mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Implement evaluators for sparse * sparse with auto pruning.
This commit is contained in:
@@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
||||
res.resize(cols, rows);
|
||||
else
|
||||
res.resize(rows, cols);
|
||||
|
||||
#ifdef EIGEN_TEST_EVALUATORS
|
||||
typename evaluator<Lhs>::type lhsEval(lhs);
|
||||
typename evaluator<Rhs>::type rhsEval(rhs);
|
||||
#endif
|
||||
|
||||
res.reserve(estimated_nnz_prod);
|
||||
double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
|
||||
@@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
||||
// let's do a more accurate determination of the nnz ratio for the current column j of res
|
||||
tempVector.init(ratioColRes);
|
||||
tempVector.setZero();
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
||||
#else
|
||||
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||
#endif
|
||||
{
|
||||
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
||||
tempVector.restart();
|
||||
Scalar x = rhsIt.value();
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
|
||||
#else
|
||||
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
|
||||
#endif
|
||||
{
|
||||
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
||||
}
|
||||
@@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
|
||||
}
|
||||
};
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
// NOTE the 2 others cases (col row *) must never occur since they are caught
|
||||
// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
|
||||
#else
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
|
||||
{
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs;
|
||||
RowMajorMatrixLhs rowLhs(lhs);
|
||||
sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
|
||||
{
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs;
|
||||
RowMajorMatrixRhs rowRhs(rhs);
|
||||
sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
|
||||
{
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
|
||||
ColMajorMatrixRhs colRhs(rhs);
|
||||
internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
|
||||
{
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
|
||||
ColMajorMatrixLhs colLhs(lhs);
|
||||
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
Reference in New Issue
Block a user