From aa5acdb352418d35e27c75148f66321e281af22b Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Thu, 27 Sep 2012 02:20:36 +0800 Subject: [PATCH] Create class MatrixPowerBase for further extension (like specialization for triangular or self-adjoint matrices) --- .../Eigen/src/MatrixFunctions/MatrixPower.h | 144 +++++++----------- .../src/MatrixFunctions/MatrixPowerBase.h | 107 +++++++++---- 2 files changed, 137 insertions(+), 114 deletions(-) diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h index 15ede1c2a..996c24240 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h @@ -31,57 +31,19 @@ namespace Eigen { * \include MatrixPower_optimal.cpp * Output: \verbinclude MatrixPower_optimal.out */ -template class MatrixPower +template +class MatrixPower : public MatrixPowerBase,MatrixType> { - private: - static const int Rows = MatrixType::RowsAtCompileTime; - static const int Cols = MatrixType::ColsAtCompileTime; - static const int Options = MatrixType::Options; - static const int MaxRows = MatrixType::MaxRowsAtCompileTime; - static const int MaxCols = MatrixType::MaxColsAtCompileTime; - - typedef typename MatrixType::Scalar Scalar; - typedef typename MatrixType::RealScalar RealScalar; - typedef typename MatrixType::Index Index; - typedef Matrix,Rows,Cols,Options,MaxRows,MaxCols> ComplexMatrix; - - const MatrixType* m_A; - MatrixType m_tmp1, m_tmp2; - ComplexMatrix m_T, m_U, m_fT; - char m_flag; - - RealScalar modfAndInit(RealScalar, RealScalar*); - - template - void apply(const Derived&, ResultType&, bool&); - - template - void computeIntPower(ResultType&, RealScalar); - - template - void computeIntPower(const Derived&, ResultType&, RealScalar); - - template - void computeFracPower(ResultType&, RealScalar); - public: - /** - * \brief Constructor. - * - * \param[in] A the base of the matrix power. - */ - explicit MatrixPower(const MatrixType& A); + EIGEN_MATRIX_POWER_PUBLIC_INTERFACE(MatrixPower) /** * \brief Constructor. * * \param[in] A the base of the matrix power. */ - template - explicit MatrixPower(const MatrixBase& A); - - /** \brief Destructor. */ - ~MatrixPower(); + template + explicit MatrixPower(const MatrixExpression& A); /** * \brief Return the expression \f$ A^p \f$. @@ -111,40 +73,47 @@ template class MatrixPower */ template void compute(const Derived& b, ResultType& res, RealScalar p); - - Index rows() const { return m_A->rows(); } - Index cols() const { return m_A->cols(); } + + private: + using Base::m_A; + MatrixType m_tmp1, m_tmp2; + ComplexMatrix m_T, m_U, m_fT; + bool m_init; + + RealScalar modfAndInit(RealScalar, RealScalar*); + + template + void apply(const Derived&, ResultType&, bool&); + + template + void computeIntPower(ResultType&, RealScalar); + + template + void computeIntPower(const Derived&, ResultType&, RealScalar); + + template + void computeFracPower(ResultType&, RealScalar); }; template -MatrixPower::MatrixPower(const MatrixType& A) : - m_A(&A), - m_flag(0) +template +MatrixPower::MatrixPower(const MatrixExpression& A) : + Base(A), + m_init(false) { /* empty body */ } -template -template -MatrixPower::MatrixPower(const MatrixBase& A) : - m_A(new MatrixType(A)), - m_flag(2) -{ /* empty body */ } - -template -MatrixPower::~MatrixPower() -{ if (m_flag & 2) delete m_A; } - template void MatrixPower::compute(MatrixType& res, RealScalar p) { - switch (m_A->cols()) { + switch (m_A.cols()) { case 0: break; case 1: - res(0,0) = std::pow(m_A->coeff(0,0), p); + res(0,0) = std::pow(m_A.coeff(0,0), p); break; default: RealScalar intpart, x = modfAndInit(p, &intpart); - res = MatrixType::Identity(m_A->rows(), m_A->cols()); + res = MatrixType::Identity(m_A.rows(), m_A.cols()); computeIntPower(res, intpart); computeFracPower(res, x); } @@ -154,11 +123,11 @@ template template void MatrixPower::compute(const Derived& b, ResultType& res, RealScalar p) { - switch (m_A->cols()) { + switch (m_A.cols()) { case 0: break; case 1: - res = std::pow(m_A->coeff(0,0), p) * b; + res = std::pow(m_A.coeff(0,0), p) * b; break; default: RealScalar intpart, x = modfAndInit(p, &intpart); @@ -168,20 +137,19 @@ void MatrixPower::compute(const Derived& b, ResultType& res, RealSca } template -typename MatrixType::RealScalar MatrixPower::modfAndInit(RealScalar x, RealScalar* intpart) +typename MatrixPower::Base::RealScalar MatrixPower::modfAndInit(RealScalar x, RealScalar* intpart) { static RealScalar maxAbsEival, minAbsEival; *intpart = std::floor(x); RealScalar res = x - *intpart; - if (!(m_flag & 1) && res) { - const ComplexSchur schurOfA(*m_A); + if (!m_init && res) { + const ComplexSchur schurOfA(m_A); m_T = schurOfA.matrixT(); m_U = schurOfA.matrixU(); - m_flag |= 1; + m_init = true; - const Array - absTdiag = m_T.diagonal().array().abs(); + const RealArray absTdiag = m_T.diagonal().array().abs(); maxAbsEival = absTdiag.maxCoeff(); minAbsEival = absTdiag.minCoeff(); } @@ -211,8 +179,8 @@ void MatrixPower::computeIntPower(ResultType& res, RealScalar p) { RealScalar pp = std::abs(p); - if (p<0) m_tmp1 = m_A->inverse(); - else m_tmp1 = *m_A; + if (p<0) m_tmp1 = m_A.inverse(); + else m_tmp1 = m_A; while (pp >= 1) { if (std::fmod(pp, 2) >= 1) @@ -226,8 +194,8 @@ template template void MatrixPower::computeIntPower(const Derived& b, ResultType& res, RealScalar p) { - if (b.cols() >= m_A->cols()) { - m_tmp2 = MatrixType::Identity(m_A->rows(), m_A->cols()); + if (b.cols() >= m_A.cols()) { + m_tmp2 = MatrixType::Identity(m_A.rows(), m_A.cols()); computeIntPower(m_tmp2, p); res.noalias() = m_tmp2 * b; } @@ -241,20 +209,20 @@ void MatrixPower::computeIntPower(const Derived& b, ResultType& res, return; } else if (p>0) { - m_tmp1 = *m_A; + m_tmp1 = m_A; } - else if (m_A->cols() > 2 && b.cols()*(pp-applyings) <= m_A->cols()*squarings) { - PartialPivLU A(*m_A); + else if (m_A.cols() > 2 && b.cols()*(pp-applyings) <= m_A.cols()*squarings) { + PartialPivLU A(m_A); res = A.solve(b); for (--pp; pp >= 1; --pp) res = A.solve(res); return; } else { - m_tmp1 = m_A->inverse(); + m_tmp1 = m_A.inverse(); } - while (b.cols()*(pp-applyings) > m_A->cols()*squarings) { + while (b.cols()*(pp-applyings) > m_A.cols()*squarings) { if (std::fmod(pp, 2) >= 1) { apply(b, res, init); --applyings; @@ -330,13 +298,13 @@ class MatrixPowerReturnValue : public ReturnByValue(A)), m_p(p), m_del(true) { } + : m_pow(*new MatrixPower(A)), m_p(p), m_del(true) { } MatrixPowerReturnValue(MatrixPower& pow, RealScalar p) - : m_pow(&pow), m_p(p), m_del(false) { } + : m_pow(pow), m_p(p), m_del(false) { } ~MatrixPowerReturnValue() - { if (m_del) delete m_pow; } + { if (m_del) delete &m_pow; } /** * \brief Compute the matrix power. @@ -346,17 +314,17 @@ class MatrixPowerReturnValue : public ReturnByValue inline void evalTo(ResultType& res) const - { m_pow->compute(res, m_p); } + { m_pow.compute(res, m_p); } template const MatrixPowerMatrixProduct operator*(const MatrixBase& b) const - { return MatrixPowerMatrixProduct(*m_pow, b.derived(), m_p); } + { return MatrixPowerMatrixProduct(m_pow, b.derived(), m_p); } - Index rows() const { return m_pow->rows(); } - Index cols() const { return m_pow->cols(); } + Index rows() const { return m_pow.rows(); } + Index cols() const { return m_pow.cols(); } private: - MatrixPower* m_pow; + MatrixPower& m_pow; const RealScalar m_p; const bool m_del; // whether to delete the pointer at destruction MatrixPowerReturnValue& operator=(const MatrixPowerReturnValue&); diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h index a809609d5..ca5a604fc 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h @@ -29,30 +29,6 @@ struct recompose_complex_schur<0> { res = (U * (T.template triangularView() * U.adjoint())).real(); } }; -template -struct traits > -{ - typedef MatrixXpr XprKind; - typedef typename remove_all<_Lhs>::type Lhs; - typedef typename remove_all<_Rhs>::type Rhs; - typedef typename remove_all::type PlainObject; - typedef typename scalar_product_traits::ReturnType Scalar; - typedef typename promote_storage_type::StorageKind, - typename traits::StorageKind>::ret StorageKind; - typedef typename promote_index_type::Index, - typename traits::Index>::type Index; - - enum { - RowsAtCompileTime = traits::RowsAtCompileTime, - ColsAtCompileTime = traits::ColsAtCompileTime, - MaxRowsAtCompileTime = traits::MaxRowsAtCompileTime, - MaxColsAtCompileTime = traits::MaxColsAtCompileTime, - Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) - | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, - CoeffReadCost = 0 - }; -}; - template inline int binary_powering_cost(T p, int* squarings) { @@ -121,7 +97,8 @@ inline int matrix_power_get_pade_degree(long double normIminusT) } } // namespace internal -template class MatrixPowerTriangularAtomic +template +class MatrixPowerTriangularAtomic { private: typedef typename MatrixType::Scalar Scalar; @@ -239,10 +216,88 @@ void MatrixPowerTriangularAtomic::computeBig(MatrixType& res, R compute2x2(res, p); } +#define EIGEN_MATRIX_POWER_PUBLIC_INTERFACE(Derived) \ + typedef MatrixPowerBase,MatrixType> Base; \ + using typename Base::Scalar; \ + using typename Base::RealScalar; \ + using typename Base::ComplexMatrix; \ + using typename Base::RealArray; + #define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \ - typedef MatrixPowerProductBase Base; \ + typedef MatrixPowerProductBase Base; \ EIGEN_DENSE_PUBLIC_INTERFACE(Derived) +namespace internal { +template +struct traits > +{ + typedef MatrixXpr XprKind; + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename remove_all::type PlainObject; + typedef typename scalar_product_traits::ReturnType Scalar; + typedef typename promote_storage_type::StorageKind, + typename traits::StorageKind>::ret StorageKind; + typedef typename promote_index_type::Index, + typename traits::Index>::type Index; + + enum { + RowsAtCompileTime = traits::RowsAtCompileTime, + ColsAtCompileTime = traits::ColsAtCompileTime, + MaxRowsAtCompileTime = traits::MaxRowsAtCompileTime, + MaxColsAtCompileTime = traits::MaxColsAtCompileTime, + Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) + | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, + CoeffReadCost = 0 + }; +}; +} // namespace internal + +template +class MatrixPowerBase +{ + protected: + static const int Rows = MatrixType::RowsAtCompileTime; + static const int Cols = MatrixType::ColsAtCompileTime; + static const int Options = MatrixType::Options; + static const int MaxRows = MatrixType::MaxRowsAtCompileTime; + static const int MaxCols = MatrixType::MaxColsAtCompileTime; + + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::RealScalar RealScalar; + typedef typename MatrixType::Index Index; + typedef Matrix,Rows,Cols,Options,MaxRows,MaxCols> ComplexMatrix; + typedef Array RealArray; + + const MatrixType& m_A; + const bool m_del; // whether to delete the pointer at destruction + + public: + explicit MatrixPowerBase(const MatrixType& A) : + m_A(A), + m_del(false) + { /* empty body */ } + + template + explicit MatrixPowerBase(const MatrixBase& A) : + m_A(*new MatrixType(A)), + m_del(true) + { /* empty body */ } + + ~MatrixPowerBase() + { if (m_del) delete &m_A; } + + void compute(MatrixType& res, RealScalar p) + { static_cast(this)->compute(res,p); } + + template + void compute(const OtherDerived& b, ResultType& res, RealScalar p) + { static_cast(this)->compute(b,res,p); } + + Index rows() const { return m_A.rows(); } + Index cols() const { return m_A.cols(); } +}; + template class MatrixPowerProductBase : public MatrixBase {