From 77833f932040b46e79b6fa6522594a6df27e8a76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Thu, 28 Mar 2024 18:43:50 +0000 Subject: [PATCH] Allow symbols to be used in compile-time expressions. --- Eigen/src/Core/ArithmeticSequence.h | 54 +--- Eigen/src/Core/IndexedView.h | 33 +- Eigen/src/Core/util/Constants.h | 4 +- Eigen/src/Core/util/ForwardDeclarations.h | 2 + Eigen/src/Core/util/IndexedViewHelper.h | 360 ++++++++++++++-------- Eigen/src/Core/util/IntegralConstant.h | 35 +-- Eigen/src/Core/util/SymbolicIndex.h | 314 ++++++++++++++----- Eigen/src/plugins/IndexedViewMethods.inc | 145 +++++---- doc/TutorialSlicingIndexing.dox | 6 +- test/indexed_view.cpp | 313 +++++++++++++++++++ 10 files changed, 902 insertions(+), 364 deletions(-) diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h index ae3fac3d6..ae6373dda 100644 --- a/Eigen/src/Core/ArithmeticSequence.h +++ b/Eigen/src/Core/ArithmeticSequence.h @@ -61,26 +61,28 @@ seqN(FirstType first, SizeType size, IncrType incr); template class ArithmeticSequence { public: - ArithmeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {} - ArithmeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {} + constexpr ArithmeticSequence() = default; + constexpr ArithmeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {} + constexpr ArithmeticSequence(FirstType first, SizeType size, IncrType incr) + : m_first(first), m_size(size), m_incr(incr) {} enum { - SizeAtCompileTime = internal::get_fixed_value::value, + // SizeAtCompileTime = internal::get_fixed_value::value, IncrAtCompileTime = internal::get_fixed_value::value }; /** \returns the size, i.e., number of elements, of the sequence */ - Index size() const { return m_size; } + constexpr Index size() const { return m_size; } /** \returns the first element \f$ a_0 \f$ in the sequence */ - Index first() const { return m_first; } + constexpr Index first() const { return m_first; } /** \returns the value \f$ a_i \f$ at index \a i in the sequence. */ - Index operator[](Index i) const { return m_first + i * m_incr; } + constexpr Index operator[](Index i) const { return m_first + i * m_incr; } - const FirstType& firstObject() const { return m_first; } - const SizeType& sizeObject() const { return m_size; } - const IncrType& incrObject() const { return m_incr; } + constexpr const FirstType& firstObject() const { return m_first; } + constexpr const SizeType& sizeObject() const { return m_size; } + constexpr const IncrType& incrObject() const { return m_incr; } protected: FirstType m_first; @@ -88,7 +90,7 @@ class ArithmeticSequence { IncrType m_incr; public: - auto reverse() const -> decltype(Eigen::seqN(m_first + (m_size + fix<-1>()) * m_incr, m_size, -m_incr)) { + constexpr auto reverse() const -> decltype(Eigen::seqN(m_first + (m_size + fix<-1>()) * m_incr, m_size, -m_incr)) { return seqN(m_first + (m_size + fix<-1>()) * m_incr, m_size, -m_incr); } }; @@ -201,38 +203,6 @@ auto lastN(SizeType size) -> decltype(seqN(Eigen::placeholders::last + fix<1>() } // namespace placeholders -namespace internal { - -// Convert a symbolic span into a usable one (i.e., remove last/end "keywords") -template -struct make_size_type { - typedef std::conditional_t::value, Index, T> type; -}; - -template -struct IndexedViewCompatibleType, XprSize> { - typedef ArithmeticSequence::type, IncrType> type; -}; - -template -ArithmeticSequence::type, IncrType> makeIndexedViewCompatible( - const ArithmeticSequence& ids, Index size, SpecializedType) { - return ArithmeticSequence::type, IncrType>( - eval_expr_given_size(ids.firstObject(), size), eval_expr_given_size(ids.sizeObject(), size), ids.incrObject()); -} - -template -struct get_compile_time_incr > { - enum { value = get_fixed_value::value }; -}; - -template -constexpr Index get_runtime_incr(const ArithmeticSequence& x) EIGEN_NOEXCEPT { - return static_cast(x.incrObject()); -} - -} // end namespace internal - /** \namespace Eigen::indexing * \ingroup Core_Module * diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h index b90ecb1e6..454e560e4 100644 --- a/Eigen/src/Core/IndexedView.h +++ b/Eigen/src/Core/IndexedView.h @@ -20,8 +20,8 @@ namespace internal { template struct traits> : traits { enum { - RowsAtCompileTime = int(array_size::value), - ColsAtCompileTime = int(array_size::value), + RowsAtCompileTime = int(IndexedViewHelper::SizeAtCompileTime), + ColsAtCompileTime = int(IndexedViewHelper::SizeAtCompileTime), MaxRowsAtCompileTime = RowsAtCompileTime, MaxColsAtCompileTime = ColsAtCompileTime, @@ -30,8 +30,8 @@ struct traits> : traits { : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0 : XprTypeIsRowMajor, - RowIncr = int(get_compile_time_incr::value), - ColIncr = int(get_compile_time_incr::value), + RowIncr = int(IndexedViewHelper::IncrAtCompileTime), + ColIncr = int(IndexedViewHelper::IncrAtCompileTime), InnerIncr = IsRowMajor ? ColIncr : RowIncr, OuterIncr = IsRowMajor ? RowIncr : ColIncr, @@ -47,24 +47,23 @@ struct traits> : traits { is_same, std::conditional_t>::value, InnerStrideAtCompileTime = - InnerIncr < 0 || InnerIncr == DynamicIndex || XprInnerStride == Dynamic || InnerIncr == UndefinedIncr + InnerIncr < 0 || InnerIncr == DynamicIndex || XprInnerStride == Dynamic || InnerIncr == Undefined ? Dynamic : XprInnerStride * InnerIncr, OuterStrideAtCompileTime = - OuterIncr < 0 || OuterIncr == DynamicIndex || XprOuterstride == Dynamic || OuterIncr == UndefinedIncr + OuterIncr < 0 || OuterIncr == DynamicIndex || XprOuterstride == Dynamic || OuterIncr == Undefined ? Dynamic : XprOuterstride * OuterIncr, - ReturnAsScalar = is_same::value && is_same::value, + ReturnAsScalar = is_single_range::value && is_single_range::value, ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike, ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock), // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag, // but this is too strict regarding negative strides... - DirectAccessMask = - (int(InnerIncr) != UndefinedIncr && int(OuterIncr) != UndefinedIncr && InnerIncr >= 0 && OuterIncr >= 0) - ? DirectAccessBit - : 0, + DirectAccessMask = (int(InnerIncr) != Undefined && int(OuterIncr) != Undefined && InnerIncr >= 0 && OuterIncr >= 0) + ? DirectAccessBit + : 0, FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, FlagsLvalueBit = is_lvalue::value ? LvalueBit : 0, FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0, @@ -153,10 +152,10 @@ class IndexedViewImpl : public internal::generic_xpr_base::size(m_rowIndices); } /** \returns number of columns */ - Index cols() const { return internal::index_list_size(m_colIndices); } + Index cols() const { return IndexedViewHelper::size(m_colIndices); } /** \returns the nested expression */ const internal::remove_all_t& nestedExpression() const { return m_xpr; } @@ -198,16 +197,16 @@ class IndexedViewImpl IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {} Index rowIncrement() const { - if (traits::RowIncr != DynamicIndex && traits::RowIncr != UndefinedIncr) { + if (traits::RowIncr != DynamicIndex && traits::RowIncr != Undefined) { return traits::RowIncr; } - return get_runtime_incr(this->rowIndices()); + return IndexedViewHelper::incr(this->rowIndices()); } Index colIncrement() const { - if (traits::ColIncr != DynamicIndex && traits::ColIncr != UndefinedIncr) { + if (traits::ColIncr != DynamicIndex && traits::ColIncr != Undefined) { return traits::ColIncr; } - return get_runtime_incr(this->colIndices()); + return IndexedViewHelper::incr(this->colIndices()); } Index innerIncrement() const { return traits::IsRowMajor ? colIncrement() : rowIncrement(); } diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 8b06c676b..9f4a2d8ef 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -29,9 +29,9 @@ const int Dynamic = -1; */ const int DynamicIndex = 0xffffff; -/** This value means that the increment to go from one value to another in a sequence is not constant for each step. +/** This value means that the requested value is not defined. */ -const int UndefinedIncr = 0xfffffe; +const int Undefined = 0xfffffe; /** This value means +Infinity; it is currently used only as the p parameter to MatrixBase::lpNorm(). * The value Infinity there means the L-infinity norm. diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c312939ca..2f2ba9b20 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -91,6 +91,8 @@ template class IndexedView; template class Reshaped; +template +class ArithmeticSequence; template class VectorBlock; diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h index 9d1b34833..c1870024a 100644 --- a/Eigen/src/Core/util/IndexedViewHelper.h +++ b/Eigen/src/Core/util/IndexedViewHelper.h @@ -17,6 +17,9 @@ namespace Eigen { namespace internal { struct symbolic_last_tag {}; + +struct all_t {}; + } // namespace internal namespace placeholders { @@ -42,131 +45,7 @@ typedef symbolic::SymbolExpr last_t; * * \sa end */ -static const last_t last; - -} // namespace placeholders - -namespace internal { - -// Replace symbolic last/end "keywords" by their true runtime value -inline Index eval_expr_given_size(Index x, Index /* size */) { return x; } - -template -FixedInt eval_expr_given_size(FixedInt x, Index /*size*/) { - return x; -} - -template -Index eval_expr_given_size(const symbolic::BaseExpr& x, Index size) { - return x.derived().eval(Eigen::placeholders::last = size - 1); -} - -// Extract increment/step at compile time -template -struct get_compile_time_incr { - enum { value = UndefinedIncr }; -}; - -template -constexpr Index get_runtime_incr(const T&) EIGEN_NOEXCEPT { - return Index(1); -} - -// Analogue of std::get<0>(x), but tailored for our needs. -template -EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT { - return x.first(); -} - -// IndexedViewCompatibleType/makeIndexedViewCompatible turn an arbitrary object of type T into something usable by -// MatrixSlice The generic implementation is a no-op -template -struct IndexedViewCompatibleType { - typedef T type; -}; - -template -const T& makeIndexedViewCompatible(const T& x, Index /*size*/, Q) { - return x; -} - -//-------------------------------------------------------------------------------- -// Handling of a single Index -//-------------------------------------------------------------------------------- - -struct SingleRange { - enum { SizeAtCompileTime = 1 }; - SingleRange(Index val) : m_value(val) {} - Index operator[](Index) const { return m_value; } - static EIGEN_CONSTEXPR Index size() EIGEN_NOEXCEPT { return 1; } - Index first() const EIGEN_NOEXCEPT { return m_value; } - Index m_value; -}; - -template <> -struct get_compile_time_incr { - enum { value = 1 }; // 1 or 0 ?? -}; - -// Turn a single index into something that looks like an array (i.e., that exposes a .size(), and operator[](int) -// methods) -template -struct IndexedViewCompatibleType::value>> { - // Here we could simply use Array, but maybe it's less work for the compiler to use - // a simpler wrapper as SingleRange - // typedef Eigen::Array type; - typedef SingleRange type; -}; - -template -struct IndexedViewCompatibleType::value>> { - typedef SingleRange type; -}; - -template -std::enable_if_t::value, SingleRange> makeIndexedViewCompatible(const T& id, Index size, - SpecializedType) { - return eval_expr_given_size(id, size); -} - -//-------------------------------------------------------------------------------- -// Handling of all -//-------------------------------------------------------------------------------- - -struct all_t { - all_t() {} -}; - -// Convert a symbolic 'all' into a usable range type -template -struct AllRange { - enum { SizeAtCompileTime = XprSize }; - AllRange(Index size = XprSize) : m_size(size) {} - EIGEN_CONSTEXPR Index operator[](Index i) const EIGEN_NOEXCEPT { return i; } - EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_size.value(); } - EIGEN_CONSTEXPR Index first() const EIGEN_NOEXCEPT { return 0; } - variable_if_dynamic m_size; -}; - -template -struct IndexedViewCompatibleType { - typedef AllRange type; -}; - -template -inline AllRange::value> makeIndexedViewCompatible(all_t, XprSizeType size, - SpecializedType) { - return AllRange::value>(size); -} - -template -struct get_compile_time_incr> { - enum { value = 1 }; -}; - -} // end namespace internal - -namespace placeholders { +static constexpr const last_t last; typedef symbolic::AddExpr, symbolic::ValueExpr>> @@ -186,28 +65,251 @@ typedef Eigen::internal::all_t all_t; * \sa last */ #ifdef EIGEN_PARSED_BY_DOXYGEN -static const auto lastp1 = last + fix<1>; +static constexpr auto lastp1 = last + fix<1>; #else // Using a FixedExpr<1> expression is important here to make sure the compiler // can fully optimize the computation starting indices with zero overhead. -static const lastp1_t lastp1(last + fix<1>()); +static constexpr lastp1_t lastp1(last + fix<1>()); #endif /** \var end * \ingroup Core_Module * \sa lastp1 */ -static const lastp1_t end = lastp1; +static constexpr lastp1_t end = lastp1; /** \var all * \ingroup Core_Module * Can be used as a parameter to DenseBase::operator()(const RowIndices&, const ColIndices&) to index all rows or * columns */ -static const Eigen::internal::all_t all; +static constexpr Eigen::internal::all_t all; } // namespace placeholders +namespace internal { + +// Evaluate a symbolic expression or constant given the "size" of an object, allowing +// any symbols like `last` to be evaluated. The default here assumes a dynamic constant. +template +struct SymbolicExpressionEvaluator { + static constexpr Index ValueAtCompileTime = Undefined; + static Index eval(const Expr& expr, Index /*size*/) { return static_cast(expr); } +}; + +// Symbolic expression with size known at compile-time. +template +struct SymbolicExpressionEvaluator::value>> { + static constexpr Index ValueAtCompileTime = + Expr::Derived::eval_at_compile_time(Eigen::placeholders::last = fix); + static Index eval(const Expr& expr, Index /*size*/) { + return expr.eval(Eigen::placeholders::last = fix); + } +}; + +// Symbolic expression with dynamic size. +template +struct SymbolicExpressionEvaluator::value>> { + static constexpr Index ValueAtCompileTime = Undefined; + static Index eval(const Expr& expr, Index size) { return expr.eval(Eigen::placeholders::last = size - 1); } +}; + +// Fixed int. +template +struct SymbolicExpressionEvaluator, SizeAtCompileTime, void> { + static constexpr Index ValueAtCompileTime = static_cast(N); + static Index eval(const FixedInt& /*expr*/, Index /*size*/) { return ValueAtCompileTime; } +}; + +//-------------------------------------------------------------------------------- +// Handling of generic indices (e.g. array) +//-------------------------------------------------------------------------------- + +// Potentially wrap indices in a type that is better-suited for IndexedView evaluation. +template +struct IndexedViewHelperIndicesWrapper { + using type = Indices; + static const type& CreateIndexSequence(const Indices& indices, Index /*nested_size*/) { return indices; } +}; + +// Extract compile-time and runtime first, size, increments. +template +struct IndexedViewHelper { + static constexpr Index FirstAtCompileTime = Undefined; + static constexpr Index SizeAtCompileTime = array_size::value; + static constexpr Index IncrAtCompileTime = Undefined; + + static constexpr Index first(const Indices& indices) { return static_cast(indices[0]); } + static constexpr Index size(const Indices& indices) { return index_list_size(indices); } + static constexpr Index incr(const Indices& /*indices*/) { return Undefined; } +}; + +//-------------------------------------------------------------------------------- +// Handling of ArithmeticSequence +//-------------------------------------------------------------------------------- + +template +class ArithmeticSequenceRange { + public: + static constexpr Index FirstAtCompileTime = FirstAtCompileTime_; + static constexpr Index SizeAtCompileTime = SizeAtCompileTime_; + static constexpr Index IncrAtCompileTime = IncrAtCompileTime_; + + constexpr ArithmeticSequenceRange(Index first, Index size, Index incr) : first_{first}, size_{size}, incr_{incr} {} + constexpr Index operator[](Index i) const { return first() + i * incr(); } + constexpr Index first() const noexcept { return first_.value(); } + constexpr Index size() const noexcept { return size_.value(); } + constexpr Index incr() const noexcept { return incr_.value(); } + + private: + variable_if_dynamicindex first_; + variable_if_dynamic size_; + variable_if_dynamicindex incr_; +}; + +template +struct IndexedViewHelperIndicesWrapper, NestedSizeAtCompileTime, + void> { + static constexpr Index EvalFirstAtCompileTime = + SymbolicExpressionEvaluator::ValueAtCompileTime; + static constexpr Index EvalSizeAtCompileTime = + SymbolicExpressionEvaluator::ValueAtCompileTime; + static constexpr Index EvalIncrAtCompileTime = + SymbolicExpressionEvaluator::ValueAtCompileTime; + + static constexpr Index FirstAtCompileTime = + (int(EvalFirstAtCompileTime) == Undefined) ? Index(DynamicIndex) : EvalFirstAtCompileTime; + static constexpr Index SizeAtCompileTime = + (int(EvalSizeAtCompileTime) == Undefined) ? Index(Dynamic) : EvalSizeAtCompileTime; + static constexpr Index IncrAtCompileTime = + (int(EvalIncrAtCompileTime) == Undefined) ? Index(DynamicIndex) : EvalIncrAtCompileTime; + + using Indices = ArithmeticSequence; + using type = ArithmeticSequenceRange; + + static type CreateIndexSequence(const Indices& indices, Index nested_size) { + Index first = + SymbolicExpressionEvaluator::eval(indices.firstObject(), nested_size); + Index size = + SymbolicExpressionEvaluator::eval(indices.sizeObject(), nested_size); + Index incr = + SymbolicExpressionEvaluator::eval(indices.incrObject(), nested_size); + return type(first, size, incr); + } +}; + +template +struct IndexedViewHelper, void> { + public: + using Indices = ArithmeticSequenceRange; + static constexpr Index FirstAtCompileTime = Indices::FirstAtCompileTime; + static constexpr Index SizeAtCompileTime = Indices::SizeAtCompileTime; + static constexpr Index IncrAtCompileTime = Indices::IncrAtCompileTime; + static Index first(const Indices& indices) { return indices.first(); } + static Index size(const Indices& indices) { return indices.size(); } + static Index incr(const Indices& indices) { return indices.incr(); } +}; + +//-------------------------------------------------------------------------------- +// Handling of a single index. +//-------------------------------------------------------------------------------- + +template +class SingleRange { + public: + static constexpr Index FirstAtCompileTime = ValueAtCompileTime; + static constexpr Index SizeAtCompileTime = Index(1); + static constexpr Index IncrAtCompileTime = Index(1); // Needs to be 1 to be treated as block-like. + + constexpr SingleRange(Index v) noexcept : value_(v) {} + constexpr Index operator[](Index) const noexcept { return first(); } + constexpr Index first() const noexcept { return value_.value(); } + constexpr Index size() const noexcept { return SizeAtCompileTime; } + constexpr Index incr() const noexcept { return IncrAtCompileTime; } + + private: + variable_if_dynamicindex value_; +}; + +template +struct is_single_range : public std::false_type {}; + +template +struct is_single_range> : public std::true_type {}; + +template +struct IndexedViewHelperIndicesWrapper< + SingleIndex, NestedSizeAtCompileTime, + std::enable_if_t::value || symbolic::is_symbolic::value>> { + static constexpr Index EvalValueAtCompileTime = + SymbolicExpressionEvaluator::ValueAtCompileTime; + static constexpr Index ValueAtCompileTime = + (int(EvalValueAtCompileTime) == Undefined) ? Index(DynamicIndex) : EvalValueAtCompileTime; + using type = SingleRange; + static type CreateIndexSequence(const SingleIndex& index, Index nested_size) { + return type(SymbolicExpressionEvaluator::eval(index, nested_size)); + } +}; + +template +struct IndexedViewHelperIndicesWrapper, NestedSizeAtCompileTime, void> { + using type = SingleRange; + static type CreateIndexSequence(const FixedInt& /*index*/) { return type(Index(N)); } +}; + +template +struct IndexedViewHelper, void> { + using Indices = SingleRange; + static constexpr Index FirstAtCompileTime = Indices::FirstAtCompileTime; + static constexpr Index SizeAtCompileTime = Indices::SizeAtCompileTime; + static constexpr Index IncrAtCompileTime = Indices::IncrAtCompileTime; + + static constexpr Index first(const Indices& indices) { return indices.first(); } + static constexpr Index size(const Indices& /*indices*/) { return SizeAtCompileTime; } + static constexpr Index incr(const Indices& /*indices*/) { return IncrAtCompileTime; } +}; + +//-------------------------------------------------------------------------------- +// Handling of all +//-------------------------------------------------------------------------------- + +// Convert a symbolic 'all' into a usable range type +template +class AllRange { + public: + static constexpr Index FirstAtCompileTime = Index(0); + static constexpr Index SizeAtCompileTime = SizeAtCompileTime_; + static constexpr Index IncrAtCompileTime = Index(1); + constexpr AllRange(Index size) : size_(size) {} + constexpr Index operator[](Index i) const noexcept { return i; } + constexpr Index first() const noexcept { return FirstAtCompileTime; } + constexpr Index size() const noexcept { return size_.value(); } + constexpr Index incr() const noexcept { return IncrAtCompileTime; } + + private: + variable_if_dynamic size_; +}; + +template +struct IndexedViewHelperIndicesWrapper { + using type = AllRange; + static type CreateIndexSequence(const all_t& /*indices*/, Index nested_size) { return type(nested_size); } +}; + +template +struct IndexedViewHelper, void> { + using Indices = AllRange; + static constexpr Index FirstAtCompileTime = Indices::FirstAtCompileTime; + static constexpr Index SizeAtCompileTime = Indices::SizeAtCompileTime; + static constexpr Index IncrAtCompileTime = Indices::IncrAtCompileTime; + + static Index first(const Indices& indices) { return indices.first(); } + static Index size(const Indices& indices) { return indices.size(); } + static Index incr(const Indices& indices) { return indices.incr(); } +}; + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_INDEXED_VIEW_HELPER_H diff --git a/Eigen/src/Core/util/IntegralConstant.h b/Eigen/src/Core/util/IntegralConstant.h index 279d553d9..2eb5fd9d0 100644 --- a/Eigen/src/Core/util/IntegralConstant.h +++ b/Eigen/src/Core/util/IntegralConstant.h @@ -54,65 +54,60 @@ class VariableAndFixedInt; template class FixedInt { public: - static const int value = N; - EIGEN_CONSTEXPR operator int() const { return value; } + static constexpr int value = N; + constexpr operator int() const { return N; } - EIGEN_CONSTEXPR - FixedInt() = default; + constexpr FixedInt() = default; + constexpr FixedInt(std::integral_constant) {} - EIGEN_CONSTEXPR - FixedInt(std::integral_constant) {} - - EIGEN_CONSTEXPR - FixedInt(VariableAndFixedInt other) { + constexpr FixedInt(VariableAndFixedInt other) { #ifndef EIGEN_INTERNAL_DEBUGGING EIGEN_UNUSED_VARIABLE(other); #endif eigen_internal_assert(int(other) == N); } - EIGEN_CONSTEXPR - FixedInt<-N> operator-() const { return FixedInt<-N>(); } + constexpr FixedInt<-N> operator-() const { return FixedInt<-N>(); } template - EIGEN_CONSTEXPR FixedInt operator+(FixedInt) const { + constexpr FixedInt operator+(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator-(FixedInt) const { + constexpr FixedInt operator-(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator*(FixedInt) const { + constexpr FixedInt operator*(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator/(FixedInt) const { + constexpr FixedInt operator/(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator%(FixedInt) const { + constexpr FixedInt operator%(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator|(FixedInt) const { + constexpr FixedInt operator|(FixedInt) const { return FixedInt(); } template - EIGEN_CONSTEXPR FixedInt operator&(FixedInt) const { + constexpr FixedInt operator&(FixedInt) const { return FixedInt(); } // Needed in C++14 to allow fix(): - EIGEN_CONSTEXPR FixedInt operator()() const { return *this; } + constexpr FixedInt operator()() const { return *this; } - VariableAndFixedInt operator()(int val) const { return VariableAndFixedInt(val); } + constexpr VariableAndFixedInt operator()(int val) const { return VariableAndFixedInt(val); } }; /** \internal diff --git a/Eigen/src/Core/util/SymbolicIndex.h b/Eigen/src/Core/util/SymbolicIndex.h index 136942c35..befb485e8 100644 --- a/Eigen/src/Core/util/SymbolicIndex.h +++ b/Eigen/src/Core/util/SymbolicIndex.h @@ -44,6 +44,8 @@ namespace symbolic { template class Symbol; +template +class SymbolValue; template class NegateExpr; template @@ -52,136 +54,123 @@ template class ProductExpr; template class QuotientExpr; - -// A simple wrapper around an integral value to provide the eval method. -// We could also use a free-function symbolic_eval... template -class ValueExpr { - public: - ValueExpr(IndexType val) : m_value(val) {} - template - IndexType eval_impl(const T&) const { - return m_value; - } - - protected: - IndexType m_value; -}; - -// Specialization for compile-time value, -// It is similar to ValueExpr(N) but this version helps the compiler to generate better code. -template -class ValueExpr > { - public: - ValueExpr() {} - template - EIGEN_CONSTEXPR Index eval_impl(const T&) const { - return N; - } -}; +class ValueExpr; /** \class BaseExpr * \ingroup Core_Module * Common base class of any symbolic expressions */ -template +template class BaseExpr { public: - const Derived& derived() const { return *static_cast(this); } + using Derived = Derived_; + constexpr const Derived& derived() const { return *static_cast(this); } /** Evaluate the expression given the \a values of the symbols. * - * \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue - * as constructed by SymbolExpr::operator= operator. + * \param values defines the values of the symbols, as constructed by SymbolExpr::operator= operator. * */ - template - Index eval(const T& values) const { - return derived().eval_impl(values); + template + constexpr Index eval(const SymbolValue&... values) const { + return derived().eval_impl(values...); } - template - Index eval(Types&&... values) const { - return derived().eval_impl(std::make_tuple(values...)); + /** Evaluate the expression at compile time given the \a values of the symbols. + * + * If a value is not known at compile-time, returns Eigen::Undefined. + * + */ + template + static constexpr Index eval_at_compile_time(const SymbolValue&...) { + return Derived::eval_at_compile_time_impl(SymbolValue{}...); } - NegateExpr operator-() const { return NegateExpr(derived()); } + constexpr NegateExpr operator-() const { return NegateExpr(derived()); } - AddExpr > operator+(Index b) const { return AddExpr >(derived(), b); } - AddExpr > operator-(Index a) const { return AddExpr >(derived(), -a); } - ProductExpr > operator*(Index a) const { + constexpr AddExpr> operator+(Index b) const { + return AddExpr>(derived(), b); + } + constexpr AddExpr> operator-(Index a) const { + return AddExpr>(derived(), -a); + } + constexpr ProductExpr> operator*(Index a) const { return ProductExpr >(derived(), a); } - QuotientExpr > operator/(Index a) const { + constexpr QuotientExpr> operator/(Index a) const { return QuotientExpr >(derived(), a); } - friend AddExpr > operator+(Index a, const BaseExpr& b) { + friend constexpr AddExpr> operator+(Index a, const BaseExpr& b) { return AddExpr >(b.derived(), a); } - friend AddExpr, ValueExpr<> > operator-(Index a, const BaseExpr& b) { + friend constexpr AddExpr, ValueExpr<>> operator-(Index a, const BaseExpr& b) { return AddExpr, ValueExpr<> >(-b.derived(), a); } - friend ProductExpr, Derived> operator*(Index a, const BaseExpr& b) { + friend constexpr ProductExpr, Derived> operator*(Index a, const BaseExpr& b) { return ProductExpr, Derived>(a, b.derived()); } - friend QuotientExpr, Derived> operator/(Index a, const BaseExpr& b) { + friend constexpr QuotientExpr, Derived> operator/(Index a, const BaseExpr& b) { return QuotientExpr, Derived>(a, b.derived()); } template - AddExpr > > operator+(internal::FixedInt) const { + constexpr AddExpr>> operator+(internal::FixedInt) const { return AddExpr > >(derived(), ValueExpr >()); } template - AddExpr > > operator-(internal::FixedInt) const { + constexpr AddExpr>> operator-(internal::FixedInt) const { return AddExpr > >(derived(), ValueExpr >()); } template - ProductExpr > > operator*(internal::FixedInt) const { + constexpr ProductExpr>> operator*(internal::FixedInt) const { return ProductExpr > >(derived(), ValueExpr >()); } template - QuotientExpr > > operator/(internal::FixedInt) const { + constexpr QuotientExpr>> operator/(internal::FixedInt) const { return QuotientExpr > >(derived(), ValueExpr >()); } template - friend AddExpr > > operator+(internal::FixedInt, const BaseExpr& b) { + friend constexpr AddExpr>> operator+(internal::FixedInt, + const BaseExpr& b) { return AddExpr > >(b.derived(), ValueExpr >()); } template - friend AddExpr, ValueExpr > > operator-(internal::FixedInt, - const BaseExpr& b) { + friend constexpr AddExpr, ValueExpr>> operator-(internal::FixedInt, + const BaseExpr& b) { return AddExpr, ValueExpr > >(-b.derived(), ValueExpr >()); } template - friend ProductExpr >, Derived> operator*(internal::FixedInt, const BaseExpr& b) { + friend constexpr ProductExpr>, Derived> operator*(internal::FixedInt, + const BaseExpr& b) { return ProductExpr >, Derived>(ValueExpr >(), b.derived()); } template - friend QuotientExpr >, Derived> operator/(internal::FixedInt, const BaseExpr& b) { + friend constexpr QuotientExpr>, Derived> operator/(internal::FixedInt, + const BaseExpr& b) { return QuotientExpr >, Derived>(ValueExpr >(), b.derived()); } template - AddExpr operator+(const BaseExpr& b) const { + constexpr AddExpr operator+(const BaseExpr& b) const { return AddExpr(derived(), b.derived()); } template - AddExpr > operator-(const BaseExpr& b) const { + constexpr AddExpr> operator-(const BaseExpr& b) const { return AddExpr >(derived(), -b.derived()); } template - ProductExpr operator*(const BaseExpr& b) const { + constexpr ProductExpr operator*(const BaseExpr& b) const { return ProductExpr(derived(), b.derived()); } template - QuotientExpr operator/(const BaseExpr& b) const { + constexpr QuotientExpr operator/(const BaseExpr& b) const { return QuotientExpr(derived(), b.derived()); } }; @@ -193,21 +182,137 @@ struct is_symbolic { enum { value = internal::is_convertible >::value }; }; +// A simple wrapper around an integral value to provide the eval method. +// We could also use a free-function symbolic_eval... +template +class ValueExpr : BaseExpr> { + public: + constexpr ValueExpr() = default; + constexpr ValueExpr(IndexType val) : value_(val) {} + template + constexpr IndexType eval_impl(const SymbolValue&...) const { + return value_; + } + template + static constexpr IndexType eval_at_compile_time_impl(const SymbolValue&...) { + return IndexType(Undefined); + } + + protected: + IndexType value_; +}; + +// Specialization for compile-time value, +// It is similar to ValueExpr(N) but this version helps the compiler to generate better code. +template +class ValueExpr> : public BaseExpr>> { + public: + constexpr ValueExpr() = default; + constexpr ValueExpr(internal::FixedInt) {} + template + constexpr Index eval_impl(const SymbolValue&...) const { + return Index(N); + } + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + return Index(N); + } +}; + /** Represents the actual value of a symbol identified by its tag * * It is the return type of SymbolValue::operator=, and most of the time this is only way it is used. */ +template +class SymbolValue : public BaseExpr> {}; + template -class SymbolValue { +class SymbolValue : public BaseExpr> { public: + constexpr SymbolValue() = default; + /** Default constructor from the value \a val */ - SymbolValue(Index val) : m_value(val) {} + constexpr SymbolValue(Index val) : value_(val) {} /** \returns the stored value of the symbol */ - Index value() const { return m_value; } + constexpr Index value() const { return value_; } + + /** \returns the stored value of the symbol at compile time, or Undefined if not known. */ + static constexpr Index value_at_compile_time() { return Index(Undefined); } + + template + constexpr Index eval_impl(const SymbolValue&...) const { + return value(); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + return value_at_compile_time(); + } protected: - Index m_value; + Index value_; +}; + +template +class SymbolValue> : public BaseExpr>> { + public: + constexpr SymbolValue() = default; + + /** Default constructor from the value \a val */ + constexpr SymbolValue(internal::FixedInt){}; + + /** \returns the stored value of the symbol */ + constexpr Index value() const { return static_cast(N); } + + /** \returns the stored value of the symbol at compile time, or Undefined if not known. */ + static constexpr Index value_at_compile_time() { return static_cast(N); } + + template + constexpr Index eval_impl(const SymbolValue&...) const { + return value(); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + return value_at_compile_time(); + } +}; + +// Find and return a symbol value based on the tag. +template +struct EvalSymbolValueHelper; + +// Empty base case, symbol not found. +template +struct EvalSymbolValueHelper { + static constexpr Index eval_impl() { + eigen_assert(false && "Symbol not found."); + return Index(Undefined); + } + static constexpr Index eval_at_compile_time_impl() { return Index(Undefined); } +}; + +// We found a symbol value matching the provided Tag! +template +struct EvalSymbolValueHelper, OtherTypes...> { + static constexpr Index eval_impl(const SymbolValue& symbol, const OtherTypes&...) { + return symbol.value(); + } + static constexpr Index eval_at_compile_time_impl(const SymbolValue& symbol, const OtherTypes&...) { + return symbol.value_at_compile_time(); + } +}; + +// No symbol value in first value, recursive search starting with next. +template +struct EvalSymbolValueHelper { + static constexpr Index eval_impl(const T1&, const OtherTypes&... values) { + return EvalSymbolValueHelper::eval_impl(values...); + } + static constexpr Index eval_at_compile_time_impl(const T1&, const OtherTypes&...) { + return EvalSymbolValueHelper::eval_at_compile_time_impl(OtherTypes{}...); + } }; /** Expression of a symbol uniquely identified by the template parameter type \c tag */ @@ -217,32 +322,47 @@ class SymbolExpr : public BaseExpr > { /** Alias to the template parameter \c tag */ typedef tag Tag; - SymbolExpr() {} + constexpr SymbolExpr() = default; /** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag. * * The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified * runtime-time value. */ - SymbolValue operator=(Index val) const { return SymbolValue(val); } + constexpr SymbolValue operator=(Index val) const { return SymbolValue(val); } - Index eval_impl(const SymbolValue& values) const { return values.value(); } + template + constexpr SymbolValue> operator=(internal::FixedInt) const { + return SymbolValue>{internal::FixedInt{}}; + } - // C++14 versions suitable for multiple symbols - template - Index eval_impl(const std::tuple& values) const { - return std::get >(values).value(); + template + constexpr Index eval_impl(const SymbolValue&... values) const { + return EvalSymbolValueHelper...>::eval_impl(values...); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + return EvalSymbolValueHelper...>::eval_at_compile_time_impl( + SymbolValue{}...); } }; template class NegateExpr : public BaseExpr > { public: - NegateExpr(const Arg0& arg0) : m_arg0(arg0) {} + constexpr NegateExpr() = default; + constexpr NegateExpr(const Arg0& arg0) : m_arg0(arg0) {} - template - Index eval_impl(const T& values) const { - return -m_arg0.eval_impl(values); + template + constexpr Index eval_impl(const SymbolValue&... values) const { + return -m_arg0.eval_impl(values...); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + constexpr Index v = Arg0::eval_at_compile_time_impl(SymbolValue{}...); + return (v == Undefined) ? Undefined : -v; } protected: @@ -252,11 +372,19 @@ class NegateExpr : public BaseExpr > { template class AddExpr : public BaseExpr > { public: - AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + constexpr AddExpr() = default; + constexpr AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} - template - Index eval_impl(const T& values) const { - return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); + template + constexpr Index eval_impl(const SymbolValue&... values) const { + return m_arg0.eval_impl(values...) + m_arg1.eval_impl(values...); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue{}...); + constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue{}...); + return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 + v1; } protected: @@ -267,11 +395,19 @@ class AddExpr : public BaseExpr > { template class ProductExpr : public BaseExpr > { public: - ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + constexpr ProductExpr() = default; + constexpr ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} - template - Index eval_impl(const T& values) const { - return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); + template + constexpr Index eval_impl(const SymbolValue&... values) const { + return m_arg0.eval_impl(values...) * m_arg1.eval_impl(values...); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue{}...); + constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue{}...); + return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 * v1; } protected: @@ -282,11 +418,19 @@ class ProductExpr : public BaseExpr > { template class QuotientExpr : public BaseExpr > { public: - QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + constexpr QuotientExpr() = default; + constexpr QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} - template - Index eval_impl(const T& values) const { - return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); + template + constexpr Index eval_impl(const SymbolValue&... values) const { + return m_arg0.eval_impl(values...) / m_arg1.eval_impl(values...); + } + + template + static constexpr Index eval_at_compile_time_impl(const SymbolValue&...) { + constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue{}...); + constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue{}...); + return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 / v1; } protected: diff --git a/Eigen/src/plugins/IndexedViewMethods.inc b/Eigen/src/plugins/IndexedViewMethods.inc index 26e7b5fc1..c3df42971 100644 --- a/Eigen/src/plugins/IndexedViewMethods.inc +++ b/Eigen/src/plugins/IndexedViewMethods.inc @@ -9,51 +9,47 @@ #if !defined(EIGEN_PARSED_BY_DOXYGEN) -protected: +public: // define some aliases to ease readability template -using IvcRowType = typename internal::IndexedViewCompatibleType::type; +using IvcRowType = typename internal::IndexedViewHelperIndicesWrapper::type; template -using IvcColType = typename internal::IndexedViewCompatibleType::type; +using IvcColType = typename internal::IndexedViewHelperIndicesWrapper::type; template -using IvcType = typename internal::IndexedViewCompatibleType::type; - -typedef typename internal::IndexedViewCompatibleType::type IvcIndex; +using IvcSizeType = typename internal::IndexedViewHelperIndicesWrapper::type; template inline IvcRowType ivcRow(const Indices& indices) const { - return internal::makeIndexedViewCompatible( - indices, internal::variable_if_dynamic(derived().rows()), Specialized); + return internal::IndexedViewHelperIndicesWrapper::CreateIndexSequence(indices, + derived().rows()); } template inline IvcColType ivcCol(const Indices& indices) const { - return internal::makeIndexedViewCompatible( - indices, internal::variable_if_dynamic(derived().cols()), Specialized); + return internal::IndexedViewHelperIndicesWrapper::CreateIndexSequence(indices, + derived().cols()); } template -inline IvcType ivcSize(const Indices& indices) const { - return internal::makeIndexedViewCompatible( - indices, internal::variable_if_dynamic(derived().size()), Specialized); +inline IvcSizeType ivcSize(const Indices& indices) const { + return internal::IndexedViewHelperIndicesWrapper::CreateIndexSequence(indices, + derived().size()); + ; } // this helper class assumes internal::valid_indexed_view_overload::value == true -template , IvcColType>>::ReturnAsScalar, - bool UseBlock = - internal::traits, IvcColType>>::ReturnAsBlock, - bool UseGeneric = internal::traits< - IndexedView, IvcColType>>::ReturnAsIndexedView> +template struct IndexedViewSelector; // Generic template -struct IndexedViewSelector { +struct IndexedViewSelector< + RowIndices, ColIndices, + std::enable_if_t< + internal::traits, IvcColType>>::ReturnAsIndexedView>> { using ReturnType = IndexedView, IvcColType>; using ConstReturnType = IndexedView, IvcColType>; @@ -68,60 +64,73 @@ struct IndexedViewSelector { // Block template -struct IndexedViewSelector { - using IndexedViewType = IndexedView, IvcColType>; - using ConstIndexedViewType = IndexedView, IvcColType>; +struct IndexedViewSelector, IvcColType>>::ReturnAsBlock>> { + using ActualRowIndices = IvcRowType; + using ActualColIndices = IvcColType; + using IndexedViewType = IndexedView; + using ConstIndexedViewType = IndexedView; using ReturnType = typename internal::traits::BlockType; using ConstReturnType = typename internal::traits::BlockType; + using RowHelper = internal::IndexedViewHelper; + using ColHelper = internal::IndexedViewHelper; static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { - IvcRowType actualRowIndices = derived.ivcRow(rowIndices); - IvcColType actualColIndices = derived.ivcCol(colIndices); - return ReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices), - internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); + auto actualRowIndices = derived.ivcRow(rowIndices); + auto actualColIndices = derived.ivcCol(colIndices); + return ReturnType(derived, RowHelper::first(actualRowIndices), ColHelper::first(actualColIndices), + RowHelper::size(actualRowIndices), ColHelper::size(actualColIndices)); } static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { - IvcRowType actualRowIndices = derived.ivcRow(rowIndices); - IvcColType actualColIndices = derived.ivcCol(colIndices); - return ConstReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices), - internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); + auto actualRowIndices = derived.ivcRow(rowIndices); + auto actualColIndices = derived.ivcCol(colIndices); + return ConstReturnType(derived, RowHelper::first(actualRowIndices), ColHelper::first(actualColIndices), + RowHelper::size(actualRowIndices), ColHelper::size(actualColIndices)); } }; -// Symbolic +// Scalar template -struct IndexedViewSelector { +struct IndexedViewSelector, IvcColType>>::ReturnAsScalar>> { using ReturnType = typename DenseBase::Scalar&; using ConstReturnType = typename DenseBase::CoeffReturnType; - + using ActualRowIndices = IvcRowType; + using ActualColIndices = IvcColType; + using RowHelper = internal::IndexedViewHelper; + using ColHelper = internal::IndexedViewHelper; static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { - return derived(internal::eval_expr_given_size(rowIndices, derived.rows()), - internal::eval_expr_given_size(colIndices, derived.cols())); + auto actualRowIndices = derived.ivcRow(rowIndices); + auto actualColIndices = derived.ivcCol(colIndices); + return derived(RowHelper::first(actualRowIndices), ColHelper::first(actualColIndices)); } static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { - return derived(internal::eval_expr_given_size(rowIndices, derived.rows()), - internal::eval_expr_given_size(colIndices, derived.cols())); + auto actualRowIndices = derived.ivcRow(rowIndices); + auto actualColIndices = derived.ivcCol(colIndices); + return derived(RowHelper::first(actualRowIndices), ColHelper::first(actualColIndices)); } }; // this helper class assumes internal::is_valid_index_type::value == false -template ::value, - bool UseBlock = !UseSymbolic && internal::get_compile_time_incr>::value == 1, - bool UseGeneric = !UseSymbolic && !UseBlock> +template struct VectorIndexedViewSelector; // Generic template -struct VectorIndexedViewSelector { +struct VectorIndexedViewSelector< + Indices, std::enable_if_t>::value && + internal::IndexedViewHelper>::IncrAtCompileTime != 1>> { static constexpr bool IsRowMajor = DenseBase::IsRowMajor; + using ZeroIndex = internal::SingleRange; + using RowMajorReturnType = IndexedView>; + using ConstRowMajorReturnType = IndexedView>; - using RowMajorReturnType = IndexedView>; - using ConstRowMajorReturnType = IndexedView>; - - using ColMajorReturnType = IndexedView, IvcIndex>; - using ConstColMajorReturnType = IndexedView, IvcIndex>; + using ColMajorReturnType = IndexedView, ZeroIndex>; + using ConstColMajorReturnType = IndexedView, ZeroIndex>; using ReturnType = typename internal::conditional::type; using ConstReturnType = @@ -129,49 +138,53 @@ struct VectorIndexedViewSelector { template = true> static inline RowMajorReturnType run(Derived& derived, const Indices& indices) { - return RowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices)); + return RowMajorReturnType(derived, ZeroIndex(0), derived.ivcCol(indices)); } template = true> static inline ConstRowMajorReturnType run(const Derived& derived, const Indices& indices) { - return ConstRowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices)); + return ConstRowMajorReturnType(derived, ZeroIndex(0), derived.ivcCol(indices)); } template = true> static inline ColMajorReturnType run(Derived& derived, const Indices& indices) { - return ColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0)); + return ColMajorReturnType(derived, derived.ivcRow(indices), ZeroIndex(0)); } template = true> static inline ConstColMajorReturnType run(const Derived& derived, const Indices& indices) { - return ConstColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0)); + return ConstColMajorReturnType(derived, derived.ivcRow(indices), ZeroIndex(0)); } }; // Block template -struct VectorIndexedViewSelector { - using ReturnType = VectorBlock::value>; - using ConstReturnType = VectorBlock::value>; - +struct VectorIndexedViewSelector< + Indices, std::enable_if_t>::value && + internal::IndexedViewHelper>::IncrAtCompileTime == 1>> { + using Helper = internal::IndexedViewHelper>; + using ReturnType = VectorBlock; + using ConstReturnType = VectorBlock; static inline ReturnType run(Derived& derived, const Indices& indices) { - IvcType actualIndices = derived.ivcSize(indices); - return ReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices)); + auto actualIndices = derived.ivcSize(indices); + return ReturnType(derived, Helper::first(actualIndices), Helper::size(actualIndices)); } static inline ConstReturnType run(const Derived& derived, const Indices& indices) { - IvcType actualIndices = derived.ivcSize(indices); - return ConstReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices)); + auto actualIndices = derived.ivcSize(indices); + return ConstReturnType(derived, Helper::first(actualIndices), Helper::size(actualIndices)); } }; // Symbolic template -struct VectorIndexedViewSelector { +struct VectorIndexedViewSelector>::value>> { using ReturnType = typename DenseBase::Scalar&; using ConstReturnType = typename DenseBase::CoeffReturnType; - - static inline ReturnType run(Derived& derived, const Indices& id) { - return derived(internal::eval_expr_given_size(id, derived.size())); + using Helper = internal::IndexedViewHelper>; + static inline ReturnType run(Derived& derived, const Indices& indices) { + auto actualIndices = derived.ivcSize(indices); + return derived(Helper::first(actualIndices)); } - static inline ConstReturnType run(const Derived& derived, const Indices& id) { - return derived(internal::eval_expr_given_size(id, derived.size())); + static inline ConstReturnType run(const Derived& derived, const Indices& indices) { + auto actualIndices = derived.ivcSize(indices); + return derived(Helper::first(actualIndices)); } }; diff --git a/doc/TutorialSlicingIndexing.dox b/doc/TutorialSlicingIndexing.dox index 7f8955431..6ebaa2d6d 100644 --- a/doc/TutorialSlicingIndexing.dox +++ b/doc/TutorialSlicingIndexing.dox @@ -86,12 +86,12 @@ Here are some examples for a 2D array/matrix \c A and a 1D array/vector \c v. - First \c n odd rows A + First \c n odd rows of A \code A(seqN(1,n,2), all) \endcode - The last past one column + The second-last column \code A(all, last-1) \endcode \code A.col(A.cols()-2) \endcode @@ -158,7 +158,7 @@ It is equivalent to: \endcode We can revisit the even columns of A example as follows: -\code A(all, seq(0,last,fix<2>)) +\code A(all, seq(fix<0>,last,fix<2>)) \endcode diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index d3cf4a679..f165e8b46 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -527,10 +527,323 @@ void check_indexed_view() { } } +void check_tutorial_examples() { + constexpr int kRows = 11; + constexpr int kCols = 21; + Matrix A = Matrix::Random(); + Vector v = Vector::Random(); + + { + auto slice = A(seqN(fix<0>, fix<5>, fix<2>), seqN(fix<2>, fix<7>, fix<1>)); + EIGEN_UNUSED_VARIABLE(slice); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), 5); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), 7); + } + { + auto slice = A(seqN(fix<0>, fix<5>, fix<2>), indexing::all); + EIGEN_UNUSED_VARIABLE(slice); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), 5); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), kCols); + } + + // Examples from slicing tutorial. + // Bottom-left corner. + { + Index i = 3; + Index n = 5; + auto slice = A(seq(i, indexing::last), seqN(0, n)); + auto block = A.bottomLeftCorner(A.rows() - i, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), Dynamic); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), Dynamic); + VERIFY_IS_EQUAL(slice, block); + } + { + auto i = fix<3>; + auto n = fix<5>; + auto slice = A(seq(i, indexing::last), seqN(fix<0>, n)); + auto block = A.bottomLeftCorner(fix - i, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), A.RowsAtCompileTime - i); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), n); + VERIFY_IS_EQUAL(slice, block); + } + + // Block starting at i,j of size m,n. + { + Index i = 4; + Index j = 2; + Index m = 3; + Index n = 5; + auto slice = A(seqN(i, m), seqN(j, n)); + auto block = A.block(i, j, m, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto i = fix<4>; + auto j = fix<2>; + auto m = fix<3>; + auto n = fix<5>; + auto slice = A(seqN(i, m), seqN(j, n)); + auto block = A.block(i, j, m, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Block starting at i0,j0 and ending at i1,j1. + { + Index i0 = 4; + Index i1 = 7; + Index j0 = 3; + Index j1 = 5; + auto slice = A(seq(i0, i1), seq(j0, j1)); + auto block = A.block(i0, j0, i1 - i0 + 1, j1 - j0 + 1); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto i0 = fix<4>; + auto i1 = fix<7>; + auto j0 = fix<3>; + auto j1 = fix<5>; + auto slice = A(seq(i0, i1), seq(j0, j1)); + auto block = A.block(i0, j0, i1 - i0 + fix<1>, j1 - j0 + fix<1>); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Even columns of A. + { + auto slice = A(all, seq(0, last, 2)); + auto block = + Eigen::Map, 0, OuterStride<2 * kRows>>(A.data(), kRows, (kCols + 1) / 2); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto slice = A(all, seq(fix<0>, last, fix<2>)); + auto block = Eigen::Map, 0, OuterStride<2 * kRows>>(A.data()); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // First n odd rows of A. + { + Index n = 3; + auto slice = A(seqN(1, n, 2), all); + auto block = Eigen::Map, 0, Stride>(A.data() + 1, n, kCols); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto n = fix<3>; + auto slice = A(seqN(fix<1>, n, fix<2>), all); + auto block = Eigen::Map, 0, Stride>(A.data() + 1); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // The second-last column. + { + auto slice = A(all, last - 1); + auto block = A.col(A.cols() - 2); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto slice = A(all, last - fix<1>); + auto block = A.col(fix - fix<2>); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // The middle row. + { + auto slice = A(last / 2, all); + auto block = A.row((A.rows() - 1) / 2); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto slice = A(last / fix<2>, all); + auto block = A.row(fix<(kRows - 1) / 2>); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Last elements of v starting at i. + { + Index i = 7; + auto slice = v(seq(i, last)); + auto block = v.tail(v.size() - i); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto i = fix<7>; + auto slice = v(seq(i, last)); + auto block = v.tail(fix - i); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Last n elements of v. + { + Index n = 6; + auto slice = v(seq(last + 1 - n, last)); + auto block = v.tail(n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto n = fix<6>; + auto slice = v(seq(last + fix<1> - n, last)); + auto block = v.tail(n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Last n elements of v. + { + Index n = 6; + auto slice = v(lastN(n)); + auto block = v.tail(n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto n = fix<6>; + auto slice = v(lastN(n)); + auto block = v.tail(n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Bottom-right corner of A of size m times n. + { + Index m = 3; + Index n = 6; + auto slice = A(lastN(m), lastN(n)); + auto block = A.bottomRightCorner(m, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + auto m = fix<3>; + auto n = fix<6>; + auto slice = A(lastN(m), lastN(n)); + auto block = A.bottomRightCorner(m, n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Last n columns with a stride of 3. + { + Index n = 4; + constexpr Index stride = 3; + auto slice = A(all, lastN(n, stride)); + auto block = Eigen::Map, 0, OuterStride>( + A.data() + (kCols - 1 - (n - 1) * stride) * kRows, A.rows(), n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + { + constexpr auto n = fix<4>; + constexpr auto stride = fix<3>; + auto slice = A(all, lastN(n, stride)); + auto block = Eigen::Map, 0, OuterStride>( + A.data() + (kCols - 1 - (n - 1) * stride) * kRows, A.rows(), n); + VERIFY_IS_EQUAL(int(slice.RowsAtCompileTime), int(block.RowsAtCompileTime)); + VERIFY_IS_EQUAL(int(slice.ColsAtCompileTime), int(block.ColsAtCompileTime)); + VERIFY_IS_EQUAL(slice, block); + } + + // Compile time size and increment. + { + auto slice1 = v(seq(last - fix<7>, last - fix<2>)); + auto slice2 = v(seqN(last - 7, fix<6>)); + VERIFY_IS_EQUAL(slice1, slice2); + VERIFY_IS_EQUAL(int(slice1.SizeAtCompileTime), 6); + VERIFY_IS_EQUAL(int(slice2.SizeAtCompileTime), 6); + auto slice3 = A(all, seq(fix<0>, last, fix<2>)); + VERIFY_IS_EQUAL(int(slice3.RowsAtCompileTime), kRows); + VERIFY_IS_EQUAL(int(slice3.ColsAtCompileTime), (kCols + 1) / 2); + } + + // Reverse order. + { + auto slice = A(all, seq(20, 10, fix<-2>)); + auto block = Eigen::Map, 0, OuterStride<-2 * kRows>>( + A.data() + 20 * kRows, A.rows(), (20 - 10 + 2) / 2); + VERIFY_IS_EQUAL(slice, block); + } + { + Index n = 10; + auto slice1 = A(seqN(last, n, fix<-1>), all); + auto slice2 = A(lastN(n).reverse(), all); + VERIFY_IS_EQUAL(slice1, slice2); + } + + // Array of indices. + { + std::vector ind{4, 2, 5, 5, 3}; + auto slice1 = A(all, ind); + for (int i = 0; i < ind.size(); ++i) { + VERIFY_IS_EQUAL(slice1.col(i), A.col(ind[i])); + } + + auto slice2 = A(all, {4, 2, 5, 5, 3}); + VERIFY_IS_EQUAL(slice1, slice2); + + Eigen::ArrayXi indarray(5); + indarray << 4, 2, 5, 5, 3; + auto slice3 = A(all, indarray); + VERIFY_IS_EQUAL(slice1, slice3); + } + + // Custom index list. + { + struct pad { + Index size() const { return out_size; } + Index operator[](Index i) const { return std::max(0, i - (out_size - in_size)); } + Index in_size, out_size; + }; + + auto slice = A(pad{3, 5}, pad{3, 5}); + Eigen::MatrixXd B = slice; + VERIFY_IS_EQUAL(B.block(2, 2, 3, 3), A.block(0, 0, 3, 3)); + } +} + EIGEN_DECLARE_TEST(indexed_view) { for (int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1(check_indexed_view()); } + CALL_SUBTEST_1(check_tutorial_examples()); // static checks of some internals: STATIC_CHECK((internal::is_valid_index_type::value));