Replace Eigen type metaprogramming with corresponding std types and make use of alias templates

This commit is contained in:
Erik Schultheis
2022-03-16 16:43:40 +00:00
committed by Antonio Sánchez
parent 514f90c9ff
commit 421cbf0866
191 changed files with 1147 additions and 1221 deletions

View File

@@ -27,23 +27,21 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
enum {
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
};
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
// Most of the code is assuming that both input tensors are ColMajor. If the
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
// will pretend B is LHS and A is RHS.
typedef typename internal::conditional<
static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
typedef typename internal::conditional<
static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
typedef std::conditional_t<
static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
typedef std::conditional_t<
static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType> EvalRightArgType;
static const int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
@@ -63,8 +61,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typedef DSizes<Index, NumDims> Dimensions;
// typedefs needed in evalTo
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;