From 842e31cf5c8fd31f394156ada84a1aeeab89ef7e Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 29 Sep 2014 13:37:49 +0200 Subject: [PATCH] Let KroneckerProduct exploits the recently introduced generic InnerIterator class. --- .../KroneckerProduct/KroneckerTensorProduct.h | 33 ++++++------------- unsupported/test/kronecker_product.cpp | 12 +++++++ 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h index 608c72021..b459360df 100644 --- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h +++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h @@ -157,40 +157,27 @@ void KroneckerProductSparse::evalTo(Dest& dst) const dst.resizeNonZeros(0); // 1 - evaluate the operands if needed: - typedef typename internal::nested_eval::type Lhs1; + typedef typename internal::nested_eval::type Lhs1; typedef typename internal::remove_all::type Lhs1Cleaned; const Lhs1 lhs1(m_A); - typedef typename internal::nested_eval::type Rhs1; + typedef typename internal::nested_eval::type Rhs1; typedef typename internal::remove_all::type Rhs1Cleaned; const Rhs1 rhs1(m_B); - - // 2 - construct a SparseView for dense operands - typedef typename internal::conditional::StorageKind,Sparse>::value, Lhs1, SparseView >::type Lhs2; - typedef typename internal::remove_all::type Lhs2Cleaned; - const Lhs2 lhs2(lhs1); - typedef typename internal::conditional::StorageKind,Sparse>::value, Rhs1, SparseView >::type Rhs2; - typedef typename internal::remove_all::type Rhs2Cleaned; - const Rhs2 rhs2(rhs1); - - // 3 - construct respective evaluators - typedef typename internal::evaluator::type LhsEval; - LhsEval lhsEval(lhs2); - typedef typename internal::evaluator::type RhsEval; - RhsEval rhsEval(rhs2); - - typedef typename LhsEval::InnerIterator LhsInnerIterator; - typedef typename RhsEval::InnerIterator RhsInnerIterator; + + // 2 - construct respective iterators + typedef InnerIterator LhsInnerIterator; + typedef InnerIterator RhsInnerIterator; // compute number of non-zeros per innervectors of dst { VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols()); for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA) - for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA) + for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++; VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols()); for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB) - for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB) + for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++; Matrix nnzAB = nnzB * nnzA.transpose(); @@ -201,9 +188,9 @@ void KroneckerProductSparse::evalTo(Dest& dst) const { for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB) { - for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA) + for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) { - for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB) + for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) { const DestIndex i = DestIndex(itA.row() * Br + itB.row()), diff --git a/unsupported/test/kronecker_product.cpp b/unsupported/test/kronecker_product.cpp index 753a2d417..02411a262 100644 --- a/unsupported/test/kronecker_product.cpp +++ b/unsupported/test/kronecker_product.cpp @@ -216,5 +216,17 @@ void test_kronecker_product() sC2 = kroneckerProduct(sA,sB); dC = kroneckerProduct(dA,dB); VERIFY_IS_APPROX(MatrixXf(sC2),dC); + + sC2 = kroneckerProduct(dA,sB); + dC = kroneckerProduct(dA,dB); + VERIFY_IS_APPROX(MatrixXf(sC2),dC); + + sC2 = kroneckerProduct(sA,dB); + dC = kroneckerProduct(dA,dB); + VERIFY_IS_APPROX(MatrixXf(sC2),dC); + + sC2 = kroneckerProduct(2*sA,sB); + dC = kroneckerProduct(2*dA,dB); + VERIFY_IS_APPROX(MatrixXf(sC2),dC); } }