mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Add support for sparse * dense and dense * sparse matrix/vector products
This commit is contained in:
@@ -25,9 +25,29 @@
|
||||
#ifndef EIGEN_SPARSEPRODUCT_H
|
||||
#define EIGEN_SPARSEPRODUCT_H
|
||||
|
||||
template<typename Lhs, typename Rhs> struct ei_sparse_product_mode
|
||||
{
|
||||
enum {
|
||||
|
||||
value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit
|
||||
? SparseTimeSparseProduct
|
||||
: (Lhs::Flags&SparseBit)==SparseBit
|
||||
? SparseTimeDenseProduct
|
||||
: DenseTimeSparseProduct };
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductMode>
|
||||
struct SparseProductReturnType
|
||||
{
|
||||
typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
|
||||
typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
|
||||
|
||||
typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type;
|
||||
};
|
||||
|
||||
// sparse product return type specialization
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct SparseProductReturnType
|
||||
struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct>
|
||||
{
|
||||
typedef typename ei_traits<Lhs>::Scalar Scalar;
|
||||
enum {
|
||||
@@ -47,11 +67,11 @@ struct SparseProductReturnType
|
||||
SparseMatrix<Scalar,0>,
|
||||
const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested;
|
||||
|
||||
typedef SparseProduct<LhsNested, RhsNested> Type;
|
||||
typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type;
|
||||
};
|
||||
|
||||
template<typename LhsNested, typename RhsNested>
|
||||
struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
template<typename LhsNested, typename RhsNested, int ProductMode>
|
||||
struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >
|
||||
{
|
||||
// clean the nested types:
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
@@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
|
||||
MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
|
||||
|
||||
LhsRowMajor = LhsFlags & RowMajorBit,
|
||||
RhsRowMajor = RhsFlags & RowMajorBit,
|
||||
// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit,
|
||||
// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit,
|
||||
|
||||
EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
|
||||
ResultIsSparse = ProductMode==SparseTimeSparseProduct,
|
||||
|
||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
||||
RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ),
|
||||
|
||||
Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
||||
| EvalBeforeAssigningBit
|
||||
@@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
|
||||
CoeffReadCost = Dynamic
|
||||
};
|
||||
|
||||
typedef typename ei_meta_if<ResultIsSparse,
|
||||
SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >,
|
||||
MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base;
|
||||
};
|
||||
|
||||
template<typename LhsNested, typename RhsNested>
|
||||
class SparseProduct : ei_no_assignment_operator,
|
||||
public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> >
|
||||
template<typename LhsNested, typename RhsNested, int ProductMode>
|
||||
class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base
|
||||
{
|
||||
public:
|
||||
|
||||
@@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator,
|
||||
public:
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline SparseProduct(const Lhs& lhs, const Rhs& rhs)
|
||||
EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{
|
||||
ei_assert(lhs.cols() == rhs.rows());
|
||||
|
||||
enum {
|
||||
ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic
|
||||
|| _RhsNested::RowsAtCompileTime==Dynamic
|
||||
|| int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime),
|
||||
AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
|
||||
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested)
|
||||
};
|
||||
// note to the lost user:
|
||||
// * for a dot product use: v1.dot(v2)
|
||||
// * for a coeff-wise product use: v1.cwise()*v2
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
|
||||
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
|
||||
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
|
||||
}
|
||||
|
||||
inline int rows() const { return m_lhs.rows(); }
|
||||
inline int cols() const { return m_rhs.cols(); }
|
||||
EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); }
|
||||
EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); }
|
||||
|
||||
const _LhsNested& lhs() const { return m_lhs; }
|
||||
const _LhsNested& rhs() const { return m_rhs; }
|
||||
EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
|
||||
EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
|
||||
|
||||
protected:
|
||||
LhsNested m_lhs;
|
||||
@@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
|
||||
// return derived();
|
||||
// }
|
||||
|
||||
// sparse = sparse * sparse
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product)
|
||||
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product)
|
||||
{
|
||||
// std::cout << "sparse product to sparse\n";
|
||||
ei_sparse_product_selector<
|
||||
@@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs
|
||||
return derived();
|
||||
}
|
||||
|
||||
// dense = sparse * dense
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
|
||||
{
|
||||
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||
typedef typename _Lhs::InnerIterator LhsInnerIterator;
|
||||
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
|
||||
derived().setZero();
|
||||
for (int j=0; j<product.lhs().outerSize(); ++j)
|
||||
for (LhsInnerIterator i(product.lhs(),j); i; ++i)
|
||||
derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j);
|
||||
return derived();
|
||||
}
|
||||
|
||||
// dense = dense * sparse
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product)
|
||||
{
|
||||
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||
typedef typename _Rhs::InnerIterator RhsInnerIterator;
|
||||
enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit };
|
||||
derived().setZero();
|
||||
for (int j=0; j<product.rhs().outerSize(); ++j)
|
||||
for (RhsInnerIterator i(product.rhs(),j); i; ++i)
|
||||
derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index());
|
||||
return derived();
|
||||
}
|
||||
|
||||
// sparse * sparse
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
inline const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
||||
{
|
||||
enum {
|
||||
ProductIsValid = Derived::ColsAtCompileTime==Dynamic
|
||||
|| OtherDerived::RowsAtCompileTime==Dynamic
|
||||
|| int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
|
||||
AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
|
||||
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
|
||||
};
|
||||
// note to the lost user:
|
||||
// * for a dot product use: v1.dot(v2)
|
||||
// * for a coeff-wise product use: v1.cwise()*v2
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
|
||||
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
|
||||
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
|
||||
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
||||
}
|
||||
|
||||
// sparse * dense
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
|
||||
{
|
||||
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user