mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Modernize tensor contraction code: bug fixes, dead code removal, and cleanup
libeigen/eigen!2248 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -40,7 +40,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern
|
||||
std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
|
||||
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>;
|
||||
|
||||
enum { Flags = 0 };
|
||||
static constexpr int Flags = 0;
|
||||
};
|
||||
|
||||
template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
|
||||
@@ -168,7 +168,7 @@ template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename S
|
||||
struct TensorContractionKernel {
|
||||
// True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C`
|
||||
// (otherwise beta should be always equal to 1).
|
||||
enum { HasBeta = false };
|
||||
static constexpr bool HasBeta = false;
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_,
|
||||
StorageIndex bk_, StorageIndex bn_)
|
||||
@@ -247,8 +247,59 @@ struct TensorContractionKernel {
|
||||
const StorageIndex bn;
|
||||
};
|
||||
|
||||
// Dispatches a contraction operation over all 8 combinations of the three
|
||||
// runtime boolean parameters (lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
|
||||
// rhs_inner_dim_reordered), passing them as compile-time bool_constant
|
||||
// tags to the callable `fn`.
|
||||
template <typename Func>
|
||||
EIGEN_STRONG_INLINE void tensor_contraction_dispatch(Func&& fn, bool lhs_inner_dim_contiguous,
|
||||
bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered) {
|
||||
if (lhs_inner_dim_contiguous) {
|
||||
if (rhs_inner_dim_contiguous) {
|
||||
if (rhs_inner_dim_reordered)
|
||||
fn(bool_constant<true>{}, bool_constant<true>{}, bool_constant<true>{});
|
||||
else
|
||||
fn(bool_constant<true>{}, bool_constant<true>{}, bool_constant<false>{});
|
||||
} else {
|
||||
if (rhs_inner_dim_reordered)
|
||||
fn(bool_constant<true>{}, bool_constant<false>{}, bool_constant<true>{});
|
||||
else
|
||||
fn(bool_constant<true>{}, bool_constant<false>{}, bool_constant<false>{});
|
||||
}
|
||||
} else {
|
||||
if (rhs_inner_dim_contiguous) {
|
||||
if (rhs_inner_dim_reordered)
|
||||
fn(bool_constant<false>{}, bool_constant<true>{}, bool_constant<true>{});
|
||||
else
|
||||
fn(bool_constant<false>{}, bool_constant<true>{}, bool_constant<false>{});
|
||||
} else {
|
||||
if (rhs_inner_dim_reordered)
|
||||
fn(bool_constant<false>{}, bool_constant<false>{}, bool_constant<true>{});
|
||||
else
|
||||
fn(bool_constant<false>{}, bool_constant<false>{}, bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
// Legacy macros kept for backward compatibility with code that overrides them
|
||||
// (e.g. TensorFlow Lite restricts template instantiations for binary size).
|
||||
// New Eigen code should use internal::tensor_contraction_dispatch() instead.
|
||||
#ifndef TENSOR_CONTRACTION_DISPATCH
|
||||
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
||||
::Eigen::internal::tensor_contraction_dispatch( \
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) { METHOD<lhs_c(), rhs_c(), rhs_r(), ALIGNMENT> ARGS; }, \
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered)
|
||||
#endif
|
||||
|
||||
#ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
|
||||
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
|
||||
::Eigen::internal::tensor_contraction_dispatch( \
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) { (new METHOD<DONE, lhs_c(), rhs_c(), rhs_r(), ALIGNMENT> ARGS)->FN; }, \
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered)
|
||||
#endif
|
||||
|
||||
// Tensor contraction params that should enable to get from output matrix
|
||||
// 2-dimensional coordinates to the output tensor dimensions.
|
||||
struct TensorContractionParams {
|
||||
@@ -281,16 +332,8 @@ struct NoOpOutputKernel {
|
||||
* \param[in] num_cols Number of available columns
|
||||
*/
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
|
||||
const TensorContractionParams& params, Index i, Index j, Index num_rows,
|
||||
Index num_cols) const {
|
||||
EIGEN_UNUSED_VARIABLE(output_mapper);
|
||||
EIGEN_UNUSED_VARIABLE(params);
|
||||
EIGEN_UNUSED_VARIABLE(i);
|
||||
EIGEN_UNUSED_VARIABLE(j);
|
||||
EIGEN_UNUSED_VARIABLE(num_rows);
|
||||
EIGEN_UNUSED_VARIABLE(num_cols);
|
||||
}
|
||||
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper<Scalar, Index, ColMajor>&,
|
||||
const TensorContractionParams&, Index, Index, Index, Index) const {}
|
||||
};
|
||||
|
||||
/** Tensor contraction class.
|
||||
@@ -350,14 +393,12 @@ struct TensorContractionEvaluatorBase {
|
||||
using EvaluatorPointerType = typename Storage::Type;
|
||||
|
||||
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
|
||||
enum {
|
||||
IsAligned = true,
|
||||
PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
|
||||
BlockAccess = false,
|
||||
PreferBlockAccess = false,
|
||||
CoordAccess = false, // to be implemented
|
||||
RawAccess = true
|
||||
};
|
||||
static constexpr bool IsAligned = true;
|
||||
static constexpr bool PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1);
|
||||
static constexpr bool BlockAccess = false;
|
||||
static constexpr bool PreferBlockAccess = false;
|
||||
static constexpr bool CoordAccess = false; // to be implemented
|
||||
static constexpr bool RawAccess = true;
|
||||
|
||||
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
|
||||
using TensorBlock = internal::TensorBlockNotImplemented;
|
||||
@@ -397,7 +438,7 @@ struct TensorContractionEvaluatorBase {
|
||||
device),
|
||||
m_device(device),
|
||||
m_output_kernel(op.outputKernel()),
|
||||
m_result(NULL) {
|
||||
m_result(nullptr) {
|
||||
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
|
||||
static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
|
||||
YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
@@ -581,8 +622,8 @@ struct TensorContractionEvaluatorBase {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
|
||||
m_leftImpl.evalSubExprsIfNeeded(NULL);
|
||||
m_rightImpl.evalSubExprsIfNeeded(NULL);
|
||||
m_leftImpl.evalSubExprsIfNeeded(nullptr);
|
||||
m_rightImpl.evalSubExprsIfNeeded(nullptr);
|
||||
if (data) {
|
||||
evalTo(data);
|
||||
return false;
|
||||
@@ -609,72 +650,6 @@ struct TensorContractionEvaluatorBase {
|
||||
}
|
||||
#endif // EIGEN_USE_THREADS
|
||||
|
||||
#ifndef TENSOR_CONTRACTION_DISPATCH
|
||||
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
||||
if (this->m_lhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<true, true, true, ALIGNMENT> ARGS; \
|
||||
} else { \
|
||||
METHOD<true, true, false, ALIGNMENT> ARGS; \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<true, false, true, ALIGNMENT> ARGS; \
|
||||
} else { \
|
||||
METHOD<true, false, false, ALIGNMENT> ARGS; \
|
||||
} \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<false, true, true, ALIGNMENT> ARGS; \
|
||||
} else { \
|
||||
METHOD<false, true, false, ALIGNMENT> ARGS; \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<false, false, true, ALIGNMENT> ARGS; \
|
||||
} else { \
|
||||
METHOD<false, false, false, ALIGNMENT> ARGS; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
|
||||
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
|
||||
if (this->m_lhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
(new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
|
||||
} else { \
|
||||
(new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
(new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
|
||||
} else { \
|
||||
(new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
|
||||
} \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
(new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
|
||||
} else { \
|
||||
(new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
|
||||
} \
|
||||
} else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
(new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
|
||||
} else { \
|
||||
(new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
|
||||
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
|
||||
}
|
||||
@@ -867,9 +842,9 @@ struct TensorContractionEvaluatorBase {
|
||||
m_leftImpl.cleanup();
|
||||
m_rightImpl.cleanup();
|
||||
|
||||
if (m_result != NULL) {
|
||||
if (m_result != nullptr) {
|
||||
m_device.deallocate(m_result);
|
||||
m_result = NULL;
|
||||
m_result = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -957,7 +932,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
|
||||
template <int Alignment>
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Alignment>(buffer);
|
||||
},
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
enum { ShardByRow = 0, ShardByCol = 1 };
|
||||
constexpr int ShardByRow = 0;
|
||||
constexpr int ShardByCol = 1;
|
||||
|
||||
// Default Blocking Strategy
|
||||
template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex,
|
||||
|
||||
@@ -17,7 +17,8 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
enum { Rhs = 0, Lhs = 1 };
|
||||
constexpr int Rhs = 0;
|
||||
constexpr int Lhs = 1;
|
||||
|
||||
/*
|
||||
* Implementation of the Eigen blas_data_mapper class for tensors.
|
||||
@@ -34,7 +35,7 @@ class BaseTensorContractionMapper;
|
||||
|
||||
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
|
||||
struct CoeffLoader {
|
||||
enum { DirectOffsets = false };
|
||||
static constexpr bool DirectOffsets = false;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) {}
|
||||
|
||||
@@ -44,7 +45,7 @@ struct CoeffLoader {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
|
||||
eigen_assert(false && "unsupported");
|
||||
return NULL;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
|
||||
@@ -62,7 +63,7 @@ struct CoeffLoader {
|
||||
|
||||
template <typename Tensor, template <class> class MakePointer_>
|
||||
struct CoeffLoader<Tensor, true, MakePointer_> {
|
||||
enum { DirectOffsets = true };
|
||||
static constexpr bool DirectOffsets = true;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
|
||||
|
||||
@@ -100,7 +101,7 @@ class SimpleTensorContractionMapper {
|
||||
m_contract_strides(contract_strides),
|
||||
m_k_strides(k_strides) {}
|
||||
|
||||
enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
|
||||
static constexpr bool DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
|
||||
m_tensor.offsetBuffer(offset);
|
||||
@@ -123,7 +124,6 @@ class SimpleTensorContractionMapper {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
|
||||
const bool left = (side == Lhs);
|
||||
EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
|
||||
Index nocontract_val = left ? row : col;
|
||||
Index linidx = 0;
|
||||
EIGEN_UNROLL_LOOP
|
||||
@@ -164,7 +164,6 @@ class SimpleTensorContractionMapper {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
|
||||
const Index distance) const {
|
||||
const bool left = (side == Lhs);
|
||||
EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
|
||||
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
|
||||
Index linidx[2] = {0, 0};
|
||||
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
|
||||
@@ -363,12 +362,10 @@ class TensorContractionSubMapper {
|
||||
using LinearMapper = Self;
|
||||
using SubMapper = Self;
|
||||
|
||||
enum {
|
||||
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
|
||||
// TODO: we should also enable direct offsets for the Rhs case.
|
||||
UseDirectOffsets =
|
||||
ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
|
||||
};
|
||||
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
|
||||
// TODO: we should also enable direct offsets for the Rhs case.
|
||||
static constexpr bool UseDirectOffsets =
|
||||
ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0);
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
|
||||
@@ -429,8 +426,9 @@ class TensorContractionSubMapper {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
|
||||
if (UseDirectOffsets) {
|
||||
m_base_mapper.storePacket(i, 0, p);
|
||||
} else {
|
||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||
}
|
||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
@@ -451,7 +449,7 @@ class TensorContractionSubMapper {
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
static_assert(std::is_same<PacketT, PacketT>::value, "YOU_MADE_A_PROGRAMMING_MISTAKE");
|
||||
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);
|
||||
|
||||
@@ -99,7 +99,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
|
||||
// and temporary buffers, required for executing the tensor contraction.
|
||||
// They are responsible for cleaning it up after contraction is done.
|
||||
static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
|
||||
static constexpr bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
|
||||
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
@@ -176,7 +176,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
if (n == 1) num_threads = 1;
|
||||
|
||||
if (num_threads == 1) {
|
||||
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer));
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
this->template evalProductSequential<lhs_c(), rhs_c(), rhs_r(), Unaligned>(buffer);
|
||||
},
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
|
||||
if (!IsEvalInSyncMode) done();
|
||||
return;
|
||||
}
|
||||
@@ -258,22 +262,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// optimization.
|
||||
if (parallelize_by_sharding_dim_only) parallel_pack = false;
|
||||
|
||||
// TODO(ezhulnev): With if constexpr we don't need SyncEvalParallelContext.
|
||||
if (IsEvalInSyncMode) {
|
||||
#define CONTEXT_ARGS \
|
||||
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
|
||||
parallelize_by_sharding_dim_only, NoCallback()) \
|
||||
.run()
|
||||
TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
|
||||
#undef CONTEXT_ARGS
|
||||
|
||||
} else {
|
||||
#define CONTEXT_ARGS \
|
||||
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
|
||||
parallelize_by_sharding_dim_only, std::move(done))
|
||||
TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
|
||||
#undef CONTEXT_ARGS
|
||||
}
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
EIGEN_IF_CONSTEXPR(IsEvalInSyncMode) {
|
||||
EvalParallelContext<NoCallback, lhs_c(), rhs_c(), rhs_r(), Alignment> ctx(
|
||||
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
|
||||
parallel_pack, parallelize_by_sharding_dim_only, NoCallback());
|
||||
ctx.run();
|
||||
}
|
||||
else {
|
||||
auto* ctx = new EvalParallelContext<DoneCallback, lhs_c(), rhs_c(), rhs_r(), Alignment>(
|
||||
this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col,
|
||||
parallel_pack, parallelize_by_sharding_dim_only, std::move(done));
|
||||
ctx->run();
|
||||
}
|
||||
},
|
||||
this->m_lhs_inner_dim_contiguous, this->m_rhs_inner_dim_contiguous, this->m_rhs_inner_dim_reordered);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
@@ -432,11 +436,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gn_;
|
||||
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/0, //
|
||||
/*num_rhs=*/num_blocks, //
|
||||
/*num_slices=*/1, //
|
||||
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/0, //
|
||||
/*num_rhs=*/num_blocks, //
|
||||
/*num_slices=*/1, //
|
||||
/*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
|
||||
|
||||
} else {
|
||||
@@ -444,11 +448,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
for (int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
|
||||
|
||||
Index num_blocks = num_worker_threads * gm_;
|
||||
thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/num_blocks, //
|
||||
/*num_rhs=*/0, //
|
||||
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
|
||||
thread_local_pre_allocated_mem_ = kernel_.allocateSlices( //
|
||||
device_, //
|
||||
/*num_lhs=*/num_blocks, //
|
||||
/*num_rhs=*/0, //
|
||||
/*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
|
||||
/*rhs_blocks=*/nullptr);
|
||||
}
|
||||
}
|
||||
@@ -461,7 +465,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
kernel_.deallocate(device_, packed_mem_);
|
||||
if (parallelize_by_sharding_dim_only_) {
|
||||
kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
|
||||
kernel_.deallocate(device_, thread_local_pre_allocated_mem_);
|
||||
delete[] can_use_thread_local_packed_;
|
||||
}
|
||||
}
|
||||
@@ -585,7 +589,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// different from the thread that was used for packing.
|
||||
|
||||
// Handle for pre-allocated thread local memory buffers.
|
||||
BlockMemHandle thread_local_pre_alocated_mem_;
|
||||
BlockMemHandle thread_local_pre_allocated_mem_;
|
||||
|
||||
// Only one of these will be initialized depending on shard_by_col value
|
||||
// (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
|
||||
@@ -648,7 +652,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
template <typename BlockType, bool is_rhs>
|
||||
class ThreadLocalBlocksInitialize {
|
||||
static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
|
||||
static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
|
||||
static constexpr bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
|
||||
static_assert(kIsLhs || kIsRhs, "Unknown block type");
|
||||
|
||||
using Blocks = ThreadLocalBlocks<BlockType>;
|
||||
@@ -668,10 +672,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
|
||||
private:
|
||||
// NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
|
||||
// TensorContractionKernel::allocateSlices into template specializations.
|
||||
// Also explicit specializations are not allowed at class scope in C++03,
|
||||
// EvalCtx type parameter is just a workaround for that limitation.
|
||||
// Explicit specializations are not allowed at class scope, so EvalCtx is
|
||||
// a dummy template parameter to make these partial specializations.
|
||||
template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
|
||||
struct ThreadLocalBlocksAllocator;
|
||||
|
||||
@@ -684,7 +686,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
/*num_rhs=*/ctx.gn_,
|
||||
/*num_slices=*/1,
|
||||
/*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
|
||||
|
||||
blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
|
||||
}
|
||||
|
||||
@@ -703,7 +704,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
/*num_rhs=*/0,
|
||||
/*num_slices=*/1,
|
||||
/*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
|
||||
|
||||
blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
|
||||
}
|
||||
|
||||
@@ -1015,10 +1015,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
void operator=(const EvalParallelContext&) = delete;
|
||||
};
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Alignment>;
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// EvalShardedByInnerDimContext orchestrates sync/async contraction
|
||||
@@ -1086,11 +1082,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
private:
|
||||
// The underlying GEMM kernel assumes that k is a multiple of
|
||||
// the packet size and subtle breakage occurs if this is violated.
|
||||
static const Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
static constexpr Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
|
||||
const Self* evaluator; // TensorContraction evaluator
|
||||
|
||||
// These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
|
||||
// These fields cache values from the evaluator for use in processBlock dispatch.
|
||||
bool m_lhs_inner_dim_contiguous;
|
||||
bool m_rhs_inner_dim_contiguous;
|
||||
bool m_rhs_inner_dim_reordered;
|
||||
@@ -1131,7 +1127,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
//
|
||||
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
|
||||
// sense only if number of threads >= ~128?
|
||||
static const Index l0_size = 4;
|
||||
static constexpr Index l0_size = 4;
|
||||
Index l0_ranges;
|
||||
|
||||
// Keep count of pending gemm tasks for each l0 range.
|
||||
@@ -1144,9 +1140,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
void processBlock(Index block_idx, Index begin, Index end) {
|
||||
Scalar* buf = block_buffers[block_idx];
|
||||
|
||||
TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
|
||||
(buf, begin, end,
|
||||
/*num_threads=*/internal::convert_index<int>(num_blocks)));
|
||||
internal::tensor_contraction_dispatch(
|
||||
[&](auto lhs_c, auto rhs_c, auto rhs_r) {
|
||||
evaluator->template evalGemmPartialWithoutOutputKernel<lhs_c(), rhs_c(), rhs_r(), Alignment>(
|
||||
buf, begin, end, /*num_threads=*/internal::convert_index<int>(num_blocks));
|
||||
},
|
||||
m_lhs_inner_dim_contiguous, m_rhs_inner_dim_contiguous, m_rhs_inner_dim_reordered);
|
||||
|
||||
// Check if it was the last task in l0 range.
|
||||
const Index l0_index = block_idx / l0_size;
|
||||
|
||||
@@ -131,7 +131,7 @@ static void ContractionSizes(::benchmark::Benchmark* b) {
|
||||
|
||||
static void ThreadPoolSizes(::benchmark::Benchmark* b) {
|
||||
for (int size : {64, 256, 512, 1024}) {
|
||||
for (int threads : {2, 4, 8}) {
|
||||
for (int threads : {1, 2, 4, 8, 16}) {
|
||||
b->Args({size, size, size, threads});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user