Make cross product uses nested/nested_eval

This commit is contained in:
Gael Guennebaud
2014-08-01 14:47:33 +02:00
parent 26d2cdefd4
commit fc13b37c55
2 changed files with 33 additions and 6 deletions

View File

@@ -30,8 +30,13 @@ MatrixBase<Derived>::cross(const MatrixBase<OtherDerived>& other) const
// Note that there is no need for an expression here since the compiler
// optimize such a small temporary very well (even within a complex expression)
#ifndef EIGEN_TEST_EVALUATORS
typename internal::nested<Derived,2>::type lhs(derived());
typename internal::nested<OtherDerived,2>::type rhs(other.derived());
#else
typename internal::nested_eval<Derived,2>::type lhs(derived());
typename internal::nested_eval<OtherDerived,2>::type rhs(other.derived());
#endif
return typename cross_product_return_type<OtherDerived>::type(
numext::conj(lhs.coeff(1) * rhs.coeff(2) - lhs.coeff(2) * rhs.coeff(1)),
numext::conj(lhs.coeff(2) * rhs.coeff(0) - lhs.coeff(0) * rhs.coeff(2)),
@@ -76,8 +81,13 @@ MatrixBase<Derived>::cross3(const MatrixBase<OtherDerived>& other) const
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(Derived,4)
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(OtherDerived,4)
#ifndef EIGEN_TEST_EVALUATORS
typedef typename internal::nested<Derived,2>::type DerivedNested;
typedef typename internal::nested<OtherDerived,2>::type OtherDerivedNested;
#else
typedef typename internal::nested_eval<Derived,2>::type DerivedNested;
typedef typename internal::nested_eval<OtherDerived,2>::type OtherDerivedNested;
#endif
DerivedNested lhs(derived());
OtherDerivedNested rhs(other.derived());
@@ -103,21 +113,29 @@ VectorwiseOp<ExpressionType,Direction>::cross(const MatrixBase<OtherDerived>& ot
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(OtherDerived,3)
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
#ifndef EIGEN_TEST_EVALUATORS
typename internal::nested<ExpressionType,2>::type mat(_expression());
typename internal::nested<OtherDerived,2>::type vec(other.derived());
#else
typename internal::nested_eval<ExpressionType,2>::type mat(_expression());
typename internal::nested_eval<OtherDerived,2>::type vec(other.derived());
#endif
CrossReturnType res(_expression().rows(),_expression().cols());
if(Direction==Vertical)
{
eigen_assert(CrossReturnType::RowsAtCompileTime==3 && "the matrix must have exactly 3 rows");
res.row(0) = (_expression().row(1) * other.coeff(2) - _expression().row(2) * other.coeff(1)).conjugate();
res.row(1) = (_expression().row(2) * other.coeff(0) - _expression().row(0) * other.coeff(2)).conjugate();
res.row(2) = (_expression().row(0) * other.coeff(1) - _expression().row(1) * other.coeff(0)).conjugate();
res.row(0) = (mat.row(1) * vec.coeff(2) - mat.row(2) * vec.coeff(1)).conjugate();
res.row(1) = (mat.row(2) * vec.coeff(0) - mat.row(0) * vec.coeff(2)).conjugate();
res.row(2) = (mat.row(0) * vec.coeff(1) - mat.row(1) * vec.coeff(0)).conjugate();
}
else
{
eigen_assert(CrossReturnType::ColsAtCompileTime==3 && "the matrix must have exactly 3 columns");
res.col(0) = (_expression().col(1) * other.coeff(2) - _expression().col(2) * other.coeff(1)).conjugate();
res.col(1) = (_expression().col(2) * other.coeff(0) - _expression().col(0) * other.coeff(2)).conjugate();
res.col(2) = (_expression().col(0) * other.coeff(1) - _expression().col(1) * other.coeff(0)).conjugate();
res.col(0) = (mat.col(1) * vec.coeff(2) - mat.col(2) * vec.coeff(1)).conjugate();
res.col(1) = (mat.col(2) * vec.coeff(0) - mat.col(0) * vec.coeff(2)).conjugate();
res.col(2) = (mat.col(0) * vec.coeff(1) - mat.col(1) * vec.coeff(0)).conjugate();
}
return res;
}