mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Add support for custom packed Lhs/Rhs blocks in tensor contractions
This commit is contained in:
@@ -280,6 +280,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
|
||||
TensorContractionKernel;
|
||||
|
||||
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
|
||||
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
|
||||
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
|
||||
|
||||
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
|
||||
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
|
||||
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||
@@ -311,7 +315,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
gm_(gm),
|
||||
gn_(gn),
|
||||
nm0_(nm0),
|
||||
nn0_(nn0)
|
||||
nn0_(nn0),
|
||||
kernel_(m_, k_, n_, bm_, bk_, bn_)
|
||||
{
|
||||
// These two options are mutually exclusive.
|
||||
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
|
||||
@@ -342,26 +347,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
|
||||
// Allocate memory for packed rhs/lhs matrices.
|
||||
size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
|
||||
size_t lhs_size =
|
||||
divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align;
|
||||
size_t rhs_size =
|
||||
divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align;
|
||||
packed_mem_ = static_cast<char*>(device_.allocate(
|
||||
(nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
|
||||
char* mem = static_cast<char*>(packed_mem_);
|
||||
for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
|
||||
packed_lhs_[x].resize(nm0_);
|
||||
for (Index m = 0; m < nm0_; m++) {
|
||||
packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem);
|
||||
mem += lhs_size;
|
||||
}
|
||||
packed_rhs_[x].resize(nn0_);
|
||||
for (Index n = 0; n < nn0_; n++) {
|
||||
packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem);
|
||||
mem += rhs_size;
|
||||
}
|
||||
}
|
||||
packed_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/nm0_, //
|
||||
/*num_rhs=*/nn0_, //
|
||||
/*num_slices=*/std::min<Index>(nk_, P - 1), //
|
||||
packed_lhs_, packed_rhs_);
|
||||
|
||||
if (parallelize_by_sharding_dim_only_) {
|
||||
const int num_worker_threads = device_.numThreadsInPool();
|
||||
@@ -373,14 +364,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gn_;
|
||||
thread_local_packed_mem_ = device_.allocate(num_blocks * rhs_size);
|
||||
mem = static_cast<char*>(thread_local_packed_mem_);
|
||||
thread_local_packed_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/0, //
|
||||
/*num_rhs=*/num_blocks, //
|
||||
/*num_slices=*/1, //
|
||||
/*lhs_blocks=*/nullptr, &thread_local_packed_rhs_);
|
||||
|
||||
thread_local_packed_rhs_.resize(num_blocks, nullptr);
|
||||
for (Index i = 0; i < num_blocks; ++i) {
|
||||
thread_local_packed_rhs_[i] = reinterpret_cast<RhsScalar*>(mem);
|
||||
mem += rhs_size;
|
||||
}
|
||||
} else {
|
||||
can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
|
||||
for (int i = 0; i < nm_; ++i)
|
||||
@@ -388,14 +378,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gm_;
|
||||
thread_local_packed_mem_ = device_.allocate(num_blocks * lhs_size);
|
||||
mem = static_cast<char*>(thread_local_packed_mem_);
|
||||
|
||||
thread_local_packed_lhs_.resize(num_blocks, nullptr);
|
||||
for (Index i = 0; i < num_blocks; ++i) {
|
||||
thread_local_packed_lhs_[i] = reinterpret_cast<LhsScalar*>(mem);
|
||||
mem += lhs_size;
|
||||
}
|
||||
thread_local_packed_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/num_blocks, //
|
||||
/*num_rhs=*/0, //
|
||||
/*num_slices=*/1, &thread_local_packed_lhs_, //
|
||||
/*rhs_blocks=*/nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -405,9 +393,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
|
||||
delete[] state_kernel_[x];
|
||||
}
|
||||
device_.deallocate(packed_mem_);
|
||||
kernel_.deallocate(device_, packed_mem_);
|
||||
if (parallelize_by_sharding_dim_only_) {
|
||||
device_.deallocate(thread_local_packed_mem_);
|
||||
kernel_.deallocate(device_, thread_local_packed_mem_);
|
||||
delete[] can_use_thread_local_packed_;
|
||||
}
|
||||
}
|
||||
@@ -455,6 +443,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// coarsening).
|
||||
const Index nm0_;
|
||||
const Index nn0_;
|
||||
// Tensor contraction kernel.
|
||||
TensorContractionKernel kernel_;
|
||||
|
||||
// Parallelization strategy.
|
||||
//
|
||||
@@ -491,9 +481,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// actively executing + one to track completion of kernels in the second
|
||||
// slice.
|
||||
static const Index P = 3;
|
||||
void* packed_mem_;
|
||||
std::vector<LhsScalar*> packed_lhs_[P - 1];
|
||||
std::vector<RhsScalar*> packed_rhs_[P - 1];
|
||||
|
||||
// Handle to the allocated temporary storage for Lhs/Rhs blocks.
|
||||
BlockMemHandle packed_mem_;
|
||||
std::vector<LhsBlock> packed_lhs_[P - 1];
|
||||
std::vector<RhsBlock> packed_rhs_[P - 1];
|
||||
|
||||
// If we choose to parallelize only by the sharding dimension, each thread
|
||||
// will have it's own "thead local" (not a c++ thread local storage) memory
|
||||
@@ -511,11 +503,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
|
||||
// and packed_rhs_ to allow kernels to be executed later on a thread
|
||||
// different from the thread that was used for packing.
|
||||
void* thread_local_packed_mem_;
|
||||
BlockMemHandle thread_local_packed_mem_;
|
||||
|
||||
// Only one of these will beinitialized depending on shard_by_col value.
|
||||
std::vector<LhsScalar*> thread_local_packed_lhs_;
|
||||
std::vector<RhsScalar*> thread_local_packed_rhs_;
|
||||
// Only one of these will be initialized depending on shard_by_col value.
|
||||
std::vector<LhsBlock> thread_local_packed_lhs_;
|
||||
std::vector<RhsBlock> thread_local_packed_rhs_;
|
||||
|
||||
// After a particular shard for Kth slice missed thread local execution
|
||||
// opportunity (K-1 slice didn't complete kernels execution), we can no
|
||||
@@ -532,7 +524,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
std::atomic<Index> state_packing_ready_[P];
|
||||
std::atomic<Index> state_switch_[P];
|
||||
|
||||
LhsScalar* packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
|
||||
LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
|
||||
if (use_thread_local) {
|
||||
eigen_assert(!shard_by_col_);
|
||||
|
||||
@@ -546,7 +538,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
}
|
||||
|
||||
RhsScalar* packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
|
||||
RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
|
||||
if (use_thread_local) {
|
||||
eigen_assert(shard_by_col_);
|
||||
|
||||
@@ -580,7 +572,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
} else {
|
||||
// If we can't guarantee that all kernels in `k` slice will be
|
||||
// executed sequentially in current thread, it's no longer safe to use
|
||||
// thread local memory in followig slices along the k dimensions.
|
||||
// thread local memory in following slices along the k dimensions.
|
||||
eigen_assert(k > 0);
|
||||
can_use_thread_local_packed_[m].store(false,
|
||||
std::memory_order_relaxed);
|
||||
@@ -589,9 +581,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
|
||||
const Index mend = m * gm_ + gm(m);
|
||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
||||
TensorContractionKernel::packLhs(packed_lhs(m, k, m1, use_thread_local),
|
||||
lhs_.getSubMapper(m1 * bm_, k * bk_),
|
||||
bk(k), bm(m1));
|
||||
kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
|
||||
lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
|
||||
|
||||
if (!parallel_pack_ && shard_by_col_) {
|
||||
assert(!use_thread_local);
|
||||
@@ -634,9 +625,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// deadlocks.
|
||||
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
|
||||
}
|
||||
TensorContractionKernel::packRhs(packed_rhs(n, k, n1, use_thread_local),
|
||||
rhs_.getSubMapper(k * bk_, n1 * bn_),
|
||||
bk(k), bn(n1));
|
||||
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
|
||||
rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
|
||||
}
|
||||
|
||||
if (parallel_pack_ || shard_by_col_) {
|
||||
@@ -661,7 +651,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||
for (Index m1 = m * gm_; m1 < mend; m1++) {
|
||||
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||
TensorContractionKernel::invoke(
|
||||
kernel_.invoke(
|
||||
output_mapper,
|
||||
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
||||
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
||||
@@ -678,7 +668,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||
TensorContractionKernel::invoke(
|
||||
kernel_.invoke(
|
||||
output_mapper,
|
||||
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
||||
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
||||
|
||||
Reference in New Issue
Block a user