Enable fill_n and memset optimizations for construction and assignment

This commit is contained in:
Charles Schlosser
2024-12-14 14:25:04 +00:00
parent af59ada0ac
commit c01ff45312
7 changed files with 105 additions and 22 deletions

View File

@@ -888,6 +888,32 @@ struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Weak> {
}
};
template <typename DstXprType, typename SrcPlainObject, typename Weak>
struct Assignment<DstXprType, CwiseNullaryOp<scalar_constant_op<typename DstXprType::Scalar>, SrcPlainObject>,
assign_op<typename DstXprType::Scalar, typename DstXprType::Scalar>, Dense2Dense, Weak> {
using Scalar = typename DstXprType::Scalar;
using NullaryOp = scalar_constant_op<Scalar>;
using SrcXprType = CwiseNullaryOp<NullaryOp, SrcPlainObject>;
using Functor = assign_op<Scalar, Scalar>;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
const Functor& /*func*/) {
eigen_fill_impl<DstXprType>::run(dst, src);
}
};
template <typename DstXprType, typename SrcPlainObject, typename Weak>
struct Assignment<DstXprType, CwiseNullaryOp<scalar_zero_op<typename DstXprType::Scalar>, SrcPlainObject>,
assign_op<typename DstXprType::Scalar, typename DstXprType::Scalar>, Dense2Dense, Weak> {
using Scalar = typename DstXprType::Scalar;
using NullaryOp = scalar_zero_op<Scalar>;
using SrcXprType = CwiseNullaryOp<NullaryOp, SrcPlainObject>;
using Functor = assign_op<Scalar, Scalar>;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
const Functor& /*func*/) {
eigen_zero_impl<DstXprType>::run(dst, src);
}
};
// Generic assignment through evalTo.
// TODO: not sure we have to keep that one, but it helps porting current code to new evaluator mechanism.
// Note that the last template argument "Weak" is needed to make it possible to perform

View File

@@ -71,6 +71,10 @@ class CwiseNullaryOp : public internal::dense_xpr_base<CwiseNullaryOp<NullaryOp,
eigen_assert(rows >= 0 && (RowsAtCompileTime == Dynamic || RowsAtCompileTime == rows) && cols >= 0 &&
(ColsAtCompileTime == Dynamic || ColsAtCompileTime == cols));
}
EIGEN_DEVICE_FUNC CwiseNullaryOp(Index size, const NullaryOp& func = NullaryOp())
: CwiseNullaryOp(RowsAtCompileTime == 1 ? 1 : size, RowsAtCompileTime == 1 ? size : 1, func) {
EIGEN_STATIC_ASSERT(CwiseNullaryOp::IsVectorAtCompileTime, YOU_TRIED_CALLING_A_VECTOR_METHOD_ON_A_MATRIX);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index rows() const { return m_rows.value(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index cols() const { return m_cols.value(); }
@@ -480,9 +484,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setEqualSpace
* \sa Zero(), Zero(Index)
*/
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ConstantReturnType DenseBase<Derived>::Zero(
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ZeroReturnType DenseBase<Derived>::Zero(
Index rows, Index cols) {
return Constant(rows, cols, Scalar(0));
return ZeroReturnType(rows, cols);
}
/** \returns an expression of a zero vector.
@@ -502,9 +506,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::Constan
* \sa Zero(), Zero(Index,Index)
*/
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ConstantReturnType DenseBase<Derived>::Zero(
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ZeroReturnType DenseBase<Derived>::Zero(
Index size) {
return Constant(size, Scalar(0));
return ZeroReturnType(size);
}
/** \returns an expression of a fixed-size zero matrix or vector.
@@ -518,8 +522,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::Constan
* \sa Zero(Index), Zero(Index,Index)
*/
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ConstantReturnType DenseBase<Derived>::Zero() {
return Constant(Scalar(0));
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ZeroReturnType DenseBase<Derived>::Zero() {
return ZeroReturnType(RowsAtCompileTime, ColsAtCompileTime);
}
/** \returns true if *this is approximately equal to the zero matrix,

View File

@@ -243,6 +243,8 @@ class DenseBase
#ifndef EIGEN_PARSED_BY_DOXYGEN
/** \internal Represents a matrix with all coefficients equal to one another*/
typedef CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> ConstantReturnType;
/** \internal Represents a matrix with all coefficients equal to zero*/
typedef CwiseNullaryOp<internal::scalar_zero_op<Scalar>, PlainObject> ZeroReturnType;
/** \internal \deprecated Represents a vector with linearly spaced coefficients that allows sequential access only. */
EIGEN_DEPRECATED typedef CwiseNullaryOp<internal::linspaced_op<Scalar>, PlainObject> SequentialLinSpacedReturnType;
/** \internal Represents a vector with linearly spaced coefficients that allows random access. */
@@ -328,9 +330,9 @@ class DenseBase
template <typename CustomNullaryOp>
EIGEN_DEVICE_FUNC static const CwiseNullaryOp<CustomNullaryOp, PlainObject> NullaryExpr(const CustomNullaryOp& func);
EIGEN_DEVICE_FUNC static const ConstantReturnType Zero(Index rows, Index cols);
EIGEN_DEVICE_FUNC static const ConstantReturnType Zero(Index size);
EIGEN_DEVICE_FUNC static const ConstantReturnType Zero();
EIGEN_DEVICE_FUNC static const ZeroReturnType Zero(Index rows, Index cols);
EIGEN_DEVICE_FUNC static const ZeroReturnType Zero(Index size);
EIGEN_DEVICE_FUNC static const ZeroReturnType Zero();
EIGEN_DEVICE_FUNC static const ConstantReturnType Ones(Index rows, Index cols);
EIGEN_DEVICE_FUNC static const ConstantReturnType Ones(Index size);
EIGEN_DEVICE_FUNC static const ConstantReturnType Ones();

View File

@@ -256,10 +256,13 @@ class DiagonalMatrix : public DiagonalBase<DiagonalMatrix<Scalar_, SizeAtCompile
typedef DiagonalWrapper<const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, DiagonalVectorType>>
InitializeReturnType;
typedef DiagonalWrapper<const CwiseNullaryOp<internal::scalar_zero_op<Scalar>, DiagonalVectorType>>
ZeroInitializeReturnType;
/** Initializes a diagonal matrix of size SizeAtCompileTime with coefficients set to zero */
EIGEN_DEVICE_FUNC static const InitializeReturnType Zero() { return DiagonalVectorType::Zero().asDiagonal(); }
EIGEN_DEVICE_FUNC static const ZeroInitializeReturnType Zero() { return DiagonalVectorType::Zero().asDiagonal(); }
/** Initializes a diagonal matrix of size dim with coefficients set to zero */
EIGEN_DEVICE_FUNC static const InitializeReturnType Zero(Index size) {
EIGEN_DEVICE_FUNC static const ZeroInitializeReturnType Zero(Index size) {
return DiagonalVectorType::Zero(size).asDiagonal();
}
/** Initializes a identity matrix of size SizeAtCompileTime */

View File

@@ -54,19 +54,26 @@ template <typename Xpr, int Options, int OuterStride_>
struct eigen_fill_helper<Map<Xpr, Options, OuterStride<OuterStride_>>>
: eigen_fill_helper<Map<Xpr, Options, Stride<OuterStride_, 0>>> {};
template <typename Xpr, bool use_fill = eigen_fill_helper<Xpr>::value>
struct eigen_fill_impl {
template <typename Xpr>
struct eigen_fill_impl<Xpr, /*use_fill*/ false> {
using Scalar = typename Xpr::Scalar;
using Func = scalar_constant_op<Scalar>;
using PlainObject = typename Xpr::PlainObject;
using Constant = CwiseNullaryOp<Func, PlainObject>;
using Constant = typename PlainObject::ConstantReturnType;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const Scalar& val) {
dst = Constant(dst.rows(), dst.cols(), Func(val));
const Constant src(dst.rows(), dst.cols(), val);
run(dst, src);
}
template <typename SrcXpr>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const SrcXpr& src) {
call_dense_assignment_loop(dst, src, assign_op<Scalar, Scalar>());
}
};
#if !EIGEN_COMP_MSVC
#ifndef EIGEN_GPU_COMPILE_PHASE
#if EIGEN_COMP_MSVC || defined(EIGEN_GPU_COMPILE_PHASE)
template <typename Xpr>
struct eigen_fill_impl<Xpr, /*use_fill*/ true> : eigen_fill_impl<Xpr, /*use_fill*/ false> {};
#else
template <typename Xpr>
struct eigen_fill_impl<Xpr, /*use_fill*/ true> {
using Scalar = typename Xpr::Scalar;
@@ -74,19 +81,33 @@ struct eigen_fill_impl<Xpr, /*use_fill*/ true> {
EIGEN_USING_STD(fill_n);
fill_n(dst.data(), dst.size(), val);
}
template <typename SrcXpr>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const SrcXpr& src) {
resize_if_allowed(dst, src, assign_op<Scalar, Scalar>());
const Scalar& val = src.functor()();
run(dst, val);
}
};
#endif
#endif
template <typename Xpr>
struct eigen_memset_helper {
static constexpr bool value = std::is_trivial<typename Xpr::Scalar>::value && eigen_fill_helper<Xpr>::value;
};
template <typename Xpr, bool use_memset = eigen_memset_helper<Xpr>::value>
struct eigen_zero_impl {
template <typename Xpr>
struct eigen_zero_impl<Xpr, /*use_memset*/ false> {
using Scalar = typename Xpr::Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst) { eigen_fill_impl<Xpr, false>::run(dst, Scalar(0)); }
using PlainObject = typename Xpr::PlainObject;
using Zero = typename PlainObject::ZeroReturnType;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst) {
const Zero src(dst.rows(), dst.cols());
run(dst, src);
}
template <typename SrcXpr>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const SrcXpr& src) {
call_dense_assignment_loop(dst, src, assign_op<Scalar, Scalar>());
}
};
template <typename Xpr>
@@ -104,6 +125,11 @@ struct eigen_zero_impl<Xpr, /*use_memset*/ true> {
EIGEN_USING_STD(memset);
memset(dst_ptr, 0, num_bytes);
}
template <typename SrcXpr>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const SrcXpr& src) {
resize_if_allowed(dst, src, assign_op<Scalar, Scalar>());
run(dst);
}
};
} // namespace internal

View File

@@ -19,7 +19,6 @@ namespace internal {
template <typename Scalar>
struct scalar_constant_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const scalar_constant_op& other) : m_other(other.m_other) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const Scalar& other) : m_other(other) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()() const { return m_other; }
template <typename PacketType>
@@ -37,6 +36,18 @@ struct functor_traits<scalar_constant_op<Scalar> > {
};
};
template <typename Scalar>
struct scalar_zero_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_zero_op() = default;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()() const { return Scalar(0); }
template <typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const {
return internal::pzero<PacketType>(PacketType());
}
};
template <typename Scalar>
struct functor_traits<scalar_zero_op<Scalar>> : functor_traits<scalar_constant_op<Scalar>> {};
template <typename Scalar>
struct scalar_identity_op {
template <typename IndexType>

View File

@@ -505,6 +505,17 @@ struct stem_function {
template <typename XprType, typename Device>
struct DeviceWrapper;
namespace internal {
template <typename Xpr>
struct eigen_fill_helper;
template <typename Xpr, bool use_fill = eigen_fill_helper<Xpr>::value>
struct eigen_fill_impl;
template <typename Xpr>
struct eigen_memset_helper;
template <typename Xpr, bool use_memset = eigen_memset_helper<Xpr>::value>
struct eigen_zero_impl;
} // namespace internal
} // end namespace Eigen
#endif // EIGEN_FORWARDDECLARATIONS_H