// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2009-2014 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H namespace Eigen { // The product of a diagonal matrix with a sparse matrix can be easily // implemented using expression template. // We have two consider very different cases: // 1 - diag * row-major sparse // => each inner vector <=> scalar * sparse vector product // => so we can reuse CwiseUnaryOp::InnerIterator // 2 - diag * col-major sparse // => each inner vector <=> densevector * sparse vector cwise product // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator // for that particular case // The two other cases are symmetric. #ifndef EIGEN_TEST_EVALUATORS namespace internal { template struct traits > { typedef typename remove_all::type _Lhs; typedef typename remove_all::type _Rhs; typedef typename _Lhs::Scalar Scalar; // propagate the index type of the sparse matrix typedef typename conditional< is_diagonal<_Lhs>::ret, typename traits::Index, typename traits::Index>::type Index; typedef Sparse StorageKind; typedef MatrixXpr XprKind; enum { RowsAtCompileTime = _Lhs::RowsAtCompileTime, ColsAtCompileTime = _Rhs::ColsAtCompileTime, MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime, MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime, SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags), Flags = (SparseFlags&RowMajorBit), CoeffReadCost = Dynamic }; }; enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor}; template class sparse_diagonal_product_inner_iterator_selector; } // end namespace internal template class SparseDiagonalProduct : public SparseMatrixBase >, internal::no_assignment_operator { typedef typename Lhs::Nested LhsNested; typedef typename Rhs::Nested RhsNested; typedef typename internal::remove_all::type _LhsNested; typedef typename internal::remove_all::type _RhsNested; enum { LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor, RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor }; public: EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct) typedef internal::sparse_diagonal_product_inner_iterator_selector <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator; // We do not want ReverseInnerIterator for diagonal-sparse products, // but this dummy declaration is needed to make diag * sparse * diag compile. class ReverseInnerIterator; EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product"); } EIGEN_STRONG_INLINE Index rows() const { return Index(m_lhs.rows()); } EIGEN_STRONG_INLINE Index cols() const { return Index(m_rhs.cols()); } EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; RhsNested m_rhs; }; #endif namespace internal { #ifndef EIGEN_TEST_EVALUATORS template class sparse_diagonal_product_inner_iterator_selector : public CwiseUnaryOp,const Rhs>::InnerIterator { typedef typename CwiseUnaryOp,const Rhs>::InnerIterator Base; typedef typename Rhs::Index Index; public: inline sparse_diagonal_product_inner_iterator_selector( const SparseDiagonalProductType& expr, Index outer) : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer) {} }; template class sparse_diagonal_product_inner_iterator_selector : public CwiseBinaryOp< scalar_product_op, const typename Rhs::ConstInnerVectorReturnType, const typename Lhs::DiagonalVectorType>::InnerIterator { typedef typename CwiseBinaryOp< scalar_product_op, const typename Rhs::ConstInnerVectorReturnType, const typename Lhs::DiagonalVectorType>::InnerIterator Base; typedef typename Rhs::Index Index; Index m_outer; public: inline sparse_diagonal_product_inner_iterator_selector( const SparseDiagonalProductType& expr, Index outer) : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer) {} inline Index outer() const { return m_outer; } inline Index col() const { return m_outer; } }; template class sparse_diagonal_product_inner_iterator_selector : public CwiseUnaryOp,const Lhs>::InnerIterator { typedef typename CwiseUnaryOp,const Lhs>::InnerIterator Base; typedef typename Lhs::Index Index; public: inline sparse_diagonal_product_inner_iterator_selector( const SparseDiagonalProductType& expr, Index outer) : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer) {} }; template class sparse_diagonal_product_inner_iterator_selector : public CwiseBinaryOp< scalar_product_op, const typename Lhs::ConstInnerVectorReturnType, const Transpose >::InnerIterator { typedef typename CwiseBinaryOp< scalar_product_op, const typename Lhs::ConstInnerVectorReturnType, const Transpose >::InnerIterator Base; typedef typename Lhs::Index Index; Index m_outer; public: inline sparse_diagonal_product_inner_iterator_selector( const SparseDiagonalProductType& expr, Index outer) : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer) {} inline Index outer() const { return m_outer; } inline Index row() const { return m_outer; } }; #else // EIGEN_TEST_EVALUATORS enum { SDP_AsScalarProduct, SDP_AsCwiseProduct }; template struct sparse_diagonal_product_evaluator; template struct product_evaluator, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar> : public sparse_diagonal_product_evaluator { typedef Product XprType; typedef evaluator type; typedef evaluator nestedType; enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags typedef sparse_diagonal_product_evaluator Base; product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} }; template struct product_evaluator, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar> : public sparse_diagonal_product_evaluator, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> { typedef Product XprType; typedef evaluator type; typedef evaluator nestedType; enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags typedef sparse_diagonal_product_evaluator, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {} }; template struct sparse_diagonal_product_evaluator { protected: typedef typename evaluator::InnerIterator SparseXprInnerIterator; typedef typename SparseXprType::Scalar Scalar; typedef typename SparseXprType::Index Index; public: class InnerIterator : public SparseXprInnerIterator { public: InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) {} EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } protected: typename DiagonalCoeffType::Scalar m_coeff; }; sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) {} protected: typename evaluator::nestedType m_sparseXprImpl; typename evaluator::nestedType m_diagCoeffImpl; }; template struct sparse_diagonal_product_evaluator { typedef typename SparseXprType::Scalar Scalar; typedef typename SparseXprType::Index Index; typedef CwiseBinaryOp, const typename SparseXprType::ConstInnerVectorReturnType, const DiagCoeffType> CwiseProductType; typedef typename evaluator::type CwiseProductEval; typedef typename evaluator::InnerIterator CwiseProductIterator; class InnerIterator { public: InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) : m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)), m_cwiseIter(m_cwiseEval, 0), m_outer(outer) {} inline Scalar value() const { return m_cwiseIter.value(); } inline Index index() const { return m_cwiseIter.index(); } inline Index outer() const { return m_outer; } inline Index col() const { return SparseXprType::IsRowMajor ? m_cwiseIter.index() : m_outer; } inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : m_cwiseIter.index(); } EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_cwiseIter; return *this; } inline operator bool() const { return m_cwiseIter; } protected: CwiseProductEval m_cwiseEval; CwiseProductIterator m_cwiseIter; Index m_outer; }; sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff) {} protected: typename nested_eval::type m_sparseXprNested; typename nested_eval::type m_diagCoeffNested; }; #endif // EIGEN_TEST_EVALUATORS } // end namespace internal #ifndef EIGEN_TEST_EVALUATORS // SparseMatrixBase functions template template const SparseDiagonalProduct SparseMatrixBase::operator*(const DiagonalBase &other) const { return SparseDiagonalProduct(this->derived(), other.derived()); } #endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H