From 7fdf2189516e2d4548c04ed4e18b7a28c28eb77c Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 17 Jun 2010 10:17:22 +0200 Subject: [PATCH] makes trmv works with the triangular matrix on the right --- .../Core/products/TriangularMatrixVector.h | 64 ++++++++++++++++--- test/product_trmv.cpp | 6 ++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index 1a2b183aa..935054a5a 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -25,12 +25,26 @@ #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H #define EIGEN_TRIANGULARMATRIXVECTOR_H -template struct ei_product_triangular_vector_selector; +template +struct ei_product_triangular_vector_selector +{ + static EIGEN_DONT_INLINE void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename ei_traits::Scalar alpha) + { + typedef Transpose TrRhs; TrRhs trRhs(rhs); + typedef Transpose TrLhs; TrLhs trLhs(lhs); + typedef Transpose TrRes; TrRes trRes(res); + ei_product_triangular_vector_selector + ::run(trRhs,trLhs,trRes,alpha); + } +}; + template -struct ei_product_triangular_vector_selector +struct ei_product_triangular_vector_selector { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -74,7 +88,7 @@ struct ei_product_triangular_vector_selector -struct ei_product_triangular_vector_selector +struct ei_product_triangular_vector_selector { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -119,12 +133,17 @@ struct ei_product_triangular_vector_selector -struct ei_traits > - : ei_traits, Lhs, Rhs> > +template +struct ei_traits > + : ei_traits, Lhs, Rhs> > {}; -template +template +struct ei_traits > + : ei_traits, Lhs, Rhs> > +{}; + +template struct TriangularProduct : public ProductBase, Lhs, Rhs > { @@ -143,7 +162,7 @@ struct TriangularProduct * RhsBlasTraits::extractScalarFactor(m_rhs); ei_product_triangular_vector_selector - <_ActualLhsType,_ActualRhsType,Dest, + } }; +template +struct TriangularProduct + : public ProductBase, Lhs, Rhs > +{ + EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) + + TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} + + template void scaleAndAddTo(Dest& dst, Scalar alpha) const + { + + ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); + + const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); + const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) + * RhsBlasTraits::extractScalarFactor(m_rhs); + + ei_product_triangular_vector_selector + ::Flags)&RowMajorBit) ? RowMajor : ColMajor> + ::run(lhs,rhs,dst,actualAlpha); + } +}; + #endif // EIGEN_TRIANGULARMATRIXVECTOR_H diff --git a/test/product_trmv.cpp b/test/product_trmv.cpp index f0962557a..2f5743187 100644 --- a/test/product_trmv.cpp +++ b/test/product_trmv.cpp @@ -76,6 +76,12 @@ template void trmv(const MatrixType& m) VERIFY((m3.adjoint() * (s1*v1.conjugate())).isApprox(m1.adjoint().template triangularView() * (s1*v1.conjugate()), largerEps)); m3 = m1.template triangularView(); + // check transposed cases: + m3 = m1.template triangularView(); + VERIFY((v1.transpose() * m3).isApprox(v1.transpose() * m1.template triangularView(), largerEps)); + VERIFY((v1.adjoint() * m3).isApprox(v1.adjoint() * m1.template triangularView(), largerEps)); + VERIFY((v1.adjoint() * m3.adjoint()).isApprox(v1.adjoint() * m1.template triangularView().adjoint(), largerEps)); + // TODO check with sub-matrices }