Modernize tensor contraction code: bug fixes, dead code removal, and cleanup

libeigen/eigen!2248

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-03-29 18:03:06 -07:00
parent 732ebc8cc2
commit 09581fda38
5 changed files with 133 additions and 156 deletions

View File

@@ -40,7 +40,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern
std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>;
enum { Flags = 0 };
static constexpr int Flags = 0;
};
template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
@@ -168,7 +168,7 @@ template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename S
struct TensorContractionKernel {
// True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C`
// (otherwise beta should be always equal to 1).
enum { HasBeta = false };
static constexpr bool HasBeta = false;
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_,
StorageIndex bk_, StorageIndex bn_)
@@ -247,8 +247,59 @@ struct TensorContractionKernel {
const StorageIndex bn;
};
// Dispatches a contraction operation over all 8 combinations of the three
// runtime boolean parameters (lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
// rhs_inner_dim_reordered), passing them as compile-time bool_constant
// tags to the callable `fn`.
template <typename Func>
EIGEN_STRONG_INLINE void tensor_contraction_dispatch(Func&& fn, bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered) {
if (lhs_inner_dim_contiguous) {
if (rhs_inner_dim_contiguous) {
if (rhs_inner_dim_reordered)
fn(bool_constant<true>{}, bool_constant<true>{}, bool_constant<true>{});
else
fn(bool_constant<true>{}, bool_constant<true>{}, bool_constant<false>{});
} else {
if (rhs_inner_dim_reordered)
fn(bool_constant<true>{}, bool_constant<false>{}, bool_constant<true>{});
else
fn(bool_constant<true>{}, bool_constant<false>{}, bool_constant<false>{});
}
} else {
if (rhs_inner_dim_contiguous) {
if (rhs_inner_dim_reordered)
fn(bool_constant<false>{}, bool_constant<true>{}, bool_constant<true>{});
else
fn(bool_constant<false>{}, bool_constant<true>{}, bool_constant<false>{});
} else {
if (rhs_inner_dim_reordered)
fn(bool_constant<false>{}, bool_constant<false>{}, bool_constant<true>{});
else
fn(bool_constant<false>{}, bool_constant<false>{}, bool_constant<false>{});
}
}
}
} // end namespace internal
// Legacy macros kept for backward compatibility with code that overrides them
// (e.g. TensorFlow Lite restricts template instantiations for binary size).
// New Eigen code should use internal::tensor_contraction_dispatch() instead.
#ifndef TENSOR_CONTRACTION_DISPATCH
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
::Eigen::internal::tensor_contraction_dispatch( \
[&](auto lhs_c, auto rhs_c, auto rhs_r) { METHOD<lhs_c(), rhs_c(), rhs_r(), ALIGNMENT> ARGS; }, \
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered)
#endif
#ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
::Eigen::internal::tensor_contraction_dispatch( \
[&](auto lhs_c, auto rhs_c, auto rhs_r) { (new METHOD<DONE, lhs_c(), rhs_c(), rhs_r(), ALIGNMENT> ARGS)->FN; }, \
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered)
#endif
// Tensor contraction params that should enable to get from output matrix
// 2-dimensional coordinates to the output tensor dimensions.
struct TensorContractionParams {
@@ -281,16 +332,8 @@ struct NoOpOutputKernel {
* \param[in] num_cols Number of available columns
*/
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
const TensorContractionParams& params, Index i, Index j, Index num_rows,
Index num_cols) const {
EIGEN_UNUSED_VARIABLE(output_mapper);
EIGEN_UNUSED_VARIABLE(params);
EIGEN_UNUSED_VARIABLE(i);
EIGEN_UNUSED_VARIABLE(j);
EIGEN_UNUSED_VARIABLE(num_rows);
EIGEN_UNUSED_VARIABLE(num_cols);
}
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper<Scalar, Index, ColMajor>&,
const TensorContractionParams&, Index, Index, Index, Index) const {}
};
/** Tensor contraction class.
@@ -350,14 +393,12 @@ struct TensorContractionEvaluatorBase {
using EvaluatorPointerType = typename Storage::Type;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
enum {
IsAligned = true,
PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
BlockAccess = false,
PreferBlockAccess = false,
CoordAccess = false, // to be implemented
RawAccess = true
};
static constexpr bool IsAligned = true;
static constexpr bool PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1);
static constexpr bool BlockAccess = false;
static constexpr bool PreferBlockAccess = false;
static constexpr bool CoordAccess = false; // to be implemented
static constexpr bool RawAccess = true;
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
using TensorBlock = internal::TensorBlockNotImplemented;
@@ -397,7 +438,7 @@ struct TensorContractionEvaluatorBase {
device),
m_device(device),
m_output_kernel(op.outputKernel()),
m_result(NULL) {
m_result(nullptr) {
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
YOU_MADE_A_PROGRAMMING_MISTAKE);
@@ -581,8 +622,8 @@ struct TensorContractionEvaluatorBase {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
m_leftImpl.evalSubExprsIfNeeded(NULL);
m_rightImpl.evalSubExprsIfNeeded(NULL);
m_leftImpl.evalSubExprsIfNeeded(nullptr);
m_rightImpl.evalSubExprsIfNeeded(nullptr);
if (data) {
evalTo(data);
return false;
@@ -609,72 +650,6 @@ struct TensorContractionEvaluatorBase {
}
#endif // EIGEN_USE_THREADS
#ifndef TENSOR_CONTRACTION_DISPATCH
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
if (this->m_lhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<true, true, true, ALIGNMENT> ARGS; \
} else { \
METHOD<true, true, false, ALIGNMENT> ARGS; \
} \
} else { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<true, false, true, ALIGNMENT> ARGS; \
} else { \
METHOD<true, false, false, ALIGNMENT> ARGS; \
} \
} \
} else { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, true, true, ALIGNMENT> ARGS; \
} else { \
METHOD<false, true, false, ALIGNMENT> ARGS; \
} \
} else { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, false, true, ALIGNMENT> ARGS; \
} else { \
METHOD<false, false, false, ALIGNMENT> ARGS; \
} \
} \
}
#endif
#ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
if (this->m_lhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
(new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
} else { \
(new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
} \
} else { \
if (this->m_rhs_inner_dim_reordered) { \
(new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
} else { \
(new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
} \
} \
} else { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
(new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
} else { \
(new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
} \
} else { \
if (this->m_rhs_inner_dim_reordered) { \
(new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
} else { \
(new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
} \
} \
}
#endif
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
}
@@ -867,9 +842,9 @@ struct TensorContractionEvaluatorBase {
m_leftImpl.cleanup();
m_rightImpl.cleanup();
if (m_result != NULL) {
if (m_result != nullptr) {
m_device.deallocate(m_result);
m_result = NULL;
m_result = nullptr;
}
}
@@ -957,7 +932,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void evalProduct(Scalar* buffer) const {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Alignment>(buffer);
},
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
}
};

View File

@@ -16,7 +16,8 @@
namespace Eigen {
namespace internal {
enum { ShardByRow = 0, ShardByCol = 1 };
constexpr int ShardByRow = 0;
constexpr int ShardByCol = 1;
// Default Blocking Strategy
template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex,

View File

@@ -17,7 +17,8 @@ namespace Eigen {
namespace internal {
enum { Rhs = 0, Lhs = 1 };
constexpr int Rhs = 0;
constexpr int Lhs = 1;
/*
* Implementation of the Eigen blas_data_mapper class for tensors.
@@ -34,7 +35,7 @@ class BaseTensorContractionMapper;
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
struct CoeffLoader {
enum { DirectOffsets = false };
static constexpr bool DirectOffsets = false;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) {}
@@ -44,7 +45,7 @@ struct CoeffLoader {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
eigen_assert(false && "unsupported");
return NULL;
return nullptr;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
@@ -62,7 +63,7 @@ struct CoeffLoader {
template <typename Tensor, template <class> class MakePointer_>
struct CoeffLoader<Tensor, true, MakePointer_> {
enum { DirectOffsets = true };
static constexpr bool DirectOffsets = true;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
@@ -100,7 +101,7 @@ class SimpleTensorContractionMapper {
m_contract_strides(contract_strides),
m_k_strides(k_strides) {}
enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
static constexpr bool DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
m_tensor.offsetBuffer(offset);
@@ -123,7 +124,6 @@ class SimpleTensorContractionMapper {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
const bool left = (side == Lhs);
EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
Index nocontract_val = left ? row : col;
Index linidx = 0;
EIGEN_UNROLL_LOOP
@@ -164,7 +164,6 @@ class SimpleTensorContractionMapper {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
const Index distance) const {
const bool left = (side == Lhs);
EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
Index linidx[2] = {0, 0};
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
@@ -363,12 +362,10 @@ class TensorContractionSubMapper {
using LinearMapper = Self;
using SubMapper = Self;
enum {
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
// TODO: we should also enable direct offsets for the Rhs case.
UseDirectOffsets =
ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
};
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
// TODO: we should also enable direct offsets for the Rhs case.
static constexpr bool UseDirectOffsets =
ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0);
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
@@ -429,8 +426,9 @@ class TensorContractionSubMapper {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
if (UseDirectOffsets) {
m_base_mapper.storePacket(i, 0, p);
} else {
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
@@ -451,7 +449,7 @@ class TensorContractionSubMapper {
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
static_assert(std::is_same<PacketT, PacketT>::value, "YOU_MADE_A_PROGRAMMING_MISTAKE");
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
if (UseDirectOffsets) {
return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);

View File

@@ -99,7 +99,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
// and temporary buffers, required for executing the tensor contraction.
// They are responsible for cleaning it up after contraction is done.
static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
static constexpr bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
const Index m = this->m_i_size;
const Index n = this->m_j_size;
@@ -176,7 +176,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (n == 1) num_threads = 1;
if (num_threads == 1) {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer));
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Unaligned>(buffer);
},
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
if (!IsEvalInSyncMode) done();
return;
}
@@ -258,22 +262,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// optimization.
if (parallelize_by_sharding_dim_only) parallel_pack = false;
// TODO(ezhulnev): With if constexpr we don't need SyncEvalParallelContext.
if (IsEvalInSyncMode) {
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
parallelize_by_sharding_dim_only, NoCallback()) \
.run()
TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
#undef CONTEXT_ARGS
} else {
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
parallelize_by_sharding_dim_only, std::move(done))
TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
#undef CONTEXT_ARGS
}
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
EIGEN_IF_CONSTEXPR(IsEvalInSyncMode) {
EvalParallelContext<NoCallback, lhs_c(), rhs_c(), rhs_r(), Alignment> ctx(
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
parallel_pack, parallelize_by_sharding_dim_only, NoCallback());
ctx.run();
}
else {
auto* ctx = new EvalParallelContext<DoneCallback, lhs_c(), rhs_c(), rhs_r(), Alignment>(
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
parallel_pack, parallelize_by_sharding_dim_only, std::move(done));
ctx->run();
}
},
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
}
// ------------------------------------------------------------------------ //
@@ -432,11 +436,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gn_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/0, //
/*num_rhs=*/num_blocks, //
/*num_slices=*/1, //
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/0, //
/*num_rhs=*/num_blocks, //
/*num_slices=*/1, //
/*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
} else {
@@ -444,11 +448,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gm_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/num_blocks, //
/*num_rhs=*/0, //
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/num_blocks, //
/*num_rhs=*/0, //
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
/*rhs_blocks=*/nullptr);
}
}
@@ -461,7 +465,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
kernel_.deallocate(device_, packed_mem_);
if (parallelize_by_sharding_dim_only_) {
kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
kernel_.deallocate(device_, thread_local_pre_allocated_mem_);
delete[] can_use_thread_local_packed_;
}
}
@@ -585,7 +589,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// different from the thread that was used for packing.
// Handle for pre-allocated thread local memory buffers.
BlockMemHandle thread_local_pre_alocated_mem_;
BlockMemHandle thread_local_pre_allocated_mem_;
// Only one of these will be initialized depending on shard_by_col value
// (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
@@ -648,7 +652,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <typename BlockType, bool is_rhs>
class ThreadLocalBlocksInitialize {
static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
static constexpr bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
static_assert(kIsLhs || kIsRhs, "Unknown block type");
using Blocks = ThreadLocalBlocks<BlockType>;
@@ -668,10 +672,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
private:
// NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
// TensorContractionKernel::allocateSlices into template specializations.
// Also explicit specializations are not allowed at class scope in C++03,
// EvalCtx type parameter is just a workaround for that limitation.
// Explicit specializations are not allowed at class scope, so EvalCtx is
// a dummy template parameter to make these partial specializations.
template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
struct ThreadLocalBlocksAllocator;
@@ -684,7 +686,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
/*num_rhs=*/ctx.gn_,
/*num_slices=*/1,
/*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
}
@@ -703,7 +704,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
/*num_rhs=*/0,
/*num_slices=*/1,
/*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
}
@@ -1015,10 +1015,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
void operator=(const EvalParallelContext&) = delete;
};
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>;
// ------------------------------------------------------------------------ //
// EvalShardedByInnerDimContext orchestrates sync/async contraction
@@ -1086,11 +1082,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
private:
// The underlying GEMM kernel assumes that k is a multiple of
// the packet size and subtle breakage occurs if this is violated.
static const Index packet_size = internal::packet_traits<RhsScalar>::size;
static constexpr Index packet_size = internal::packet_traits<RhsScalar>::size;
const Self* evaluator; // TensorContraction evaluator
// These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
// These fields cache values from the evaluator for use in processBlock dispatch.
bool m_lhs_inner_dim_contiguous;
bool m_rhs_inner_dim_contiguous;
bool m_rhs_inner_dim_reordered;
@@ -1131,7 +1127,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
//
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
// sense only if number of threads >= ~128?
static const Index l0_size = 4;
static constexpr Index l0_size = 4;
Index l0_ranges;
// Keep count of pending gemm tasks for each l0 range.
@@ -1144,9 +1140,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
(buf, begin, end,
/*num_threads=*/internal::convert_index<int>(num_blocks)));
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
evaluator->template evalGemmPartialWithoutOutputKernel<lhs_c(), rhs_c(), rhs_r(), Alignment>(
buf, begin, end, /*num_threads=*/internal::convert_index<int>(num_blocks));
},
m_lhs_inner_dim_contiguous, m_rhs_inner_dim_contiguous, m_rhs_inner_dim_reordered);
// Check if it was the last task in l0 range.
const Index l0_index = block_idx / l0_size;

View File

@@ -131,7 +131,7 @@ static void ContractionSizes(::benchmark::Benchmark* b) {
static void ThreadPoolSizes(::benchmark::Benchmark* b) {
for (int size : {64, 256, 512, 1024}) {
for (int threads : {2, 4, 8}) {
for (int threads : {1, 2, 4, 8, 16}) {
b->Args({size, size, size, threads});
}
}