diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index c00c1488c..f138b12d2 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -162,6 +162,9 @@ template class MatrixBase #ifndef EIGEN_PARSED_BY_DOXYGEN template Derived& lazyAssign(const ProductBase& other); + + template + Derived& lazyAssign(const MatrixPowerProductBase& other); #endif // not EIGEN_PARSED_BY_DOXYGEN template diff --git a/Eigen/src/Core/NoAlias.h b/Eigen/src/Core/NoAlias.h index fcf2c479c..ac1396f68 100644 --- a/Eigen/src/Core/NoAlias.h +++ b/Eigen/src/Core/NoAlias.h @@ -81,8 +81,8 @@ class NoAlias EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct& other) { return m_expression.derived() -= CoeffBasedProduct(other.lhs(), other.rhs()); } - template - EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase& other) + template + EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase& other) { other.derived().evalTo(m_expression); return m_expression; } #endif diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 1a3e14b30..58e1d87dc 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -272,7 +272,7 @@ template class MatrixFunctionReturnValue; template class MatrixSquareRootReturnValue; template class MatrixLogarithmReturnValue; template class MatrixPowerReturnValue; -template class MatrixPowerProductBase; +template class MatrixPowerProductBase; namespace internal { template diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h index 08affb2b5..7aeb69c00 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h @@ -55,14 +55,14 @@ template class MatrixPower RealScalar modfAndInit(RealScalar, RealScalar*); - template - void apply(const PlainObject&, ResultType&, bool&); + template + void apply(const Derived&, ResultType&, bool&); template void computeIntPower(ResultType&, RealScalar); - template - void computeIntPower(const PlainObject&, ResultType&, RealScalar); + template + void computeIntPower(const Derived&, ResultType&, RealScalar); template void computeFracPower(ResultType&, RealScalar); @@ -101,8 +101,8 @@ template class MatrixPower * \param[out] res \f$ A^p b \f$, where A is specified in the * constructor. */ - template - void compute(const PlainObject& b, ResultType& res, RealScalar p); + template + void compute(const Derived& b, ResultType& res, RealScalar p); Index rows() const { return m_A.rows(); } Index cols() const { return m_A.cols(); } @@ -133,8 +133,8 @@ void MatrixPower::compute(MatrixType& res, RealScalar p) } template -template -void MatrixPower::compute(const PlainObject& b, ResultType& res, RealScalar p) +template +void MatrixPower::compute(const Derived& b, ResultType& res, RealScalar p) { switch (m_A.cols()) { case 0: @@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower::modfAndInit(RealScalar } template -template -void MatrixPower::apply(const PlainObject& b, ResultType& res, bool& init) +template +void MatrixPower::apply(const Derived& b, ResultType& res, bool& init) { if (init) res = m_tmp1 * res; @@ -206,8 +206,8 @@ void MatrixPower::computeIntPower(ResultType& res, RealScalar p) } template -template -void MatrixPower::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p) +template +void MatrixPower::computeIntPower(const Derived& b, ResultType& res, RealScalar p) { if (b.cols() >= m_A.cols()) { m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols()); @@ -262,14 +262,13 @@ void MatrixPower::computeFracPower(ResultType& res, RealScalar p) } } -template -class MatrixPowerMatrixProduct : public MatrixPowerProductBase > +template +class MatrixPowerMatrixProduct : public MatrixPowerProductBase,Lhs,Rhs> { public: - typedef MatrixPowerProductBase > Base; - EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct) + EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct) - MatrixPowerMatrixProduct(MatrixPower& pow, const PlainObject& b, RealScalar p) + MatrixPowerMatrixProduct(MatrixPower& pow, const Rhs& b, RealScalar p) : m_pow(pow), m_b(b), m_p(p) { } template @@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase& m_pow; - const PlainObject& m_b; + MatrixPower& m_pow; + const Rhs& m_b; const RealScalar m_p; MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&); }; @@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue inline void evalTo(ResultType& res) const - { MatrixPower(m_A).compute(res, m_p); } + { MatrixPower(m_A.eval()).compute(res, m_p); } Index rows() const { return m_A.rows(); } Index cols() const { return m_A.cols(); } @@ -350,8 +349,8 @@ class MatrixPowerEvaluator { m_pow.compute(res, m_p); } template - const MatrixPowerMatrixProduct operator*(const MatrixBase& b) const - { return MatrixPowerMatrixProduct(m_pow, b.derived(), m_p); } + const MatrixPowerMatrixProduct operator*(const MatrixBase& b) const + { return MatrixPowerMatrixProduct(m_pow, b.derived(), m_p); } Index rows() const { return m_pow.rows(); } Index cols() const { return m_pow.cols(); } @@ -363,9 +362,9 @@ class MatrixPowerEvaluator }; namespace internal { -template -struct nested > -{ typedef PlainObject const& type; }; +template +struct nested > +{ typedef typename MatrixPowerMatrixProduct::PlainObject const& type; }; template struct traits > @@ -375,28 +374,10 @@ template struct traits > { typedef MatrixType ReturnType; }; -template -struct traits > -{ - typedef MatrixXpr XprKind; - typedef typename scalar_product_traits::ReturnType Scalar; - typedef typename promote_storage_type::StorageKind, - typename traits::StorageKind>::ret StorageKind; - typedef typename promote_index_type::Index, - typename traits::Index>::type Index; - - enum { - RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits::RowsAtCompileTime, - traits::RowsAtCompileTime), - ColsAtCompileTime = traits::ColsAtCompileTime, - MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits::MaxRowsAtCompileTime, - traits::MaxRowsAtCompileTime), - MaxColsAtCompileTime = traits::MaxColsAtCompileTime, - Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) - | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, - CoeffReadCost = 0 - }; -}; +template +struct traits > +: traits,Lhs,Rhs> > +{ }; } template diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h index 0a18fe1c1..28617ff6f 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h @@ -29,9 +29,29 @@ struct recompose_complex_schur<0> { res = (U * (T.template triangularView() * U.adjoint())).real(); } }; -template -struct traits > : traits -{ }; +template +struct traits > +{ + typedef MatrixXpr XprKind; + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename remove_all::type PlainObject; + typedef typename scalar_product_traits::ReturnType Scalar; + typedef typename promote_storage_type::StorageKind, + typename traits::StorageKind>::ret StorageKind; + typedef typename promote_index_type::Index, + typename traits::Index>::type Index; + + enum { + RowsAtCompileTime = traits::RowsAtCompileTime, + ColsAtCompileTime = traits::ColsAtCompileTime, + MaxRowsAtCompileTime = traits::MaxRowsAtCompileTime, + MaxColsAtCompileTime = traits::MaxColsAtCompileTime, + Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) + | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, + CoeffReadCost = 0 + }; +}; template inline int binary_powering_cost(T p, int* squarings) @@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic::computeBig(MatrixType& res, R compute2x2(res, p); } -template +#define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \ + typedef MatrixPowerProductBase Base; \ + EIGEN_DENSE_PUBLIC_INTERFACE(Derived) + +template class MatrixPowerProductBase : public MatrixBase { public: typedef MatrixBase Base; - typedef typename Base::PlainObject PlainObject; EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase) + + typedef typename Base::PlainObject PlainObject; inline Index rows() const { return derived().rows(); } inline Index cols() const { return derived().cols(); } @@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase mutable PlainObject m_result; }; +template +template +Derived& MatrixBase::lazyAssign(const MatrixPowerProductBase& other) +{ + other.derived().evalTo(derived()); + return derived(); +} + } // namespace Eigen #endif // EIGEN_MATRIX_POWER