mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix mixing types in sparse matrix products.
This commit is contained in:
@@ -17,7 +17,9 @@ namespace internal {
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, bool sortedInsertion = false)
|
||||
{
|
||||
typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
||||
typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
|
||||
typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
|
||||
typedef typename remove_all<ResultType>::type::Scalar ResScalar;
|
||||
|
||||
// make sure to call innerSize/outerSize since we fake the storage order.
|
||||
Index rows = lhs.innerSize();
|
||||
@@ -25,7 +27,7 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r
|
||||
eigen_assert(lhs.outerSize() == rhs.innerSize());
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(bool, mask, rows, 0);
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, values, rows, 0);
|
||||
ei_declare_aligned_stack_constructed_variable(ResScalar, values, rows, 0);
|
||||
ei_declare_aligned_stack_constructed_variable(Index, indices, rows, 0);
|
||||
|
||||
std::memset(mask,0,sizeof(bool)*rows);
|
||||
@@ -51,12 +53,12 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r
|
||||
Index nnz = 0;
|
||||
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||
{
|
||||
Scalar y = rhsIt.value();
|
||||
RhsScalar y = rhsIt.value();
|
||||
Index k = rhsIt.index();
|
||||
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
|
||||
{
|
||||
Index i = lhsIt.index();
|
||||
Scalar x = lhsIt.value();
|
||||
LhsScalar x = lhsIt.value();
|
||||
if(!mask[i])
|
||||
{
|
||||
mask[i] = true;
|
||||
@@ -166,11 +168,12 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
|
||||
RowMajorMatrix rhsRow = rhs;
|
||||
RowMajorMatrix resRow(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow);
|
||||
res = resRow;
|
||||
typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRhs;
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
|
||||
RowMajorRhs rhsRow = rhs;
|
||||
RowMajorRes resRow(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<RowMajorRhs,Lhs,RowMajorRes>(rhsRow, lhs, resRow);
|
||||
res = resRow;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -179,10 +182,11 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
|
||||
RowMajorMatrix lhsRow = lhs;
|
||||
RowMajorMatrix resRow(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
|
||||
typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorLhs;
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
|
||||
RowMajorLhs lhsRow = lhs;
|
||||
RowMajorRes resRow(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorLhs,RowMajorRes>(rhs, lhsRow, resRow);
|
||||
res = resRow;
|
||||
}
|
||||
};
|
||||
@@ -219,10 +223,11 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix lhsCol = lhs;
|
||||
ColMajorMatrix resCol(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
|
||||
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
|
||||
ColMajorLhs lhsCol = lhs;
|
||||
ColMajorRes resCol(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<ColMajorLhs,Rhs,ColMajorRes>(lhsCol, rhs, resCol);
|
||||
res = resCol;
|
||||
}
|
||||
};
|
||||
@@ -232,10 +237,11 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix rhsCol = rhs;
|
||||
ColMajorMatrix resCol(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
|
||||
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
|
||||
ColMajorRhs rhsCol = rhs;
|
||||
ColMajorRes resCol(lhs.rows(), rhs.cols());
|
||||
internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorRhs,ColMajorRes>(lhs, rhsCol, resCol);
|
||||
res = resCol;
|
||||
}
|
||||
};
|
||||
@@ -263,7 +269,8 @@ namespace internal {
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
||||
typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
|
||||
typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
|
||||
Index cols = rhs.outerSize();
|
||||
eigen_assert(lhs.outerSize() == rhs.innerSize());
|
||||
|
||||
@@ -274,12 +281,12 @@ static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs,
|
||||
{
|
||||
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||
{
|
||||
Scalar y = rhsIt.value();
|
||||
RhsScalar y = rhsIt.value();
|
||||
Index k = rhsIt.index();
|
||||
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
|
||||
{
|
||||
Index i = lhsIt.index();
|
||||
Scalar x = lhsIt.value();
|
||||
LhsScalar x = lhsIt.value();
|
||||
res.coeffRef(i,j) += x * y;
|
||||
}
|
||||
}
|
||||
@@ -310,9 +317,9 @@ struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMa
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix lhsCol(lhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res);
|
||||
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
|
||||
ColMajorLhs lhsCol(lhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<ColMajorLhs,Rhs,ResultType>(lhsCol, rhs, res);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -321,9 +328,9 @@ struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMa
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix rhsCol(rhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res);
|
||||
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
|
||||
ColMajorRhs rhsCol(rhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorRhs,ResultType>(lhs, rhsCol, res);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
||||
{
|
||||
// return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
|
||||
|
||||
typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
||||
typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
|
||||
typedef typename remove_all<ResultType>::type::Scalar ResScalar;
|
||||
typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
|
||||
|
||||
// make sure to call innerSize/outerSize since we fake the storage order.
|
||||
@@ -31,7 +32,7 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
||||
eigen_assert(lhs.outerSize() == rhs.innerSize());
|
||||
|
||||
// allocate a temporary buffer
|
||||
AmbiVector<Scalar,StorageIndex> tempVector(rows);
|
||||
AmbiVector<ResScalar,StorageIndex> tempVector(rows);
|
||||
|
||||
// mimics a resizeByInnerOuter:
|
||||
if(ResultType::IsRowMajor)
|
||||
@@ -63,14 +64,14 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
||||
{
|
||||
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
||||
tempVector.restart();
|
||||
Scalar x = rhsIt.value();
|
||||
RhsScalar x = rhsIt.value();
|
||||
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
|
||||
{
|
||||
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
||||
}
|
||||
}
|
||||
res.startVec(j);
|
||||
for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
|
||||
for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
|
||||
res.insertBackByOuterInner(j,it.index()) = it.value();
|
||||
}
|
||||
res.finalize();
|
||||
@@ -85,7 +86,6 @@ struct sparse_sparse_product_with_pruning_selector;
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
|
||||
{
|
||||
typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
@@ -129,8 +129,8 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
|
||||
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
|
||||
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
|
||||
ColMajorMatrixLhs colLhs(lhs);
|
||||
ColMajorMatrixRhs colRhs(rhs);
|
||||
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
|
||||
@@ -149,7 +149,7 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,R
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
|
||||
typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
|
||||
RowMajorMatrixLhs rowLhs(lhs);
|
||||
sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
|
||||
}
|
||||
@@ -161,7 +161,7 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,C
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
|
||||
typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
|
||||
RowMajorMatrixRhs rowRhs(rhs);
|
||||
sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
|
||||
}
|
||||
@@ -173,7 +173,7 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,R
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
|
||||
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
|
||||
ColMajorMatrixRhs colRhs(rhs);
|
||||
internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
|
||||
}
|
||||
@@ -185,7 +185,7 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,C
|
||||
typedef typename ResultType::RealScalar RealScalar;
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
|
||||
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
|
||||
ColMajorMatrixLhs colLhs(lhs);
|
||||
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user