Replace typedef with using in tensor contraction files

libeigen/eigen!2247

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-03-04 08:59:22 -08:00
parent abc3d6014d
commit dd826edb42
3 changed files with 160 additions and 170 deletions

View File

@@ -20,49 +20,49 @@ namespace internal {
template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>> {
// Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename gebp_traits<std::remove_const_t<typename LhsXprType::Scalar>,
std::remove_const_t<typename RhsXprType::Scalar>>::ResScalar Scalar;
using Scalar = typename gebp_traits<std::remove_const_t<typename LhsXprType::Scalar>,
std::remove_const_t<typename RhsXprType::Scalar>>::ResScalar;
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
typedef
typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
typedef typename LhsXprType::Nested LhsNested;
typedef typename RhsXprType::Nested RhsNested;
typedef std::remove_reference_t<LhsNested> LhsNested_;
typedef std::remove_reference_t<RhsNested> RhsNested_;
using StorageKind = typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
typename traits<RhsXprType>::StorageKind>::ret;
using Index =
typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type;
using LhsNested = typename LhsXprType::Nested;
using RhsNested = typename RhsXprType::Nested;
using LhsNested_ = std::remove_reference_t<LhsNested>;
using RhsNested_ = std::remove_reference_t<RhsNested>;
// From NumDims below.
static constexpr int NumDimensions =
traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
static constexpr int Layout = traits<LhsXprType>::Layout;
typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>
PointerType;
using PointerType =
std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>;
enum { Flags = 0 };
};
template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense> {
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type;
using type = const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>&;
};
template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1,
typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>>::type> {
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type;
using type = TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>;
};
template <typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_,
typename Device_>
struct traits<
TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_>> {
typedef Indices_ Indices;
typedef LeftArgType_ LeftArgType;
typedef RightArgType_ RightArgType;
typedef OutputKernelType_ OutputKernelType;
typedef Device_ Device;
using Indices = Indices_;
using LeftArgType = LeftArgType_;
using RightArgType = RightArgType_;
using OutputKernelType = OutputKernelType_;
using Device = Device_;
// From NumDims below.
static constexpr int NumDimensions =
@@ -72,7 +72,7 @@ struct traits<
// Helper class to allocate and deallocate temporary memory for packed buffers.
template <typename LhsScalar, typename RhsScalar>
struct TensorContractionBlockMemAllocator {
typedef void* BlockMemHandle;
using BlockMemHandle = void*;
template <typename Device>
EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm, const Index bk, const Index bn,
@@ -175,25 +175,23 @@ struct TensorContractionKernel {
: m(m_), k(k_), n(n_), bm(bm_), bk(bk_), bn(bn_) {}
// Pack blocks of Lhs and Rhs into contiguous blocks in memory.
typedef LhsScalar* LhsBlock;
typedef RhsScalar* RhsBlock;
using LhsBlock = LhsScalar*;
using RhsBlock = RhsScalar*;
// Packed Lhs/Rhs block memory allocator.
typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> BlockMemAllocator;
typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;
using BlockMemAllocator = TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>;
using BlockMemHandle = typename BlockMemAllocator::BlockMemHandle;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>;
typedef internal::gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
LhsPacker;
using LhsPacker = internal::gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>;
typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
RhsPacker;
using RhsPacker =
internal::gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor>;
typedef internal::gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr,
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false>
GebpKernel;
using GebpKernel = internal::gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr,
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false>;
template <typename Device>
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, RhsBlock* rhs_block) {
@@ -303,12 +301,12 @@ template <typename Indices, typename LhsXprType, typename RhsXprType,
class TensorContractionOp
: public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> {
public:
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType,
typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
using Scalar = typename Eigen::internal::traits<TensorContractionOp>::Scalar;
using CoeffReturnType = typename internal::gebp_traits<typename LhsXprType::CoeffReturnType,
typename RhsXprType::CoeffReturnType>::ResScalar;
using Nested = typename Eigen::internal::nested<TensorContractionOp>::type;
using StorageKind = typename Eigen::internal::traits<TensorContractionOp>::StorageKind;
using Index = typename Eigen::internal::traits<TensorContractionOp>::Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType& lhs, const RhsXprType& rhs,
const Indices& dims,
@@ -337,19 +335,19 @@ class TensorContractionOp
template <typename Derived>
struct TensorContractionEvaluatorBase {
typedef typename internal::traits<Derived>::Indices Indices;
typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
typedef typename internal::traits<Derived>::RightArgType RightArgType;
typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType;
typedef typename internal::traits<Derived>::Device Device;
using Indices = typename internal::traits<Derived>::Indices;
using LeftArgType = typename internal::traits<Derived>::LeftArgType;
using RightArgType = typename internal::traits<Derived>::RightArgType;
using OutputKernelType = typename internal::traits<Derived>::OutputKernelType;
using Device = typename internal::traits<Derived>::Device;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
typedef StorageMemory<Scalar, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
using XprType = TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>;
using Scalar = std::remove_const_t<typename XprType::Scalar>;
using Index = typename XprType::Index;
using CoeffReturnType = typename XprType::CoeffReturnType;
using PacketReturnType = typename PacketType<CoeffReturnType, Device>::type;
using Storage = StorageMemory<Scalar, Device>;
using EvaluatorPointerType = typename Storage::Type;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
enum {
@@ -362,20 +360,20 @@ struct TensorContractionEvaluatorBase {
};
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
typedef internal::TensorBlockNotImplemented TensorBlock;
using TensorBlock = internal::TensorBlockNotImplemented;
//===--------------------------------------------------------------------===//
// Most of the code is assuming that both input tensors are ColMajor. If the
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
// will pretend B is LHS and A is RHS.
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>
EvalLeftArgType;
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>
EvalRightArgType;
using EvalLeftArgType =
std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>;
using EvalRightArgType =
std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluatorType;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluatorType;
using LeftEvaluatorType = TensorEvaluator<EvalLeftArgType, Device>;
using RightEvaluatorType = TensorEvaluator<EvalRightArgType, Device>;
static constexpr int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
@@ -384,11 +382,11 @@ struct TensorContractionEvaluatorBase {
static constexpr int ContractDims = internal::array_size<Indices>::value;
static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
using contract_t = array<Index, ContractDims>;
using left_nocontract_t = array<Index, LDims - ContractDims>;
using right_nocontract_t = array<Index, RDims - ContractDims>;
typedef DSizes<Index, NumDims> Dimensions;
using Dimensions = DSizes<Index, NumDims>;
EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType& op, const Device& device)
: m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), op.lhsExpression(),
@@ -697,23 +695,22 @@ struct TensorContractionEvaluatorBase {
const Index rows = m_i_size;
const Index cols = m_k_size;
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
using LhsScalar = std::remove_const_t<typename EvalLeftArgType::Scalar>;
using RhsScalar = std::remove_const_t<typename EvalRightArgType::Scalar>;
using LeftEvaluator = TensorEvaluator<EvalLeftArgType, Device>;
using RightEvaluator = TensorEvaluator<EvalRightArgType, Device>;
const int lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
const int rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned;
const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned;
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, lhs_packet_size, lhs_inner_dim_contiguous, false,
lhs_alignment>
LhsMapper;
using LhsMapper = internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator,
left_nocontract_t, contract_t, lhs_packet_size,
lhs_inner_dim_contiguous, false, lhs_alignment>;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, rhs_alignment>
RhsMapper;
using RhsMapper =
internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, rhs_alignment>;
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
@@ -727,7 +724,7 @@ struct TensorContractionEvaluatorBase {
internal::general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, false, RhsScalar, RhsMapper,
false>::run(rows, cols, lhs, rhs, buffer, resIncr, alpha);
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
m_output_kernel(OutputMapper(buffer, rows), m_tensor_contraction_params, static_cast<Index>(0),
static_cast<Index>(0), rows, static_cast<Index>(1));
}
@@ -765,29 +762,28 @@ struct TensorContractionEvaluatorBase {
const Index n = this->m_j_size;
// define data mappers for Lhs and Rhs
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
using LhsScalar = std::remove_const_t<typename EvalLeftArgType::Scalar>;
using RhsScalar = std::remove_const_t<typename EvalRightArgType::Scalar>;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
using LeftEvaluator = TensorEvaluator<EvalLeftArgType, Device>;
using RightEvaluator = TensorEvaluator<EvalRightArgType, Device>;
const int lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
const int rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, lhs_packet_size, lhs_inner_dim_contiguous, false,
Unaligned>
LhsMapper;
using LhsMapper =
internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, lhs_packet_size, lhs_inner_dim_contiguous, false, Unaligned>;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Unaligned>
RhsMapper;
using RhsMapper =
internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Unaligned>;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
typedef internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
TensorContractionKernel;
using TensorContractionKernel =
internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>;
// initialize data mappers
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
@@ -805,15 +801,15 @@ struct TensorContractionEvaluatorBase {
const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc());
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
using LhsBlock = typename TensorContractionKernel::LhsBlock;
using RhsBlock = typename TensorContractionKernel::RhsBlock;
LhsBlock blockA;
RhsBlock blockB;
TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc);
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
using BlockMemHandle = typename TensorContractionKernel::BlockMemHandle;
const BlockMemHandle packed_mem = kernel.allocate(this->m_device, &blockA, &blockB);
// If a contraction kernel does not support beta, explicitly initialize
@@ -913,14 +909,14 @@ template <typename Indices, typename LeftArgType, typename RightArgType, typenam
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>
: public TensorContractionEvaluatorBase<
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>> {
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
using Self = TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>;
using Base = TensorContractionEvaluatorBase<Self>;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
using XprType = TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>;
using Scalar = std::remove_const_t<typename XprType::Scalar>;
using Index = typename XprType::Index;
using CoeffReturnType = typename XprType::CoeffReturnType;
using PacketReturnType = typename PacketType<CoeffReturnType, Device>::type;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
@@ -928,8 +924,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
// will pretend B is LHS and A is RHS.
typedef std::conditional_t<Layout == static_cast<int>(ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
typedef std::conditional_t<Layout == static_cast<int>(ColMajor), RightArgType, LeftArgType> EvalRightArgType;
using EvalLeftArgType = std::conditional_t<Layout == static_cast<int>(ColMajor), LeftArgType, RightArgType>;
using EvalRightArgType = std::conditional_t<Layout == static_cast<int>(ColMajor), RightArgType, LeftArgType>;
static constexpr int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
@@ -937,14 +933,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static constexpr int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
using contract_t = array<Index, ContractDims>;
using left_nocontract_t = array<Index, LDims - ContractDims>;
using right_nocontract_t = array<Index, RDims - ContractDims>;
static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
// Could we use NumDimensions here?
typedef DSizes<Index, NumDims> Dimensions;
using Dimensions = DSizes<Index, NumDims>;
TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {}

View File

@@ -82,7 +82,7 @@ struct CoeffLoader<Tensor, true, MakePointer_> {
}
private:
typedef typename Tensor::Scalar Scalar;
using Scalar = typename Tensor::Scalar;
typename MakePointer_<const Scalar>::Type m_data;
};
@@ -243,9 +243,8 @@ class BaseTensorContractionMapper
: public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, Alignment, MakePointer_> {
public:
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, Alignment, MakePointer_>
ParentMapper;
using ParentMapper = SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, Alignment, MakePointer_>;
EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides, const contract_t& contract_strides,
@@ -330,9 +329,8 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
: public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
inner_dim_contiguous, Alignment, MakePointer_> {
public:
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
Alignment, MakePointer_>
ParentMapper;
using ParentMapper = SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
inner_dim_contiguous, Alignment, MakePointer_>;
EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides, const contract_t& contract_strides,
@@ -358,14 +356,12 @@ template <typename Scalar, typename Index, int side, typename Tensor, typename n
template <class> class MakePointer_ = MakePointer>
class TensorContractionSubMapper {
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
Self;
typedef Self LinearMapper;
typedef Self SubMapper;
using ParentMapper = BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>;
using Self = TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>;
using LinearMapper = Self;
using SubMapper = Self;
enum {
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
@@ -485,15 +481,13 @@ class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
public:
typedef Scalar_ Scalar;
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
SubMapper;
typedef SubMapper VectorMapper;
typedef SubMapper LinearMapper;
using Scalar = Scalar_;
using Base = BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>;
using SubMapper = TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>;
using VectorMapper = SubMapper;
using LinearMapper = SubMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides, const contract_t& contract_strides,
@@ -526,7 +520,7 @@ template <typename Scalar_, typename Index_, int side_, typename Tensor_, typena
struct TensorContractionInputMapperTrait<
TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_,
inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_> > {
typedef Tensor_ XprType;
using XprType = Tensor_;
static const bool inner_dim_contiguous = inner_dim_contiguous_;
static const bool inner_dim_reordered = inner_dim_reordered_;
};

View File

@@ -23,16 +23,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
ThreadPoolDevice>
: public TensorContractionEvaluatorBase<TensorEvaluator<
const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice>> {
typedef ThreadPoolDevice Device;
using Device = ThreadPoolDevice;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
using Self = TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>;
using Base = TensorContractionEvaluatorBase<Self>;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
using XprType = TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>;
using Scalar = std::remove_const_t<typename XprType::Scalar>;
using Index = typename XprType::Index;
using CoeffReturnType = typename XprType::CoeffReturnType;
using PacketReturnType = typename PacketType<CoeffReturnType, Device>::type;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
@@ -40,10 +40,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
// will pretend B is LHS and A is RHS.
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>
EvalLeftArgType;
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>
EvalRightArgType;
using EvalLeftArgType =
std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>;
using EvalRightArgType =
std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>;
static constexpr int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
@@ -51,24 +51,24 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static constexpr int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, LDims> left_dim_mapper_t;
typedef array<Index, RDims> right_dim_mapper_t;
using left_dim_mapper_t = array<Index, LDims>;
using right_dim_mapper_t = array<Index, RDims>;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
using contract_t = array<Index, ContractDims>;
using left_nocontract_t = array<Index, LDims - ContractDims>;
using right_nocontract_t = array<Index, RDims - ContractDims>;
static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
typedef DSizes<Index, NumDims> Dimensions;
using Dimensions = DSizes<Index, NumDims>;
// typedefs needed in evalTo
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
using LhsScalar = std::remove_const_t<typename EvalLeftArgType::Scalar>;
using RhsScalar = std::remove_const_t<typename EvalRightArgType::Scalar>;
using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
using LeftEvaluator = TensorEvaluator<EvalLeftArgType, Device>;
using RightEvaluator = TensorEvaluator<EvalRightArgType, Device>;
TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {}
@@ -335,23 +335,23 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
bool rhs_inner_dim_reordered, int Alignment>
class EvalParallelContext {
public:
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
using LhsMapper =
internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>;
using RhsMapper =
internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
typedef internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
TensorContractionKernel;
using TensorContractionKernel =
internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>;
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
using LhsBlock = typename TensorContractionKernel::LhsBlock;
using RhsBlock = typename TensorContractionKernel::RhsBlock;
using BlockMemHandle = typename TensorContractionKernel::BlockMemHandle;
EvalParallelContext(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0,
@@ -1195,7 +1195,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
void applyOutputKernel() const {
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
evaluator->m_output_kernel(OutputMapper(result, m), evaluator->m_tensor_contraction_params,
static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
}