Clean implementation of permutation * matrix products.

This commit is contained in:
Gael Guennebaud
2015-06-19 10:51:57 +02:00
parent 06036d8bb1
commit fad36cc814
3 changed files with 108 additions and 196 deletions

View File

@@ -16,25 +16,8 @@ namespace Eigen {
namespace internal {
template<typename PermutationType, typename MatrixType, int Side, bool Transposed>
struct traits<permut_sparsematrix_product_retval<PermutationType, MatrixType, Side, Transposed> >
{
typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
typedef typename MatrixTypeNestedCleaned::Scalar Scalar;
typedef typename MatrixTypeNestedCleaned::StorageIndex StorageIndex;
enum {
SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor,
MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight
};
typedef typename internal::conditional<MoveOuter,
SparseMatrix<Scalar,SrcStorageOrder,StorageIndex>,
SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> >::type ReturnType;
};
template<typename PermutationType, typename MatrixType, int Side, bool Transposed>
struct permut_sparsematrix_product_retval
: public ReturnByValue<permut_sparsematrix_product_retval<PermutationType, MatrixType, Side, Transposed> >
template<typename MatrixType, int Side, bool Transposed>
struct permutation_matrix_product<MatrixType, Side, Transposed, SparseShape>
{
typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
typedef typename MatrixTypeNestedCleaned::Scalar Scalar;
@@ -44,61 +27,55 @@ struct permut_sparsematrix_product_retval
SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor,
MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight
};
typedef typename internal::conditional<MoveOuter,
SparseMatrix<Scalar,SrcStorageOrder,StorageIndex>,
SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> >::type ReturnType;
permut_sparsematrix_product_retval(const PermutationType& perm, const MatrixType& matrix)
: m_permutation(perm), m_matrix(matrix)
{}
inline int rows() const { return m_matrix.rows(); }
inline int cols() const { return m_matrix.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
template<typename Dest,typename PermutationType>
static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat)
{
if(MoveOuter)
{
SparseMatrix<Scalar,SrcStorageOrder,StorageIndex> tmp(m_matrix.rows(), m_matrix.cols());
Matrix<StorageIndex,Dynamic,1> sizes(m_matrix.outerSize());
for(Index j=0; j<m_matrix.outerSize(); ++j)
SparseMatrix<Scalar,SrcStorageOrder,StorageIndex> tmp(mat.rows(), mat.cols());
Matrix<StorageIndex,Dynamic,1> sizes(mat.outerSize());
for(Index j=0; j<mat.outerSize(); ++j)
{
Index jp = m_permutation.indices().coeff(j);
sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = StorageIndex(m_matrix.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).nonZeros());
Index jp = perm.indices().coeff(j);
sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = StorageIndex(mat.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).nonZeros());
}
tmp.reserve(sizes);
for(Index j=0; j<m_matrix.outerSize(); ++j)
for(Index j=0; j<mat.outerSize(); ++j)
{
Index jp = m_permutation.indices().coeff(j);
Index jp = perm.indices().coeff(j);
Index jsrc = ((Side==OnTheRight) ^ Transposed) ? jp : j;
Index jdst = ((Side==OnTheLeft) ^ Transposed) ? jp : j;
for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,jsrc); it; ++it)
for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,jsrc); it; ++it)
tmp.insertByOuterInner(jdst,it.index()) = it.value();
}
dst = tmp;
}
else
{
SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> tmp(m_matrix.rows(), m_matrix.cols());
SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> tmp(mat.rows(), mat.cols());
Matrix<StorageIndex,Dynamic,1> sizes(tmp.outerSize());
sizes.setZero();
PermutationMatrix<Dynamic,Dynamic,StorageIndex> perm;
PermutationMatrix<Dynamic,Dynamic,StorageIndex> perm_cpy;
if((Side==OnTheLeft) ^ Transposed)
perm = m_permutation;
perm_cpy = perm;
else
perm = m_permutation.transpose();
perm_cpy = perm.transpose();
for(Index j=0; j<m_matrix.outerSize(); ++j)
for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,j); it; ++it)
sizes[perm.indices().coeff(it.index())]++;
for(Index j=0; j<mat.outerSize(); ++j)
for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,j); it; ++it)
sizes[perm_cpy.indices().coeff(it.index())]++;
tmp.reserve(sizes);
for(Index j=0; j<m_matrix.outerSize(); ++j)
for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,j); it; ++it)
tmp.insertByOuterInner(perm.indices().coeff(it.index()),j) = it.value();
for(Index j=0; j<mat.outerSize(); ++j)
for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,j); it; ++it)
tmp.insertByOuterInner(perm_cpy.indices().coeff(it.index()),j) = it.value();
dst = tmp;
}
}
protected:
const PermutationType& m_permutation;
typename MatrixType::Nested m_matrix;
};
}
@@ -107,63 +84,17 @@ namespace internal {
template <int ProductTag> struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> { typedef Sparse ret; };
template <int ProductTag> struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> { typedef Sparse ret; };
// TODO, the following need cleaning, this is just a copy-past of the dense case
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, false> pmpr(lhs, rhs);
pmpr.evalTo(dst);
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, false> pmpr(rhs, lhs);
pmpr.evalTo(dst);
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, SparseShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{
permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, true> pmpr(lhs.nestedPermutation(), rhs);
pmpr.evalTo(dst);
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Lhs, Transpose<Rhs>, SparseShape, PermutationShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{
permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, true> pmpr(rhs.nestedPermutation(), lhs);
pmpr.evalTo(dst);
}
};
// TODO, the following two overloads are only needed to define the right temporary type through
// typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType
// while it should be correctly handled by traits<Product<> >::PlainObject
// typename traits<permutation_sparse_matrix_product<Rhs,Lhs,OnTheRight,false> >::ReturnType
// whereas it should be correctly handled by traits<Product<> >::PlainObject
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, PermutationShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
: public evaluator<typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType>::type
: public evaluator<typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType PlainObject;
typedef typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType PlainObject;
typedef typename evaluator<PlainObject>::type Base;
explicit product_evaluator(const XprType& xpr)
@@ -179,10 +110,10 @@ protected:
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, PermutationShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
: public evaluator<typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType>::type
: public evaluator<typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType PlainObject;
typedef typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType PlainObject;
typedef typename evaluator<PlainObject>::type Base;
explicit product_evaluator(const XprType& xpr)