fix m = m*m with m sparse (gug found by Frederik Heinz)

This commit is contained in:
Gael Guennebaud
2009-02-12 15:57:13 +00:00
parent 59a1ed0932
commit 20a8bb96eb
2 changed files with 98 additions and 87 deletions

View File

@@ -171,6 +171,55 @@ class SparseProduct : ei_no_assignment_operator,
RhsNested m_rhs;
};
// perform a pseudo in-place sparse * sparse product assuming all matrices are col major
template<typename Lhs, typename Rhs, typename ResultType>
static void ei_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef typename ei_traits<typename ei_cleantype<Lhs>::type>::Scalar Scalar;
// make sure to call innerSize/outerSize since we fake the storage order.
int rows = lhs.innerSize();
int cols = rhs.outerSize();
//int size = lhs.outerSize();
ei_assert(lhs.outerSize() == rhs.innerSize());
// allocate a temporary buffer
AmbiVector<Scalar> tempVector(rows);
// estimate the number of non zero entries
float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
res.resize(rows, cols);
res.startFill(int(ratioRes*rows*cols));
for (int j=0; j<cols; ++j)
{
// let's do a more accurate determination of the nnz ratio for the current column j of res
//float ratioColRes = std::min(ratioLhs * rhs.innerNonZeros(j), 1.f);
// FIXME find a nice way to get the number of nonzeros of a sub matrix (here an inner vector)
float ratioColRes = ratioRes;
tempVector.init(ratioColRes);
tempVector.setZero();
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
{
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
tempVector.restart();
Scalar x = rhsIt.value();
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
{
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
}
}
for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it)
if (ResultType::Flags&RowMajorBit)
res.fill(j,it.index()) = it.value();
else
res.fill(it.index(), j) = it.value();
}
res.endFill();
}
template<typename Lhs, typename Rhs, typename ResultType,
int LhsStorageOrder = ei_traits<Lhs>::Flags&RowMajorBit,
int RhsStorageOrder = ei_traits<Rhs>::Flags&RowMajorBit,
@@ -184,58 +233,21 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
// make sure to call innerSize/outerSize since we fake the storage order.
int rows = lhs.innerSize();
int cols = rhs.outerSize();
//int size = lhs.outerSize();
ei_assert(lhs.outerSize() == rhs.innerSize());
// allocate a temporary buffer
AmbiVector<Scalar> tempVector(rows);
// estimate the number of non zero entries
float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
res.resize(rows, cols);
res.startFill(int(ratioRes*rows*cols));
for (int j=0; j<cols; ++j)
{
// let's do a more accurate determination of the nnz ratio for the current column j of res
//float ratioColRes = std::min(ratioLhs * rhs.innerNonZeros(j), 1.f);
// FIXME find a nice way to get the number of nonzeros of a sub matrix (here an inner vector)
float ratioColRes = ratioRes;
tempVector.init(ratioColRes);
tempVector.setZero();
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
{
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
tempVector.restart();
Scalar x = rhsIt.value();
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
{
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
}
}
for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it)
if (ResultType::Flags&RowMajorBit)
res.fill(j,it.index()) = it.value();
else
res.fill(it.index(), j) = it.value();
}
res.endFill();
typename ei_cleantype<ResultType>::type _res(res.rows(), res.cols());
ei_sparse_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res);
res.swap(_res);
}
};
template<typename Lhs, typename Rhs, typename ResultType>
struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
{
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
// we need a col-major matrix to hold the result
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
SparseTemporaryType _res(res.rows(), res.cols());
ei_sparse_product_selector<Lhs,Rhs,SparseTemporaryType,ColMajor,ColMajor,ColMajor>::run(lhs, rhs, _res);
ei_sparse_product_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res);
res = _res;
}
};
@@ -246,20 +258,21 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
// let's transpose the product to get a column x column product
ei_sparse_product_selector<Rhs,Lhs,ResultType,ColMajor,ColMajor,ColMajor>::run(rhs, lhs, res);
typename ei_cleantype<ResultType>::type _res(res.rows(), res.cols());
ei_sparse_product_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res);
res.swap(_res);
}
};
template<typename Lhs, typename Rhs, typename ResultType>
struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
{
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
// let's transpose the product to get a column x column product
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
SparseTemporaryType _res(res.cols(), res.rows());
ei_sparse_product_selector<Rhs,Lhs,SparseTemporaryType,ColMajor,ColMajor,ColMajor>
::run(rhs, lhs, _res);
ei_sparse_product_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
res = _res.transpose();
}
};
@@ -322,7 +335,7 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs
template<typename Derived>
template<typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
{
{
typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename _Lhs::InnerIterator LhsInnerIterator;