Implement scalar multiples and division by a scalar as a binary-expression with a constant expression.

This slightly complexifies the type of the expressions and implies that we now have to distinguish between scalar*expr and expr*scalar to catch scalar-multiple expression (e.g., see BlasUtil.h), but this brings several advantages:
- it makes it clear on each side the scalar is applied,
- it clearly reflects that we are dealing with a binary-expression,
- the complexity of the type is hidden through macros defined at the end of Macros.h,
- distinguishing between "scalar op expr" and "expr op scalar" is important to support non commutative fields (like quaternions)
- "scalar op expr" is now fully equivalent to "ConstantExpr(scalar) op expr"
- scalar_multiple_op, scalar_quotient1_op and scalar_quotient2_op are not used anymore in officially supported modules (still used in Tensor)
This commit is contained in:
Gael Guennebaud
2016-06-14 11:26:57 +02:00
parent 39781dc1e2
commit 64fcfd314f
12 changed files with 146 additions and 99 deletions

View File

@@ -71,18 +71,17 @@ class DiagonalBase : public EigenBase<Derived>
return InverseReturnType(diagonal().cwiseInverse());
}
typedef DiagonalWrapper<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const DiagonalVectorType> > ScalarMultipleReturnType;
EIGEN_DEVICE_FUNC
inline const ScalarMultipleReturnType
inline const DiagonalWrapper<const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DiagonalVectorType,Scalar,product) >
operator*(const Scalar& scalar) const
{
return ScalarMultipleReturnType(diagonal() * scalar);
return DiagonalWrapper<const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DiagonalVectorType,Scalar,product) >(diagonal() * scalar);
}
EIGEN_DEVICE_FUNC
friend inline const ScalarMultipleReturnType
friend inline const DiagonalWrapper<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,DiagonalVectorType,product) >
operator*(const Scalar& scalar, const DiagonalBase& other)
{
return ScalarMultipleReturnType(other.diagonal() * scalar);
return DiagonalWrapper<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,DiagonalVectorType,product) >(scalar * other.diagonal());
}
};

View File

@@ -35,22 +35,28 @@ struct evaluator<Product<Lhs, Rhs, Options> >
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
};
// Catch scalar * ( A * B ) and transform it to (A*scalar) * B
// Catch "scalar * ( A * B )" and transform it to "(A*scalar) * B"
// TODO we should apply that rule only if that's really helpful
template<typename Lhs, typename Rhs, typename Scalar>
struct evaluator_assume_aliasing<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
template<typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
const Product<Lhs, Rhs, DefaultProduct> > >
{
static const bool value = true;
};
template<typename Lhs, typename Rhs, typename Scalar>
struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
: public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> >
template<typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
struct evaluator<CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
const Product<Lhs, Rhs, DefaultProduct> > >
: public evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1,Lhs,product), Rhs, DefaultProduct> >
{
typedef CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > XprType;
typedef evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> > Base;
typedef CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
const Product<Lhs, Rhs, DefaultProduct> > XprType;
typedef evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1,Lhs,product), Rhs, DefaultProduct> > Base;
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
: Base(xpr.functor().m_other * xpr.nestedExpression().lhs() * xpr.nestedExpression().rhs())
: Base(xpr.lhs().functor().m_other * xpr.rhs().lhs() * xpr.rhs().rhs())
{}
};
@@ -171,16 +177,17 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::sub_assign_op<
// Dense ?= scalar * Product
// TODO we should apply that rule if that's really helpful
// for instance, this is not good for inner products
template< typename DstXprType, typename Lhs, typename Rhs, typename AssignFunc, typename Scalar, typename ScalarBis>
struct Assignment<DstXprType, CwiseUnaryOp<internal::scalar_multiple_op<ScalarBis>,
template< typename DstXprType, typename Lhs, typename Rhs, typename AssignFunc, typename Scalar, typename ScalarBis, typename Plain>
struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_product_op<ScalarBis,Scalar>, const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>,Plain>,
const Product<Lhs,Rhs,DefaultProduct> >, AssignFunc, Dense2Dense, Scalar>
{
typedef CwiseUnaryOp<internal::scalar_multiple_op<ScalarBis>,
const Product<Lhs,Rhs,DefaultProduct> > SrcXprType;
typedef CwiseBinaryOp<internal::scalar_product_op<ScalarBis,Scalar>,
const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>,Plain>,
const Product<Lhs,Rhs,DefaultProduct> > SrcXprType;
static EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const AssignFunc& func)
{
call_assignment_no_alias(dst, (src.functor().m_other * src.nestedExpression().lhs())*src.nestedExpression().rhs(), func);
call_assignment_no_alias(dst, (src.lhs().functor().m_other * src.rhs().lhs())*src.rhs().rhs(), func);
}
};

View File

@@ -129,7 +129,7 @@ template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
}
friend EIGEN_DEVICE_FUNC
const SelfAdjointView<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,MatrixType>,UpLo>
const SelfAdjointView<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,MatrixType,product),UpLo>
operator*(const Scalar& s, const SelfAdjointView& mat)
{
return (s*mat.nestedExpression()).template selfadjointView<UpLo>();

View File

@@ -293,6 +293,30 @@ struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
};
// pop scalar multiple
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
: blas_traits<NestedXpr>
{
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType;
static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
};
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
: blas_traits<NestedXpr>
{
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
typedef typename Base::ExtractType ExtractType;
static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
};
// pop scalar multiple (using deprecated scalar_multiple_op)
template<typename Scalar, typename NestedXpr>
struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
: blas_traits<NestedXpr>

View File

@@ -201,8 +201,6 @@ template<typename Scalar> struct scalar_inverse_op;
template<typename Scalar> struct scalar_square_op;
template<typename Scalar> struct scalar_cube_op;
template<typename Scalar, typename NewType> struct scalar_cast_op;
template<typename Scalar> struct scalar_multiple_op;
template<typename Scalar> struct scalar_quotient1_op;
template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_add_op;
template<typename Scalar> struct scalar_constant_op;

View File

@@ -895,6 +895,51 @@ namespace Eigen {
return EIGEN_CWISE_BINARY_RETURN_TYPE(Derived,OtherDerived,OPNAME)(derived(), other.derived()); \
}
#define EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(EXPR,SCALAR,OPNAME) \
CwiseBinaryOp<EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<typename internal::traits<EXPR>::Scalar,SCALAR>, const EXPR, \
const typename internal::plain_constant_type<EXPR,SCALAR>::type>
#define EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(SCALAR,EXPR,OPNAME) \
CwiseBinaryOp<EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<SCALAR,typename internal::traits<EXPR>::Scalar>, \
const typename internal::plain_constant_type<EXPR,SCALAR>::type, const EXPR>
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) \
EIGEN_DEVICE_FUNC inline \
const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME) \
(METHOD)(const Scalar& scalar) const { \
return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME)(derived(), \
typename internal::plain_constant_type<Derived,Scalar>::type(derived().rows(), derived().cols(), scalar)); \
} \
\
template <typename T> EIGEN_DEVICE_FUNC inline \
typename internal::enable_if<ScalarBinaryOpTraits<Scalar,T>::Defined, \
const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME) >::type \
(METHOD)(const T& scalar) const { \
return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME)(derived(), \
typename internal::plain_constant_type<Derived,T>::type(derived().rows(), derived().cols(), scalar)); \
}
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \
EIGEN_DEVICE_FUNC inline friend \
const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME) \
(METHOD)(const Scalar& scalar, const StorageBaseType& matrix) { \
return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME)( \
typename internal::plain_constant_type<Derived,Scalar>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \
} \
\
template <typename T> EIGEN_DEVICE_FUNC inline friend \
typename internal::enable_if<ScalarBinaryOpTraits<T,Scalar>::Defined, \
const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME) >::type \
(METHOD)(const T& scalar, const StorageBaseType& matrix) { \
return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME)( \
typename internal::plain_constant_type<Derived,T>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \
}
#define EIGEN_MAKE_SCALAR_BINARY_OP(METHOD,OPNAME) \
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME)
#ifdef EIGEN_EXCEPTIONS
# define EIGEN_THROW_X(X) throw X
# define EIGEN_THROW throw

View File

@@ -576,6 +576,20 @@ struct plain_diag_type
>::type type;
};
template<typename Expr,typename Scalar = typename Expr::Scalar>
struct plain_constant_type
{
enum { Options = (traits<Expr>::Flags&RowMajorBit)?RowMajor:0 };
typedef Array<Scalar, traits<Expr>::RowsAtCompileTime, traits<Expr>::ColsAtCompileTime,
Options, traits<Expr>::MaxRowsAtCompileTime,traits<Expr>::MaxColsAtCompileTime> array_type;
typedef Matrix<Scalar, traits<Expr>::RowsAtCompileTime, traits<Expr>::ColsAtCompileTime,
Options, traits<Expr>::MaxRowsAtCompileTime,traits<Expr>::MaxColsAtCompileTime> matrix_type;
typedef CwiseNullaryOp<scalar_constant_op<Scalar>, const typename conditional<is_same< typename traits<Expr>::XprKind, MatrixXpr >::value, matrix_type, array_type>::type > type;
};
template<typename ExpressionType>
struct is_lvalue
{