Fix CUDA+Clang build warnings.

libeigen/eigen!2241
This commit is contained in:
Antonio Sánchez
2026-03-04 09:41:01 +00:00
committed by Rasmus Munk Larsen
parent 0269c017aa
commit abc3d6014d
11 changed files with 356 additions and 334 deletions

View File

@@ -414,6 +414,7 @@ if (EIGEN_BUILD_TESTING)
ei_add_cxx_compiler_flag("-Wno-psabi")
ei_add_cxx_compiler_flag("-Wno-variadic-macros")
ei_add_cxx_compiler_flag("-Wno-long-long")
ei_add_cxx_compiler_flag("-Wno-pass-failed") # disable clang's warning for unrolling when the loop count is dynamic.
ei_add_cxx_compiler_flag("-fno-common")
ei_add_cxx_compiler_flag("-fstrict-aliasing")
ei_add_cxx_compiler_flag("-wd981") # disable ICC's "operands are evaluated in unspecified order" remark

View File

@@ -151,13 +151,13 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
// increasing the value of k, so we'll cap it at 320 (value determined
// experimentally).
// To avoid that k vanishes, we make k_cache at least as big as kr
const Index k_cache = numext::maxi<Index>(kr, (numext::mini<Index>)((l1 - ksub) / kdiv, 320));
const Index k_cache = numext::maxi<Index>(kr, (numext::mini<Index>)(static_cast<Index>((l1 - ksub) / kdiv), 320));
if (k_cache < k) {
k = k_cache - (k_cache % kr);
eigen_internal_assert(k > 0);
}
const Index n_cache = (l2 - l1) / (nr * sizeof(RhsScalar) * k);
const Index n_cache = static_cast<Index>((l2 - l1) / (nr * sizeof(RhsScalar) * k));
const Index n_per_thread = numext::div_ceil(n, num_threads);
if (n_cache <= n_per_thread) {
// Don't exceed the capacity of the l2 cache.
@@ -170,7 +170,7 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
if (l3 > l2) {
// l3 is shared between all cores, so we'll give each thread its own chunk of l3.
const Index m_cache = (l3 - l2) / (sizeof(LhsScalar) * k * num_threads);
const Index m_cache = static_cast<Index>((l3 - l2) / (sizeof(LhsScalar) * k * num_threads));
const Index m_per_thread = numext::div_ceil(m, num_threads);
if (m_cache < m_per_thread && m_cache >= static_cast<Index>(mr)) {
m = m_cache - (m_cache % mr);
@@ -208,7 +208,7 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
// We also include a register-level block of the result (mx x nr).
// (In an ideal world only the lhs panel would stay in L1)
// Moreover, kc has to be a multiple of 8 to be compatible with loop peeling, leading to a maximum blocking size of:
const Index max_kc = numext::maxi<Index>(((l1 - k_sub) / k_div) & (~(k_peeling - 1)), 1);
const Index max_kc = numext::maxi<Index>(static_cast<Index>(((l1 - k_sub) / k_div) & (~(k_peeling - 1))), 1);
const Index old_k = k;
if (k > max_kc) {
// We are really blocking on the third dimension:
@@ -242,9 +242,9 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
// that spills to L3 but remains accessible with low latency. This matches
// the empirically-tuned constant (1.5MB) previously used when L2 was 1MB.
#ifdef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS
const Index actual_l2 = l3;
const Index actual_l2 = static_cast<Index>(l3);
#else
const Index actual_l2 = l2 * 3 / 2;
const Index actual_l2 = static_cast<Index>(l2 * 3 / 2);
#endif
// Here, nc is chosen such that a block of kc x nc of the rhs fit within half of L2.
@@ -255,7 +255,7 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
// and it becomes fruitful to keep the packed rhs blocks in L1 if there is enough remaining space.
Index max_nc;
const Index lhs_bytes = m * k * sizeof(LhsScalar);
const Index remaining_l1 = l1 - k_sub - lhs_bytes;
const Index remaining_l1 = static_cast<Index>(l1 - k_sub - lhs_bytes);
if (remaining_l1 >= Index(Traits::nr * sizeof(RhsScalar)) * k) {
// L1 blocking
max_nc = remaining_l1 / (k * sizeof(RhsScalar));
@@ -282,11 +282,11 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
if (problem_size <= 1024) {
// problem is small enough to keep in L1
// Let's choose m such that lhs's block fit in 1/3 of L1
actual_lm = l1;
actual_lm = static_cast<Index>(l1);
} else if (l3 != 0 && problem_size <= 32768) {
// we have both L2 and L3, and problem is small enough to be kept in L2
// Let's choose m such that lhs's block fit in 1/3 of L2
actual_lm = l2;
actual_lm = static_cast<Index>(l2);
max_mc = (numext::mini<Index>)(576, max_mc);
}
Index mc = (numext::mini<Index>)(actual_lm / (3 * k * sizeof(LhsScalar)), max_mc);

View File

@@ -68,7 +68,7 @@ struct MultiplyKernel {
};
template <typename T1, typename T2, typename T3>
void test_multiply(const T1& type1, const T2& type2, const T3& type3) {
void test_multiply(const T1& type1, const T2& type2, const T3& /*type3*/) {
const T1 A = T1::Random(type1.rows(), type1.cols());
const T2 B = T2::Random(type2.rows(), type2.cols());
T3 C;

View File

@@ -30,6 +30,9 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
const Index base_m = 64 * m_block_idx;
const Index base_n = 64 * n_block_idx;
const Index thread_x = threadIdx.x;
const Index thread_y = threadIdx.y;
const Index thread_z = threadIdx.z;
// declare and initialize 64 registers for output 8x8 block
@@ -66,8 +69,8 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
// conflicts on writes and also none on reads.
// storage indices
const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
const Index lhs_store_idx_base = thread_y * 72 + thread_x * 9 + thread_z;
const Index rhs_store_idx_base = thread_y * 72 + thread_z * 8 + thread_x;
const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
@@ -88,151 +91,151 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
// in the loading code, the following variables are important:
// threadIdx.x: the vertical position in an 8x8 block
// threadIdx.y: the vertical index of the 8x8 block in the grid
// threadIdx.z: the horizontal position in an 8x8 block
// thread_x: the vertical position in an 8x8 block
// thread_y: the vertical index of the 8x8 block in the grid
// thread_z: the horizontal position in an 8x8 block
// k: the horizontal index of the 8x8 block in the grid
//
// The k parameter is implicit (it was the loop counter for a loop that went
// from 0 to <8, but now that loop is unrolled in the below code.
const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
const Index load_idx_vert = thread_x + 8 * thread_y;
const Index lhs_vert = base_m + load_idx_vert;
#define prefetchIntoRegisters(base_k) \
{ \
lhs_pf0 = conv(0); \
lhs_pf1 = conv(0); \
lhs_pf2 = conv(0); \
lhs_pf3 = conv(0); \
lhs_pf4 = conv(0); \
lhs_pf5 = conv(0); \
lhs_pf6 = conv(0); \
lhs_pf7 = conv(0); \
\
rhs_pf0 = conv(0); \
rhs_pf1 = conv(0); \
rhs_pf2 = conv(0); \
rhs_pf3 = conv(0); \
rhs_pf4 = conv(0); \
rhs_pf5 = conv(0); \
rhs_pf6 = conv(0); \
rhs_pf7 = conv(0); \
\
if (!needs_edge_check || lhs_vert < m_size) { \
const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
\
if (!needs_edge_check || lhs_horiz_7 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
} else if (lhs_horiz_6 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
} else if (lhs_horiz_5 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
} else if (lhs_horiz_4 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
} else if (lhs_horiz_3 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
} else if (lhs_horiz_2 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
} else if (lhs_horiz_1 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
} else if (lhs_horiz_0 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
} \
} \
\
const Index rhs_vert = base_k + load_idx_vert; \
if (!needs_edge_check || rhs_vert < k_size) { \
const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
\
if (rhs_horiz_7 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
} else if (rhs_horiz_6 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
} else if (rhs_horiz_5 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
} else if (rhs_horiz_4 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
} else if (rhs_horiz_3 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
} else if (rhs_horiz_2 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
} else if (rhs_horiz_1 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
} else if (rhs_horiz_0 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
} \
} \
#define prefetchIntoRegisters(base_k) \
{ \
lhs_pf0 = conv(0); \
lhs_pf1 = conv(0); \
lhs_pf2 = conv(0); \
lhs_pf3 = conv(0); \
lhs_pf4 = conv(0); \
lhs_pf5 = conv(0); \
lhs_pf6 = conv(0); \
lhs_pf7 = conv(0); \
\
rhs_pf0 = conv(0); \
rhs_pf1 = conv(0); \
rhs_pf2 = conv(0); \
rhs_pf3 = conv(0); \
rhs_pf4 = conv(0); \
rhs_pf5 = conv(0); \
rhs_pf6 = conv(0); \
rhs_pf7 = conv(0); \
\
if (!needs_edge_check || lhs_vert < m_size) { \
const Index lhs_horiz_0 = base_k + thread_z + 0 * 8; \
const Index lhs_horiz_1 = base_k + thread_z + 1 * 8; \
const Index lhs_horiz_2 = base_k + thread_z + 2 * 8; \
const Index lhs_horiz_3 = base_k + thread_z + 3 * 8; \
const Index lhs_horiz_4 = base_k + thread_z + 4 * 8; \
const Index lhs_horiz_5 = base_k + thread_z + 5 * 8; \
const Index lhs_horiz_6 = base_k + thread_z + 6 * 8; \
const Index lhs_horiz_7 = base_k + thread_z + 7 * 8; \
\
if (!needs_edge_check || lhs_horiz_7 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
} else if (lhs_horiz_6 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
} else if (lhs_horiz_5 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
} else if (lhs_horiz_4 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
} else if (lhs_horiz_3 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
} else if (lhs_horiz_2 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
} else if (lhs_horiz_1 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
} else if (lhs_horiz_0 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
} \
} \
\
const Index rhs_vert = base_k + load_idx_vert; \
if (!needs_edge_check || rhs_vert < k_size) { \
const Index rhs_horiz_0 = base_n + thread_z + 0 * 8; \
const Index rhs_horiz_1 = base_n + thread_z + 1 * 8; \
const Index rhs_horiz_2 = base_n + thread_z + 2 * 8; \
const Index rhs_horiz_3 = base_n + thread_z + 3 * 8; \
const Index rhs_horiz_4 = base_n + thread_z + 4 * 8; \
const Index rhs_horiz_5 = base_n + thread_z + 5 * 8; \
const Index rhs_horiz_6 = base_n + thread_z + 6 * 8; \
const Index rhs_horiz_7 = base_n + thread_z + 7 * 8; \
\
if (rhs_horiz_7 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
} else if (rhs_horiz_6 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
} else if (rhs_horiz_5 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
} else if (rhs_horiz_4 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
} else if (rhs_horiz_3 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
} else if (rhs_horiz_2 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
} else if (rhs_horiz_1 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
} else if (rhs_horiz_0 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
} \
} \
}
#define writeRegToShmem() \
@@ -321,8 +324,8 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
Scalar rrow(7);
// Now x corresponds to k, y to m, and z to n
const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
const Scalar* lhs_block = &lhs_shmem[thread_x + 9 * thread_y];
const Scalar* rhs_block = &rhs_shmem[thread_x + 8 * thread_z];
#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
@@ -441,7 +444,7 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
// wait for shared mem to be out of use
__syncthreads();
#define writeResultShmem(i, j) lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j);
#define writeResultShmem(i, j) lhs_shmem[i + 8 * thread_y + 64 * thread_z + 512 * j] = res(i, j);
#define writeRow(i) \
writeResultShmem(i, 0); \
@@ -453,7 +456,7 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
writeResultShmem(i, 6); \
writeResultShmem(i, 7);
if (threadIdx.x == 0) {
if (thread_x == 0) {
writeRow(0);
writeRow(1);
writeRow(2);
@@ -466,34 +469,34 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
#undef writeResultShmem
#undef writeRow
const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
const int max_i_write = numext::mini((int)((m_size - base_m - thread_y + 7) / 8), 8);
const int max_j_write = numext::mini((int)((n_size - base_n - thread_z + 7) / 8), 8);
if (threadIdx.x < max_i_write) {
if (thread_x < max_i_write) {
if (max_j_write == 8) {
// TODO: Can we trade bank conflicts for coalesced writes?
Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
Scalar val0 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 0];
Scalar val1 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 1];
Scalar val2 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 2];
Scalar val3 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 3];
Scalar val4 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 4];
Scalar val5 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 5];
Scalar val6 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 6];
Scalar val7 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 7];
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 0) = val0;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 1) = val1;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 2) = val2;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 3) = val3;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 4) = val4;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 5) = val5;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 6) = val6;
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 7) = val7;
} else {
#pragma unroll 7
for (int j = 0; j < max_j_write; j++) {
Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
Scalar val = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * j];
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * j) = val;
}
}
}
@@ -539,6 +542,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
float4 lhs_pf0, rhs_pf0;
float4 results[4];
const Index thread_x = threadIdx.x;
const Index thread_y = threadIdx.y;
for (int i = 0; i < 4; i++) {
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
@@ -565,17 +570,17 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
} \
}
Index lhs_vert = base_m + threadIdx.x * 4;
Index lhs_vert = base_m + thread_x * 4;
for (Index k = 0; k < k_size; k += 16) {
lhs_pf0 = internal::pset1<float4>(0);
rhs_pf0 = internal::pset1<float4>(0);
Index lhs_horiz = threadIdx.y + k;
Index lhs_horiz = thread_y + k;
prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
Index rhs_vert = k + (threadIdx.x % 4) * 4;
Index rhs_horiz0 = (threadIdx.x >> 2) + threadIdx.y * 4 + base_n;
Index rhs_vert = k + (thread_x % 4) * 4;
Index rhs_horiz0 = (thread_x >> 2) + thread_y * 4 + base_n;
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
@@ -610,7 +615,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
}
float x1, x2;
// TODO: The following can be a bitwise operation.
if ((threadIdx.x % 8) < 4) {
if ((thread_x % 8) < 4) {
x1 = rhs_pf0.y;
x2 = rhs_pf0.w;
} else {
@@ -624,7 +629,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
#endif
if ((threadIdx.x % 8) < 4) {
if ((thread_x % 8) < 4) {
rhs_pf0.y = x1;
rhs_pf0.w = x2;
} else {
@@ -639,8 +644,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
// Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
// Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
// ...
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2][threadIdx.x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2 + 32][threadIdx.x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
rhs_shmem2[(thread_x >> 3) + thread_y * 2][thread_x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
rhs_shmem2[(thread_x >> 3) + thread_y * 2 + 32][thread_x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
// Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
// Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
@@ -649,8 +654,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
// Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
// ...
lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[threadIdx.y + 16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
lhs_shmem2[thread_y][thread_x] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[thread_y + 16][thread_x] = make_float2(lhs_pf0.z, lhs_pf0.w);
#define add_vals(fl1, fl2, fr1, fr2) \
results[0].x += fl1.x * fr1.x; \
@@ -679,10 +684,10 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
#pragma unroll
for (int koff = 0; koff < 16; koff++) {
// 32 x threads.
float2 fl1 = lhs_shmem2[koff][threadIdx.x];
float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
float2 fl1 = lhs_shmem2[koff][thread_x];
float2 fl2 = lhs_shmem2[koff + 16][thread_x];
int start_feature = threadIdx.y * 4;
int start_feature = thread_y * 4;
float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
@@ -694,7 +699,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
#undef prefetch_lhs
#undef add_vals
Index horiz_base = threadIdx.y * 4 + base_n;
Index horiz_base = thread_y * 4 + base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
@@ -770,11 +775,15 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
float4 rhs_pf0, rhs_pf1;
float4 results[8];
const Index thread_x = threadIdx.x;
const Index thread_y = threadIdx.y;
for (int i = 0; i < 8; i++) {
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
Index lhs_vert = base_m + threadIdx.x * 4 + (threadIdx.y % 4) * 32;
Index lhs_vert = base_m + thread_x * 4 + (thread_y % 4) * 32;
for (Index k = 0; k < k_size; k += 32) {
lhs_pf0 = internal::pset1<float4>(0);
lhs_pf1 = internal::pset1<float4>(0);
@@ -785,123 +794,123 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
rhs_pf1 = internal::pset1<float4>(0);
if (!CHECK_LHS_BOUNDARY) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
if ((thread_y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 24));
} else if ((thread_y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
} else if ((thread_y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
} else if ((thread_y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
}
} else {
// just CHECK_LHS_BOUNDARY
if (lhs_vert + 3 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
if ((thread_y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 24));
} else if ((thread_y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
} else if ((thread_y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
} else if ((thread_y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
}
} else if (lhs_vert + 2 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
lhs_pf3.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
if ((thread_y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 24));
lhs_pf3.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 24));
} else if ((thread_y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 16));
} else if ((thread_y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
} else if ((thread_y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
}
} else if (lhs_vert + 1 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
if ((thread_y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 24));
} else if ((thread_y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
} else if ((thread_y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
} else if ((thread_y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
}
} else if (lhs_vert < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
if ((thread_y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
} else if ((thread_y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
} else if ((thread_y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
} else if ((thread_y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
}
}
}
__syncthreads();
Index rhs_vert = k + threadIdx.x * 4;
Index rhs_horiz0 = threadIdx.y * 2 + base_n;
Index rhs_horiz1 = threadIdx.y * 2 + 1 + base_n;
Index rhs_vert = k + thread_x * 4;
Index rhs_horiz0 = thread_y * 2 + base_n;
Index rhs_horiz1 = thread_y * 2 + 1 + base_n;
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
@@ -938,12 +947,12 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
} else if (k + threadIdx.x * 4 + 1 < k_size) {
} else if (k + thread_x * 4 + 1 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
} else if (k + threadIdx.x * 4 < k_size) {
} else if (k + thread_x * 4 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
}
@@ -970,17 +979,17 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
// Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
// ..
// Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
rhs_shmem2[thread_y][thread_x] = make_float2(rhs_pf0.x, rhs_pf1.x);
// Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
// Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
// ..
rhs_shmem2[threadIdx.y + 32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
rhs_shmem2[thread_y + 32][thread_x] = make_float2(rhs_pf0.y, rhs_pf1.y);
// Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
// Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
rhs_shmem2[threadIdx.y + 64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
rhs_shmem2[thread_y + 64][thread_x] = make_float2(rhs_pf0.z, rhs_pf1.z);
// Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
// Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
rhs_shmem2[threadIdx.y + 96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
rhs_shmem2[thread_y + 96][thread_x] = make_float2(rhs_pf0.w, rhs_pf1.w);
// LHS.
// Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
@@ -1026,26 +1035,26 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
results[6].w += a_feat2.y * f4.x; \
results[7].w += a_feat2.y * f4.y;
lhs_shmem2[threadIdx.y / 4][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[threadIdx.y / 4 + 8][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
lhs_shmem2[threadIdx.y / 4 + 16][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
lhs_shmem2[threadIdx.y / 4 + 24][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
lhs_shmem2[thread_y / 4][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[thread_y / 4 + 8][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
lhs_shmem2[thread_y / 4 + 16][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
lhs_shmem2[thread_y / 4 + 24][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
lhs_shmem2[threadIdx.y / 4 + 32][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
lhs_shmem2[threadIdx.y / 4 + 40][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
lhs_shmem2[threadIdx.y / 4 + 48][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
lhs_shmem2[threadIdx.y / 4 + 56][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
lhs_shmem2[thread_y / 4 + 32][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
lhs_shmem2[thread_y / 4 + 40][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
lhs_shmem2[thread_y / 4 + 48][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
lhs_shmem2[thread_y / 4 + 56][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
__syncthreads();
// Do the multiplies.
#pragma unroll
for (int koff = 0; koff < 32; koff++) {
float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
float2 a3 = lhs_shmem2[koff][thread_x + (thread_y % 4) * 8];
float2 a4 = lhs_shmem2[koff + 32][thread_x + (thread_y % 4) * 8];
// first feature is at (threadIdx.y/4) * 8 last is at start + 8.
int start_feature = (threadIdx.y / 4) * 8;
// first feature is at (thread_y/4) * 8 last is at start + 8.
int start_feature = (thread_y / 4) * 8;
float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
@@ -1060,7 +1069,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
#undef add_vals
__syncthreads();
Index horiz_base = (threadIdx.y / 4) * 8 + base_n;
Index horiz_base = (thread_y / 4) * 8 + base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;

View File

@@ -666,8 +666,8 @@ template <typename InputEvaluator, typename Index, typename InputDims>
__global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void EigenConvolutionKernel3D(
InputEvaluator eval, const internal::IndexMapper<Index, InputDims, 3, InputEvaluator::Layout> indexMapper,
const float* __restrict kernel, const size_t numPlanes, const size_t numX, const size_t maxX, const size_t numY,
const size_t maxY, const size_t numZ, const size_t maxZ, const size_t kernelSizeX, const size_t kernelSizeY,
const size_t kernelSizeZ, float* buffer) {
const size_t maxY, const size_t numZ, const size_t maxZ, const int kernelSizeX, const int kernelSizeY,
const int kernelSizeZ, float* buffer) {
#if defined(EIGEN_HIPCC)
HIP_DYNAMIC_SHARED(float, s)
#else
@@ -675,19 +675,25 @@ __global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void EigenConvolutionKernel3D(
#endif
// Load inputs to shared memory
const int first_x = blockIdx.x * maxX;
const int last_x = (first_x + maxX < numX ? first_x + maxX : numX) - 1;
const int first_x = blockIdx.x * static_cast<int>(maxX);
const int last_x = (first_x + static_cast<int>(maxX) < static_cast<int>(numX) ? first_x + static_cast<int>(maxX)
: static_cast<int>(numX)) -
1;
const int num_x_input = last_x - first_x + kernelSizeX;
const int first_y = blockIdx.y * maxY;
const int last_y = (first_y + maxY < numY ? first_y + maxY : numY) - 1;
const int first_y = blockIdx.y * static_cast<int>(maxY);
const int last_y = (first_y + static_cast<int>(maxY) < static_cast<int>(numY) ? first_y + static_cast<int>(maxY)
: static_cast<int>(numY)) -
1;
const int num_y_input = last_y - first_y + kernelSizeY;
const int first_z = blockIdx.z * maxZ;
const int last_z = (first_z + maxZ < numZ ? first_z + maxZ : numZ) - 1;
const int first_z = blockIdx.z * static_cast<int>(maxZ);
const int last_z = (first_z + static_cast<int>(maxZ) < static_cast<int>(numZ) ? first_z + static_cast<int>(maxZ)
: static_cast<int>(numZ)) -
1;
const int num_z_input = last_z - first_z + kernelSizeZ;
for (int p = 0; p < numPlanes; ++p) {
for (int p = 0; p < static_cast<int>(numPlanes); ++p) {
const int plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(p);
const int plane_kernel_offset = 0;

View File

@@ -91,7 +91,7 @@ class TensorExecutor {
TensorEvaluator<Expression, Device> evaluator(expr, device);
const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL);
if (needs_assign) {
const StorageIndex size = array_prod(evaluator.dimensions());
const StorageIndex size = static_cast<StorageIndex>(array_prod(evaluator.dimensions()));
for (StorageIndex i = 0; i < size; ++i) {
evaluator.evalScalar(i);
}
@@ -120,7 +120,7 @@ class TensorExecutor<Expression, DefaultDevice, /*Vectorizable=*/true,
TensorEvaluator<Expression, DefaultDevice> evaluator(expr, device);
const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL);
if (needs_assign) {
const StorageIndex size = array_prod(evaluator.dimensions());
const StorageIndex size = static_cast<StorageIndex>(array_prod(evaluator.dimensions()));
const int PacketSize =
unpacket_traits<typename TensorEvaluator<Expression, DefaultDevice>::PacketReturnType>::size;
@@ -302,7 +302,7 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, Tiling> {
Evaluator evaluator(expr, device);
const bool needs_assign = evaluator.evalSubExprsIfNeeded(nullptr);
if (needs_assign) {
const StorageIndex size = array_prod(evaluator.dimensions());
const StorageIndex size = static_cast<StorageIndex>(array_prod(evaluator.dimensions()));
device.parallelFor(
size, evaluator.costPerCoeff(Vectorizable), EvalRange::alignBlockSize,
[&evaluator](StorageIndex firstIdx, StorageIndex lastIdx) { EvalRange::run(&evaluator, firstIdx, lastIdx); });
@@ -375,7 +375,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, Vectorizab
}
typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange;
const StorageIndex size = array_prod(ctx->evaluator.dimensions());
const StorageIndex size = static_cast<StorageIndex>(array_prod(ctx->evaluator.dimensions()));
device.parallelForAsync(
size, ctx->evaluator.costPerCoeff(Vectorizable), EvalRange::alignBlockSize,
[ctx](StorageIndex firstIdx, StorageIndex lastIdx) { EvalRange::run(&ctx->evaluator, firstIdx, lastIdx); },
@@ -583,7 +583,7 @@ EIGEN_STRONG_INLINE void TensorExecutor<Expression, GpuDevice, Vectorizable, Til
numext::mini<int64_t>(device.getNumGpuMultiProcessors() * device.maxGpuThreadsPerMultiProcessor(),
NumTraits<StorageIndex>::highest()) /
block_size);
const StorageIndex size = array_prod(evaluator.dimensions());
const StorageIndex size = static_cast<StorageIndex>(array_prod(evaluator.dimensions()));
// Create at least one block to ensure we don't crash with tensors of size 0.
const int num_blocks = numext::maxi<int>(
numext::mini<int>(max_blocks, static_cast<int>(numext::div_ceil<StorageIndex>(size, block_size))), 1);

View File

@@ -313,7 +313,8 @@ namespace internal {
// FIXME: Figure out the exact threshold.
template <typename Index, typename Device, bool BlockAccess>
struct MemcpyTriggerForSlicing {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) {}
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device)
: threshold_(static_cast<Index>(2 * device.numThreads())) {}
EIGEN_DEVICE_FUNC bool operator()(Index total, Index contiguous) const {
const bool prefer_block_evaluation = BlockAccess && total > 32 * 1024;
return !prefer_block_evaluation && contiguous > threshold_;

View File

@@ -693,7 +693,7 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M
if (internal::InnerReducer<Self, Op, Device>::HasOptimizedImplementation &&
(reducing_inner_dims || ReducingInnerMostDims)) {
const Index num_values_to_reduce = internal::array_prod(m_reducedDims);
const Index num_coeffs_to_preserve = internal::array_prod(m_dimensions);
const Index num_coeffs_to_preserve = static_cast<Index>(internal::array_prod(m_dimensions));
if (!data) {
if ((num_coeffs_to_preserve < 1024 && num_values_to_reduce > num_coeffs_to_preserve &&
num_values_to_reduce > 128) ||
@@ -729,7 +729,7 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M
}
if (internal::OuterReducer<Self, Op, Device>::HasOptimizedImplementation && preserving_inner_dims) {
const Index num_values_to_reduce = internal::array_prod(m_reducedDims);
const Index num_coeffs_to_preserve = internal::array_prod(m_dimensions);
const Index num_coeffs_to_preserve = static_cast<Index>(internal::array_prod(m_dimensions));
if (!data) {
if ((num_coeffs_to_preserve < 1024 && num_values_to_reduce > num_coeffs_to_preserve &&
num_values_to_reduce > 32) ||
@@ -759,7 +759,7 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M
// must break into two subexpression and use the SYCL generic Reducer on the device.
if (RunningOnSycl) {
const Index num_values_to_reduce = internal::array_prod(m_reducedDims);
const Index num_coeffs_to_preserve = internal::array_prod(m_dimensions);
const Index num_coeffs_to_preserve = static_cast<Index>(internal::array_prod(m_dimensions));
if (!data) {
data = static_cast<EvaluatorPointerType>(
m_device.get((CoeffReturnType*)m_device.allocate_temp(sizeof(CoeffReturnType) * num_coeffs_to_preserve)));

View File

@@ -556,6 +556,11 @@ __global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void InnerReductionKernel(Reducer reduce
}
}
#else // EIGEN_CUDA_ARCH >= 300
EIGEN_UNUSED_VARIABLE(reducer);
EIGEN_UNUSED_VARIABLE(input);
EIGEN_UNUSED_VARIABLE(num_coeffs_to_reduce);
EIGEN_UNUSED_VARIABLE(num_preserved_coeffs);
EIGEN_UNUSED_VARIABLE(output);
gpu_assert(0 && "Shouldn't be called on unsupported device");
#endif // EIGEN_CUDA_ARCH >= 300
}

View File

@@ -35,8 +35,8 @@ void test_cuda_complex_cwise_ops() {
Eigen::TensorMap<Eigen::Tensor<std::complex<T>, 1, 0, int>, Eigen::Aligned> gpu_in2(d_in2, kNumItems);
Eigen::TensorMap<Eigen::Tensor<std::complex<T>, 1, 0, int>, Eigen::Aligned> gpu_out(d_out, kNumItems);
const std::complex<T> a(3.14f, 2.7f);
const std::complex<T> b(-10.6f, 1.4f);
const std::complex<T> a(static_cast<T>(3.14), static_cast<T>(2.7));
const std::complex<T> b(static_cast<T>(-10.6), static_cast<T>(1.4));
gpu_in1.device(gpu_device) = gpu_in1.constant(a);
gpu_in2.device(gpu_device) = gpu_in2.constant(b);

View File

@@ -176,7 +176,7 @@ void test_3d_convolution(Context* context) {
// Helper method to synchronize device.
template <typename Device>
void synchronize(Device& device) { /*nothing*/
void synchronize(Device& /*device*/) { /*nothing*/
}
template <>
void synchronize(Eigen::GpuDevice& device) {
@@ -197,7 +197,7 @@ void test_device_memory(const TensorDevice& device) {
device.memcpyDeviceToHost(host.data(), device_data, count * sizeof(DataType));
synchronize(device);
memset(expected.data(), byte_value, count * sizeof(DataType));
for (size_t i = 0; i < count; i++) {
for (Index i = 0; i < count; i++) {
VERIFY_IS_EQUAL(host(i), expected(i));
}