Modernize tensor contraction code: bug fixes, dead code removal, and cleanup

libeigen/eigen!2248

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-03-29 18:03:06 -07:00
parent 732ebc8cc2
commit 09581fda38
5 changed files with 133 additions and 156 deletions

View File

@@ -99,7 +99,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
// and temporary buffers, required for executing the tensor contraction.
// They are responsible for cleaning it up after contraction is done.
static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
static constexpr bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
const Index m = this->m_i_size;
const Index n = this->m_j_size;
@@ -176,7 +176,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (n == 1) num_threads = 1;
if (num_threads == 1) {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer));
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Unaligned>(buffer);
},
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
if (!IsEvalInSyncMode) done();
return;
}
@@ -258,22 +262,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// optimization.
if (parallelize_by_sharding_dim_only) parallel_pack = false;
// TODO(ezhulnev): With if constexpr we don't need SyncEvalParallelContext.
if (IsEvalInSyncMode) {
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
parallelize_by_sharding_dim_only, NoCallback()) \
.run()
TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
#undef CONTEXT_ARGS
} else {
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
parallelize_by_sharding_dim_only, std::move(done))
TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
#undef CONTEXT_ARGS
}
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
EIGEN_IF_CONSTEXPR(IsEvalInSyncMode) {
EvalParallelContext<NoCallback, lhs_c(), rhs_c(), rhs_r(), Alignment> ctx(
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
parallel_pack, parallelize_by_sharding_dim_only, NoCallback());
ctx.run();
}
else {
auto* ctx = new EvalParallelContext<DoneCallback, lhs_c(), rhs_c(), rhs_r(), Alignment>(
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
parallel_pack, parallelize_by_sharding_dim_only, std::move(done));
ctx->run();
}
},
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
}
// ------------------------------------------------------------------------ //
@@ -432,11 +436,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gn_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/0, //
/*num_rhs=*/num_blocks, //
/*num_slices=*/1, //
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/0, //
/*num_rhs=*/num_blocks, //
/*num_slices=*/1, //
/*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
} else {
@@ -444,11 +448,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gm_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/num_blocks, //
/*num_rhs=*/0, //
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
device_, //
/*num_lhs=*/num_blocks, //
/*num_rhs=*/0, //
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
/*rhs_blocks=*/nullptr);
}
}
@@ -461,7 +465,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
kernel_.deallocate(device_, packed_mem_);
if (parallelize_by_sharding_dim_only_) {
kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
kernel_.deallocate(device_, thread_local_pre_allocated_mem_);
delete[] can_use_thread_local_packed_;
}
}
@@ -585,7 +589,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// different from the thread that was used for packing.
// Handle for pre-allocated thread local memory buffers.
BlockMemHandle thread_local_pre_alocated_mem_;
BlockMemHandle thread_local_pre_allocated_mem_;
// Only one of these will be initialized depending on shard_by_col value
// (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
@@ -648,7 +652,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <typename BlockType, bool is_rhs>
class ThreadLocalBlocksInitialize {
static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
static constexpr bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
static_assert(kIsLhs || kIsRhs, "Unknown block type");
using Blocks = ThreadLocalBlocks<BlockType>;
@@ -668,10 +672,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
private:
// NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
// TensorContractionKernel::allocateSlices into template specializations.
// Also explicit specializations are not allowed at class scope in C++03,
// EvalCtx type parameter is just a workaround for that limitation.
// Explicit specializations are not allowed at class scope, so EvalCtx is
// a dummy template parameter to make these partial specializations.
template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
struct ThreadLocalBlocksAllocator;
@@ -684,7 +686,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
/*num_rhs=*/ctx.gn_,
/*num_slices=*/1,
/*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
}
@@ -703,7 +704,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
/*num_rhs=*/0,
/*num_slices=*/1,
/*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
}
@@ -1015,10 +1015,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
void operator=(const EvalParallelContext&) = delete;
};
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>;
// ------------------------------------------------------------------------ //
// EvalShardedByInnerDimContext orchestrates sync/async contraction
@@ -1086,11 +1082,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
private:
// The underlying GEMM kernel assumes that k is a multiple of
// the packet size and subtle breakage occurs if this is violated.
static const Index packet_size = internal::packet_traits<RhsScalar>::size;
static constexpr Index packet_size = internal::packet_traits<RhsScalar>::size;
const Self* evaluator; // TensorContraction evaluator
// These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
// These fields cache values from the evaluator for use in processBlock dispatch.
bool m_lhs_inner_dim_contiguous;
bool m_rhs_inner_dim_contiguous;
bool m_rhs_inner_dim_reordered;
@@ -1131,7 +1127,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
//
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
// sense only if number of threads >= ~128?
static const Index l0_size = 4;
static constexpr Index l0_size = 4;
Index l0_ranges;
// Keep count of pending gemm tasks for each l0 range.
@@ -1144,9 +1140,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
(buf, begin, end,
/*num_threads=*/internal::convert_index<int>(num_blocks)));
internal::tensor_contraction_dispatch(
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
evaluator->template evalGemmPartialWithoutOutputKernel<lhs_c(), rhs_c(), rhs_r(), Alignment>(
buf, begin, end, /*num_threads=*/internal::convert_index<int>(num_blocks));
},
m_lhs_inner_dim_contiguous, m_rhs_inner_dim_contiguous, m_rhs_inner_dim_reordered);
// Check if it was the last task in l0 range.
const Index l0_index = block_idx / l0_size;