Refactor TriangularView to handle both dense and sparse objects. Introduce a glu_shape<S1,S2> helper to assemble sparse/dense shapes with triagular/seladjoint views.

This commit is contained in:
Gael Guennebaud
2014-07-22 11:35:56 +02:00
parent 2a251ffab0
commit 6daa6a0d16
15 changed files with 520 additions and 284 deletions

View File

@@ -23,6 +23,7 @@ template<typename Lhs, typename Rhs, int Mode,
int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
struct sparse_solve_triangular_selector;
#ifndef EIGEN_TEST_EVALUATORS
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
@@ -163,13 +164,166 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
}
};
#else // EIGEN_TEST_EVALUATORS
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.rows(); ++i)
{
Scalar tmp = other.coeff(i,col);
Scalar lastVal(0);
Index lastIndex = 0;
for(LhsIterator it(lhsEval, i); it; ++it)
{
lastVal = it.value();
lastIndex = it.index();
if(lastIndex==i)
break;
tmp -= lastVal * other.coeff(lastIndex,col);
}
if (Mode & UnitDiag)
other.coeffRef(i,col) = tmp;
else
{
eigen_assert(lastIndex==i);
other.coeffRef(i,col) = tmp/lastVal;
}
}
}
}
};
// backward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.rows()-1 ; i>=0 ; --i)
{
Scalar tmp = other.coeff(i,col);
Scalar l_ii = 0;
LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
{
eigen_assert(it && it.index()==i);
l_ii = it.value();
++it;
}
else if (it && it.index() == i)
++it;
for(; it; ++it)
{
tmp -= it.value() * other.coeff(it.index(),col);
}
if (Mode & UnitDiag) other.coeffRef(i,col) = tmp;
else other.coeffRef(i,col) = tmp/l_ii;
}
}
}
};
// forward substitution, col-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.cols(); ++i)
{
Scalar& tmp = other.coeffRef(i,col);
if (tmp!=Scalar(0)) // optimization when other is actually sparse
{
LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
{
eigen_assert(it && it.index()==i);
tmp /= it.value();
}
if (it && it.index()==i)
++it;
for(; it; ++it)
other.coeffRef(it.index(), col) -= tmp * it.value();
}
}
}
}
};
// backward substitution, col-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.cols()-1; i>=0; --i)
{
Scalar& tmp = other.coeffRef(i,col);
if (tmp!=Scalar(0)) // optimization when other is actually sparse
{
if(!(Mode & UnitDiag))
{
// TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
LhsIterator it(lhsEval, i);
while(it && it.index()!=i)
++it;
eigen_assert(it && it.index()==i);
other.coeffRef(i,col) /= it.value();
}
LhsIterator it(lhsEval, i);
for(; it && it.index()<i; ++it)
other.coeffRef(it.index(), col) -= tmp * it.value();
}
}
}
}
};
#endif // EIGEN_TEST_EVALUATORS
} // end namespace internal
template<typename ExpressionType,int Mode>
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const
{
eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@@ -178,21 +332,23 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDer
typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
OtherCopy otherCopy(other.derived());
internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy);
if (copy)
other = otherCopy;
}
template<typename ExpressionType,int Mode>
#ifndef EIGEN_TEST_EVALUATORS
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
TriangularViewImpl<ExpressionType,Mode,Sparse>::solve(const MatrixBase<OtherDerived>& other) const
{
typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
solveInPlace(res);
return res;
}
#endif
// pure sparse path
@@ -290,11 +446,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
} // end namespace internal
template<typename ExpressionType,int Mode>
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
{
eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
// enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@@ -303,7 +459,7 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<Ot
// typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
// OtherCopy otherCopy(other.derived());
internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived());
// if (copy)
// other = otherCopy;