From bb2d70d211a8fc8184b690b75d29ba484edace0e Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Fri, 22 Apr 2011 22:36:45 +0100 Subject: [PATCH] Implement evaluators for ArrayWrapper and MatrixWrapper. --- Eigen/src/Core/ArrayWrapper.h | 12 ++++ Eigen/src/Core/CoreEvaluators.h | 85 ++++++++++++++++++++++- Eigen/src/Core/util/ForwardDeclarations.h | 1 + test/evaluators.cpp | 12 ++++ 4 files changed, 109 insertions(+), 1 deletion(-) diff --git a/Eigen/src/Core/ArrayWrapper.h b/Eigen/src/Core/ArrayWrapper.h index 7ba01de36..6c7e2b198 100644 --- a/Eigen/src/Core/ArrayWrapper.h +++ b/Eigen/src/Core/ArrayWrapper.h @@ -119,6 +119,12 @@ class ArrayWrapper : public ArrayBase > template inline void evalTo(Dest& dst) const { dst = m_expression; } + const typename internal::remove_all::type& + nestedExpression() const + { + return m_expression; + } + protected: const NestedExpressionType m_expression; }; @@ -214,6 +220,12 @@ class MatrixWrapper : public MatrixBase > m_expression.const_cast_derived().template writePacket(index, x); } + const typename internal::remove_all::type& + nestedExpression() const + { + return m_expression; + } + protected: const NestedExpressionType m_expression; }; diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 6b08c78a0..47835f576 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -106,7 +106,7 @@ protected: typename evaluator::type m_argImpl; }; -// -------------------- Matrix and Array-------------------- +// -------------------- Matrix and Array -------------------- // // evaluator_impl is a common base class for the // Matrix and Array evaluators. @@ -704,6 +704,89 @@ protected: }; +// -------------------- MatrixWrapper and ArrayWrapper -------------------- +// +// evaluator_impl_wrapper_base is a common base class for the +// MatrixWrapper and ArrayWrapper evaluators. + +template +struct evaluator_impl_wrapper_base +{ + evaluator_impl_wrapper_base(const ArgType& arg) : m_argImpl(arg) {} + + typedef typename ArgType::Index Index; + typedef typename ArgType::Scalar Scalar; + typedef typename ArgType::CoeffReturnType CoeffReturnType; + typedef typename ArgType::PacketScalar PacketScalar; + typedef typename ArgType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(row, col); + } + + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template + PacketReturnType packet(Index row, Index col) const + { + return m_argImpl.template packet(row, col); + } + + template + PacketReturnType packet(Index index) const + { + return m_argImpl.template packet(index); + } + + template + void writePacket(Index row, Index col, const PacketScalar& x) + { + m_argImpl.template writePacket(row, col, x); + } + + template + void writePacket(Index index, const PacketScalar& x) + { + m_argImpl.template writePacket(index, x); + } + +protected: + typename evaluator::type m_argImpl; +}; + +template +struct evaluator_impl > + : evaluator_impl_wrapper_base +{ + evaluator_impl(const MatrixWrapper& wrapper) + : evaluator_impl_wrapper_base(wrapper.nestedExpression()) + { } +}; + +template +struct evaluator_impl > + : evaluator_impl_wrapper_base +{ + evaluator_impl(const ArrayWrapper& wrapper) + : evaluator_impl_wrapper_base(wrapper.nestedExpression()) + { } +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 7fbccf98c..ce784daed 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -133,6 +133,7 @@ template class WithFormat; template struct CommaInitializer; template class ReturnByValue; template class ArrayWrapper; +template class MatrixWrapper; namespace internal { template struct solve_retval_base; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 4c55736eb..da6b9064b 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -180,4 +180,16 @@ void test_evaluators() VectorXd vec1(6); VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.rowwise().sum()); VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.colwise().sum().transpose()); + + // test MatrixWrapper and ArrayWrapper + mat1.setRandom(6,6); + arr1.setRandom(6,6); + VERIFY_IS_APPROX_EVALUATOR(mat2, arr1.matrix()); + VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array()); + VERIFY_IS_APPROX_EVALUATOR(mat2, (arr1 + 2).matrix()); + VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array() + 2); + mat2.array() = arr1 * arr1; + VERIFY_IS_APPROX(mat2, (arr1 * arr1).matrix()); + arr2.matrix() = MatrixXd::Identity(6,6); + VERIFY_IS_APPROX(arr2, MatrixXd::Identity(6,6).array()); }