From 553a0ae924d8307944ab8465bbe2106c449ea2d7 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 1 Mar 2012 10:13:13 +0100 Subject: [PATCH] simplify and speedup sparse * dense matrix products --- Eigen/src/SparseCore/SparseDenseProduct.h | 127 +++++++++++++++++----- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h index b372853dc..0bdaee21e 100644 --- a/Eigen/src/SparseCore/SparseDenseProduct.h +++ b/Eigen/src/SparseCore/SparseDenseProduct.h @@ -149,6 +149,102 @@ struct traits > typedef Dense StorageKind; typedef MatrixXpr XprKind; }; + +template +struct sparse_time_dense_product_impl; + +template +struct sparse_time_dense_product_impl +{ + typedef typename internal::remove_all::type Lhs; + typedef typename internal::remove_all::type Rhs; + typedef typename internal::remove_all::type Res; + typedef typename Lhs::Index Index; + typedef typename Lhs::InnerIterator LhsInnerIterator; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index c=0; c +struct sparse_time_dense_product_impl +{ + typedef typename internal::remove_all::type Lhs; + typedef typename internal::remove_all::type Rhs; + typedef typename internal::remove_all::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index c=0; c +struct sparse_time_dense_product_impl +{ + typedef typename internal::remove_all::type Lhs; + typedef typename internal::remove_all::type Rhs; + typedef typename internal::remove_all::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index j=0; j +struct sparse_time_dense_product_impl +{ + typedef typename internal::remove_all::type Lhs; + typedef typename internal::remove_all::type Rhs; + typedef typename internal::remove_all::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index j=0; j +inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) +{ + sparse_time_dense_product_impl::run(lhs, rhs, res, alpha); +} + } // end namespace internal template @@ -163,27 +259,7 @@ class SparseTimeDenseProduct template void scaleAndAddTo(Dest& dest, Scalar alpha) const { - typedef typename _LhsNested::InnerIterator LhsInnerIterator; - enum { - LhsIsRowMajor = (_LhsNested::Flags&RowMajorBit)==RowMajorBit, - RhsIsVector = _RhsNested::ColsAtCompileTime==1 - }; - Index j=0; - for(j=0; j void scaleAndAddTo(Dest& dest, Scalar alpha) const { - typedef typename _RhsNested::InnerIterator RhsInnerIterator; - enum { RhsIsRowMajor = (_RhsNested::Flags&RowMajorBit)==RowMajorBit }; - for(Index j=0; j lhs_t(m_lhs); + Transpose rhs_t(m_rhs); + Transpose dest_t(dest); + internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); } private: