Add support for dense.cwiseProduct(sparse)

This also fixes a regression regarding (dense*sparse).diagonal()
This commit is contained in:
Gael Guennebaud
2015-11-04 17:42:07 +01:00
parent f6b1deebab
commit 902750826b
6 changed files with 26 additions and 16 deletions

View File

@@ -423,10 +423,10 @@ Derived& SparseMatrixBase<Derived>::operator-=(const DiagonalBase<OtherDerived>&
template<typename Derived>
template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type
SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
{
return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived());
}
} // end namespace Eigen

View File

@@ -262,20 +262,18 @@ template<typename Derived> class SparseMatrixBase
Derived& operator*=(const Scalar& other);
Derived& operator/=(const Scalar& other);
#define EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE \
CwiseBinaryOp< \
internal::scalar_product_op< \
typename internal::scalar_product_traits< \
typename internal::traits<Derived>::Scalar, \
typename internal::traits<OtherDerived>::Scalar \
>::ReturnType \
>, \
const Derived, \
const OtherDerived \
>
template<typename OtherDerived> struct CwiseProductDenseReturnType {
typedef CwiseBinaryOp<internal::scalar_product_op<typename internal::scalar_product_traits<
typename internal::traits<Derived>::Scalar,
typename internal::traits<OtherDerived>::Scalar
>::ReturnType>,
const Derived,
const OtherDerived
> Type;
};
template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
EIGEN_STRONG_INLINE const typename CwiseProductDenseReturnType<OtherDerived>::Type
cwiseProduct(const MatrixBase<OtherDerived> &other) const;
// sparse * diagonal

View File

@@ -49,7 +49,6 @@ const int InnerRandomAccessPattern = 0x2 | CoherentAccessPattern;
const int OuterRandomAccessPattern = 0x4 | CoherentAccessPattern;
const int RandomAccessPattern = 0x8 | OuterRandomAccessPattern | InnerRandomAccessPattern;
template<typename Derived> class SparseMatrixBase;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class DynamicSparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseVector;