From 5d43b4049dd7eb4d5e742a4441ee164bb886e6fe Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 24 Oct 2011 09:33:24 +0200 Subject: [PATCH] factorize solving with guess --- .../Eigen/src/IterativeSolvers/BiCGSTAB.h | 40 +++-------- .../src/IterativeSolvers/ConjugateGradient.h | 70 ++++++++----------- unsupported/Eigen/src/SparseExtra/Solve.h | 42 ++++++++++- 3 files changed, 80 insertions(+), 72 deletions(-) diff --git a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h index ee2ab128d..798f85da5 100644 --- a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h +++ b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h @@ -106,9 +106,6 @@ class BiCGSTAB; namespace internal { -template -class bicgstab_solve_retval_with_guess; - template< typename _MatrixType, typename _Preconditioner> struct traits > { @@ -204,19 +201,19 @@ public: * \sa compute() */ template - inline const internal::bicgstab_solve_retval_with_guess + inline const internal::solve_retval_with_guess solveWithGuess(const MatrixBase& b, const Guess& x0) const { eigen_assert(m_isInitialized && "BiCGSTAB is not initialized."); eigen_assert(Base::rows()==b.rows() && "BiCGSTAB::solve(): invalid number of rows of the right hand side matrix b"); - return internal::bicgstab_solve_retval_with_guess + return internal::solve_retval_with_guess (*this, b.derived(), x0); } /** \internal */ template - void _solve(const Rhs& b, Dest& x) const + void _solveWithGuess(const Rhs& b, Dest& x) const { for(int j=0; j + void _solve(const Rhs& b, Dest& x) const + { + x.setOnes(); + _solveWithGuess(b,x); + } + protected: }; @@ -247,33 +252,10 @@ struct solve_retval, Rhs> template void evalTo(Dest& dst) const { - dst.setOnes(); dec()._solve(rhs(),dst); } }; -template -class bicgstab_solve_retval_with_guess - : public solve_retval_base -{ - typedef Eigen::internal::solve_retval_base Base; - using Base::dec; - using Base::rhs; - public: - bicgstab_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess) - : Base(cg, rhs), m_guess(guess) - {} - - template void evalTo(Dest& dst) const - { - dst = m_guess; - dec()._solve(rhs(), dst); - } - protected: - const Guess& m_guess; - -}; - } #endif // EIGEN_BICGSTAB_H diff --git a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h index 2a78337c5..ced3e310c 100644 --- a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h +++ b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h @@ -37,6 +37,7 @@ namespace internal { * \param tol_error On input the tolerance error, on output an estimation of the relative error. */ template +EIGEN_DONT_INLINE void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, const Preconditioner& precond, int& iters, typename Dest::RealScalar& tol_error) @@ -59,7 +60,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, VectorType z(n), tmp(n); RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM RealScalar absInit = absNew; // the initial absolute value - + int i = 0; while ((i < maxIters) && (absNew > tol*tol*absInit)) { @@ -89,9 +90,6 @@ class ConjugateGradient; namespace internal { -template -class conjugate_gradient_solve_retval_with_guess; - template< typename _MatrixType, int _UpLo, typename _Preconditioner> struct traits > { @@ -193,32 +191,43 @@ public: * \sa compute() */ template - inline const internal::conjugate_gradient_solve_retval_with_guess + inline const internal::solve_retval_with_guess solveWithGuess(const MatrixBase& b, const Guess& x0) const { eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); eigen_assert(Base::rows()==b.rows() && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b"); - return internal::conjugate_gradient_solve_retval_with_guess + return internal::solve_retval_with_guess (*this, b.derived(), x0); } + + /** \internal */ + template + void _solveWithGuess(const Rhs& b, Dest& x) const + { + m_iterations = Base::m_maxIterations; + m_error = Base::m_tolerance; + + for(int j=0; jtemplate selfadjointView(), b.col(j), xj, + Base::m_preconditioner, m_iterations, m_error); + } + + m_isInitialized = true; + m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; + } /** \internal */ template void _solve(const Rhs& b, Dest& x) const { - for(int j=0; jtemplate selfadjointView(), b.col(j), xj, - Base::m_preconditioner, m_iterations, m_error); - } - - m_isInitialized = true; - m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; + x.setOnes(); + _solveWithGuess(b,x); } protected: @@ -228,7 +237,7 @@ protected: namespace internal { - template +template struct solve_retval, Rhs> : solve_retval_base, Rhs> { @@ -237,33 +246,10 @@ struct solve_retval, Rhs> template void evalTo(Dest& dst) const { - dst.setOnes(); dec()._solve(rhs(),dst); } }; -template -class conjugate_gradient_solve_retval_with_guess - : public solve_retval_base -{ - typedef Eigen::internal::solve_retval_base Base; - using Base::dec; - using Base::rhs; - public: - conjugate_gradient_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess) - : Base(cg, rhs), m_guess(guess) - {} - - template void evalTo(Dest& dst) const - { - dst = m_guess; - dec()._solve(rhs(), dst); - } - protected: - const Guess& m_guess; - -}; - } #endif // EIGEN_CONJUGATE_GRADIENT_H diff --git a/unsupported/Eigen/src/SparseExtra/Solve.h b/unsupported/Eigen/src/SparseExtra/Solve.h index 19449e9de..5b6c859ae 100644 --- a/unsupported/Eigen/src/SparseExtra/Solve.h +++ b/unsupported/Eigen/src/SparseExtra/Solve.h @@ -76,7 +76,47 @@ template struct sparse_solve_retval_b using Base::cols; \ sparse_solve_retval(const DecompositionType& dec, const Rhs& rhs) \ : Base(dec, rhs) {} - + + + +template struct solve_retval_with_guess; + +template +struct traits > +{ + typedef typename DecompositionType::MatrixType MatrixType; + typedef Matrix ReturnType; +}; + +template struct solve_retval_with_guess + : public ReturnByValue > +{ + typedef typename DecompositionType::Index Index; + + solve_retval_with_guess(const DecompositionType& dec, const Rhs& rhs, const Guess& guess) + : m_dec(dec), m_rhs(rhs), m_guess(guess) + {} + + inline Index rows() const { return m_dec.cols(); } + inline Index cols() const { return m_rhs.cols(); } + + template inline void evalTo(Dest& dst) const + { + dst = m_guess; + m_dec._solveWithGuess(m_rhs,dst); + } + + protected: + const DecompositionType& m_dec; + const typename Rhs::Nested m_rhs; + const typename Guess::Nested m_guess; +}; + } // namepsace internal #endif // EIGEN_SPARSE_SOLVE_H