mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Modernize tensor contraction code: bug fixes, dead code removal, and cleanup
libeigen/eigen!2248 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -99,7 +99,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
|
||||
// and temporary buffers, required for executing the tensor contraction.
|
||||
// They are responsible for cleaning it up after contraction is done.
|
||||
static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
|
||||
static constexpr bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
|
||||
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
@@ -176,7 +176,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
if (n == 1) num_threads = 1;
|
||||
|
||||
if (num_threads == 1) {
|
||||
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer));
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Unaligned>(buffer);
|
||||
},
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
|
||||
if (!IsEvalInSyncMode) done();
|
||||
return;
|
||||
}
|
||||
@@ -258,22 +262,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// optimization.
|
||||
if (parallelize_by_sharding_dim_only) parallel_pack = false;
|
||||
|
||||
// TODO(ezhulnev): With if constexpr we don't need SyncEvalParallelContext.
|
||||
if (IsEvalInSyncMode) {
|
||||
#define CONTEXT_ARGS \
|
||||
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
|
||||
parallelize_by_sharding_dim_only, NoCallback()) \
|
||||
.run()
|
||||
TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
|
||||
#undef CONTEXT_ARGS
|
||||
|
||||
} else {
|
||||
#define CONTEXT_ARGS \
|
||||
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
|
||||
parallelize_by_sharding_dim_only, std::move(done))
|
||||
TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
|
||||
#undef CONTEXT_ARGS
|
||||
}
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
EIGEN_IF_CONSTEXPR(IsEvalInSyncMode) {
|
||||
EvalParallelContext<NoCallback, lhs_c(), rhs_c(), rhs_r(), Alignment> ctx(
|
||||
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
|
||||
parallel_pack, parallelize_by_sharding_dim_only, NoCallback());
|
||||
ctx.run();
|
||||
}
|
||||
else {
|
||||
auto* ctx = new EvalParallelContext<DoneCallback, lhs_c(), rhs_c(), rhs_r(), Alignment>(
|
||||
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
|
||||
parallel_pack, parallelize_by_sharding_dim_only, std::move(done));
|
||||
ctx->run();
|
||||
}
|
||||
},
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
@@ -432,11 +436,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gn_;
|
||||
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/0, //
|
||||
/*num_rhs=*/num_blocks, //
|
||||
/*num_slices=*/1, //
|
||||
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/0, //
|
||||
/*num_rhs=*/num_blocks, //
|
||||
/*num_slices=*/1, //
|
||||
/*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
|
||||
|
||||
} else {
|
||||
@@ -444,11 +448,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gm_;
|
||||
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/num_blocks, //
|
||||
/*num_rhs=*/0, //
|
||||
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
|
||||
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/num_blocks, //
|
||||
/*num_rhs=*/0, //
|
||||
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
|
||||
/*rhs_blocks=*/nullptr);
|
||||
}
|
||||
}
|
||||
@@ -461,7 +465,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
kernel_.deallocate(device_, packed_mem_);
|
||||
if (parallelize_by_sharding_dim_only_) {
|
||||
kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
|
||||
kernel_.deallocate(device_, thread_local_pre_allocated_mem_);
|
||||
delete[] can_use_thread_local_packed_;
|
||||
}
|
||||
}
|
||||
@@ -585,7 +589,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// different from the thread that was used for packing.
|
||||
|
||||
// Handle for pre-allocated thread local memory buffers.
|
||||
BlockMemHandle thread_local_pre_alocated_mem_;
|
||||
BlockMemHandle thread_local_pre_allocated_mem_;
|
||||
|
||||
// Only one of these will be initialized depending on shard_by_col value
|
||||
// (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
|
||||
@@ -648,7 +652,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
template <typename BlockType, bool is_rhs>
|
||||
class ThreadLocalBlocksInitialize {
|
||||
static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
|
||||
static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
|
||||
static constexpr bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
|
||||
static_assert(kIsLhs || kIsRhs, "Unknown block type");
|
||||
|
||||
using Blocks = ThreadLocalBlocks<BlockType>;
|
||||
@@ -668,10 +672,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
|
||||
private:
|
||||
// NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
|
||||
// TensorContractionKernel::allocateSlices into template specializations.
|
||||
// Also explicit specializations are not allowed at class scope in C++03,
|
||||
// EvalCtx type parameter is just a workaround for that limitation.
|
||||
// Explicit specializations are not allowed at class scope, so EvalCtx is
|
||||
// a dummy template parameter to make these partial specializations.
|
||||
template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
|
||||
struct ThreadLocalBlocksAllocator;
|
||||
|
||||
@@ -684,7 +686,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
/*num_rhs=*/ctx.gn_,
|
||||
/*num_slices=*/1,
|
||||
/*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
|
||||
|
||||
blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
|
||||
}
|
||||
|
||||
@@ -703,7 +704,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
/*num_rhs=*/0,
|
||||
/*num_slices=*/1,
|
||||
/*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
|
||||
|
||||
blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
|
||||
}
|
||||
|
||||
@@ -1015,10 +1015,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
void operator=(const EvalParallelContext&) = delete;
|
||||
};
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Alignment>;
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// EvalShardedByInnerDimContext orchestrates sync/async contraction
|
||||
@@ -1086,11 +1082,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
private:
|
||||
// The underlying GEMM kernel assumes that k is a multiple of
|
||||
// the packet size and subtle breakage occurs if this is violated.
|
||||
static const Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
static constexpr Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
|
||||
const Self* evaluator; // TensorContraction evaluator
|
||||
|
||||
// These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
|
||||
// These fields cache values from the evaluator for use in processBlock dispatch.
|
||||
bool m_lhs_inner_dim_contiguous;
|
||||
bool m_rhs_inner_dim_contiguous;
|
||||
bool m_rhs_inner_dim_reordered;
|
||||
@@ -1131,7 +1127,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
//
|
||||
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
|
||||
// sense only if number of threads >= ~128?
|
||||
static const Index l0_size = 4;
|
||||
static constexpr Index l0_size = 4;
|
||||
Index l0_ranges;
|
||||
|
||||
// Keep count of pending gemm tasks for each l0 range.
|
||||
@@ -1144,9 +1140,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
void processBlock(Index block_idx, Index begin, Index end) {
|
||||
Scalar* buf = block_buffers[block_idx];
|
||||
|
||||
TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
|
||||
(buf, begin, end,
|
||||
/*num_threads=*/internal::convert_index<int>(num_blocks)));
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
evaluator->template evalGemmPartialWithoutOutputKernel<lhs_c(), rhs_c(), rhs_r(), Alignment>(
|
||||
buf, begin, end, /*num_threads=*/internal::convert_index<int>(num_blocks));
|
||||
},
|
||||
m_lhs_inner_dim_contiguous, m_rhs_inner_dim_contiguous, m_rhs_inner_dim_reordered);
|
||||
|
||||
// Check if it was the last task in l0 range.
|
||||
const Index l0_index = block_idx / l0_size;
|
||||
|
||||
Reference in New Issue
Block a user