Implement evaluators for sparse times diagonal products.

This commit is contained in:
Gael Guennebaud
2014-06-27 15:54:44 +02:00
parent ae039dde13
commit 73e686c6a4
3 changed files with 139 additions and 3 deletions

View File

@@ -24,8 +24,10 @@ namespace Eigen {
// for that particular case
// The two other cases are symmetric.
#ifndef EIGEN_TEST_EVALUATORS
namespace internal {
template<typename Lhs, typename Rhs>
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
{
@@ -100,9 +102,14 @@ class SparseDiagonalProduct
LhsNested m_lhs;
RhsNested m_rhs;
};
#endif
namespace internal {
#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
class sparse_diagonal_product_inner_iterator_selector
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
@@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector
inline Index row() const { return m_outer; }
};
#else // EIGEN_TEST_EVALUATORS
enum {
SDP_AsScalarProduct,
SDP_AsCwiseProduct
};
template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
struct sparse_diagonal_product_evaluator;
template<typename Lhs, typename Rhs, int Options, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
{
typedef Product<Lhs, Rhs, Options> XprType;
typedef evaluator<XprType> type;
typedef evaluator<XprType> nestedType;
enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
};
template<typename Lhs, typename Rhs, int Options, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
{
typedef Product<Lhs, Rhs, Options> XprType;
typedef evaluator<XprType> type;
typedef evaluator<XprType> nestedType;
enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
};
template<typename SparseXprType, typename DiagonalCoeffType>
struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
{
protected:
typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
typedef typename SparseXprType::Scalar Scalar;
typedef typename SparseXprType::Index Index;
public:
class InnerIterator : public SparseXprInnerIterator
{
public:
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
: SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
{}
EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
protected:
typename DiagonalCoeffType::Scalar m_coeff;
};
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
: m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
{}
protected:
typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
};
template<typename SparseXprType, typename DiagCoeffType>
struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
{
typedef typename SparseXprType::Scalar Scalar;
typedef typename SparseXprType::Index Index;
typedef CwiseBinaryOp<scalar_product_op<Scalar>,
const typename SparseXprType::ConstInnerVectorReturnType,
const DiagCoeffType> CwiseProductType;
typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
class InnerIterator : public CwiseProductIterator
{
public:
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
: CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0),
m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
m_outer(outer)
{
::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0);
}
inline Index outer() const { return m_outer; }
inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; }
inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); }
protected:
Index m_outer;
CwiseProductEval m_cwiseEval;
};
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
: m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
{}
protected:
typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
: SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
};
#endif // EIGEN_TEST_EVALUATORS
} // end namespace internal
// SparseMatrixBase functions
#ifndef EIGEN_TEST_EVALUATORS
// SparseMatrixBase functions
template<typename Derived>
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
{
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
}
#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen