|
|
|
|
@@ -30,6 +30,9 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
|
|
|
|
|
const Index base_m = 64 * m_block_idx;
|
|
|
|
|
const Index base_n = 64 * n_block_idx;
|
|
|
|
|
const Index thread_x = threadIdx.x;
|
|
|
|
|
const Index thread_y = threadIdx.y;
|
|
|
|
|
const Index thread_z = threadIdx.z;
|
|
|
|
|
|
|
|
|
|
// declare and initialize 64 registers for output 8x8 block
|
|
|
|
|
|
|
|
|
|
@@ -66,8 +69,8 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
// conflicts on writes and also none on reads.
|
|
|
|
|
|
|
|
|
|
// storage indices
|
|
|
|
|
const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
|
|
|
|
|
const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
|
|
|
|
|
const Index lhs_store_idx_base = thread_y * 72 + thread_x * 9 + thread_z;
|
|
|
|
|
const Index rhs_store_idx_base = thread_y * 72 + thread_z * 8 + thread_x;
|
|
|
|
|
|
|
|
|
|
const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
|
|
|
|
|
const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
|
|
|
|
|
@@ -88,151 +91,151 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
|
|
|
|
|
|
|
|
|
|
// in the loading code, the following variables are important:
|
|
|
|
|
// threadIdx.x: the vertical position in an 8x8 block
|
|
|
|
|
// threadIdx.y: the vertical index of the 8x8 block in the grid
|
|
|
|
|
// threadIdx.z: the horizontal position in an 8x8 block
|
|
|
|
|
// thread_x: the vertical position in an 8x8 block
|
|
|
|
|
// thread_y: the vertical index of the 8x8 block in the grid
|
|
|
|
|
// thread_z: the horizontal position in an 8x8 block
|
|
|
|
|
// k: the horizontal index of the 8x8 block in the grid
|
|
|
|
|
//
|
|
|
|
|
// The k parameter is implicit (it was the loop counter for a loop that went
|
|
|
|
|
// from 0 to <8, but now that loop is unrolled in the below code.
|
|
|
|
|
|
|
|
|
|
const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
|
|
|
|
|
const Index load_idx_vert = thread_x + 8 * thread_y;
|
|
|
|
|
const Index lhs_vert = base_m + load_idx_vert;
|
|
|
|
|
|
|
|
|
|
#define prefetchIntoRegisters(base_k) \
|
|
|
|
|
{ \
|
|
|
|
|
lhs_pf0 = conv(0); \
|
|
|
|
|
lhs_pf1 = conv(0); \
|
|
|
|
|
lhs_pf2 = conv(0); \
|
|
|
|
|
lhs_pf3 = conv(0); \
|
|
|
|
|
lhs_pf4 = conv(0); \
|
|
|
|
|
lhs_pf5 = conv(0); \
|
|
|
|
|
lhs_pf6 = conv(0); \
|
|
|
|
|
lhs_pf7 = conv(0); \
|
|
|
|
|
\
|
|
|
|
|
rhs_pf0 = conv(0); \
|
|
|
|
|
rhs_pf1 = conv(0); \
|
|
|
|
|
rhs_pf2 = conv(0); \
|
|
|
|
|
rhs_pf3 = conv(0); \
|
|
|
|
|
rhs_pf4 = conv(0); \
|
|
|
|
|
rhs_pf5 = conv(0); \
|
|
|
|
|
rhs_pf6 = conv(0); \
|
|
|
|
|
rhs_pf7 = conv(0); \
|
|
|
|
|
\
|
|
|
|
|
if (!needs_edge_check || lhs_vert < m_size) { \
|
|
|
|
|
const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
|
|
|
|
|
const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
|
|
|
|
|
const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
|
|
|
|
|
const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
|
|
|
|
|
const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
|
|
|
|
|
const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
|
|
|
|
|
const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
|
|
|
|
|
const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
|
|
|
|
|
\
|
|
|
|
|
if (!needs_edge_check || lhs_horiz_7 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
|
|
|
|
|
lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
|
|
|
|
|
} else if (lhs_horiz_6 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
|
|
|
|
|
} else if (lhs_horiz_5 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
} else if (lhs_horiz_4 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
} else if (lhs_horiz_3 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
} else if (lhs_horiz_2 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
} else if (lhs_horiz_1 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
} else if (lhs_horiz_0 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
\
|
|
|
|
|
const Index rhs_vert = base_k + load_idx_vert; \
|
|
|
|
|
if (!needs_edge_check || rhs_vert < k_size) { \
|
|
|
|
|
const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
|
|
|
|
|
const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
|
|
|
|
|
const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
|
|
|
|
|
const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
|
|
|
|
|
const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
|
|
|
|
|
const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
|
|
|
|
|
const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
|
|
|
|
|
const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
|
|
|
|
|
\
|
|
|
|
|
if (rhs_horiz_7 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
|
|
|
|
|
rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
|
|
|
|
|
} else if (rhs_horiz_6 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
|
|
|
|
|
} else if (rhs_horiz_5 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
} else if (rhs_horiz_4 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
} else if (rhs_horiz_3 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
} else if (rhs_horiz_2 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
} else if (rhs_horiz_1 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
} else if (rhs_horiz_0 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
#define prefetchIntoRegisters(base_k) \
|
|
|
|
|
{ \
|
|
|
|
|
lhs_pf0 = conv(0); \
|
|
|
|
|
lhs_pf1 = conv(0); \
|
|
|
|
|
lhs_pf2 = conv(0); \
|
|
|
|
|
lhs_pf3 = conv(0); \
|
|
|
|
|
lhs_pf4 = conv(0); \
|
|
|
|
|
lhs_pf5 = conv(0); \
|
|
|
|
|
lhs_pf6 = conv(0); \
|
|
|
|
|
lhs_pf7 = conv(0); \
|
|
|
|
|
\
|
|
|
|
|
rhs_pf0 = conv(0); \
|
|
|
|
|
rhs_pf1 = conv(0); \
|
|
|
|
|
rhs_pf2 = conv(0); \
|
|
|
|
|
rhs_pf3 = conv(0); \
|
|
|
|
|
rhs_pf4 = conv(0); \
|
|
|
|
|
rhs_pf5 = conv(0); \
|
|
|
|
|
rhs_pf6 = conv(0); \
|
|
|
|
|
rhs_pf7 = conv(0); \
|
|
|
|
|
\
|
|
|
|
|
if (!needs_edge_check || lhs_vert < m_size) { \
|
|
|
|
|
const Index lhs_horiz_0 = base_k + thread_z + 0 * 8; \
|
|
|
|
|
const Index lhs_horiz_1 = base_k + thread_z + 1 * 8; \
|
|
|
|
|
const Index lhs_horiz_2 = base_k + thread_z + 2 * 8; \
|
|
|
|
|
const Index lhs_horiz_3 = base_k + thread_z + 3 * 8; \
|
|
|
|
|
const Index lhs_horiz_4 = base_k + thread_z + 4 * 8; \
|
|
|
|
|
const Index lhs_horiz_5 = base_k + thread_z + 5 * 8; \
|
|
|
|
|
const Index lhs_horiz_6 = base_k + thread_z + 6 * 8; \
|
|
|
|
|
const Index lhs_horiz_7 = base_k + thread_z + 7 * 8; \
|
|
|
|
|
\
|
|
|
|
|
if (!needs_edge_check || lhs_horiz_7 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
|
|
|
|
|
lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
|
|
|
|
|
} else if (lhs_horiz_6 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
|
|
|
|
|
} else if (lhs_horiz_5 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
|
|
|
|
|
} else if (lhs_horiz_4 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
|
|
|
|
|
} else if (lhs_horiz_3 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
|
|
|
|
|
} else if (lhs_horiz_2 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
|
|
|
|
|
} else if (lhs_horiz_1 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
|
|
|
|
|
} else if (lhs_horiz_0 < k_size) { \
|
|
|
|
|
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
\
|
|
|
|
|
const Index rhs_vert = base_k + load_idx_vert; \
|
|
|
|
|
if (!needs_edge_check || rhs_vert < k_size) { \
|
|
|
|
|
const Index rhs_horiz_0 = base_n + thread_z + 0 * 8; \
|
|
|
|
|
const Index rhs_horiz_1 = base_n + thread_z + 1 * 8; \
|
|
|
|
|
const Index rhs_horiz_2 = base_n + thread_z + 2 * 8; \
|
|
|
|
|
const Index rhs_horiz_3 = base_n + thread_z + 3 * 8; \
|
|
|
|
|
const Index rhs_horiz_4 = base_n + thread_z + 4 * 8; \
|
|
|
|
|
const Index rhs_horiz_5 = base_n + thread_z + 5 * 8; \
|
|
|
|
|
const Index rhs_horiz_6 = base_n + thread_z + 6 * 8; \
|
|
|
|
|
const Index rhs_horiz_7 = base_n + thread_z + 7 * 8; \
|
|
|
|
|
\
|
|
|
|
|
if (rhs_horiz_7 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
|
|
|
|
|
rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
|
|
|
|
|
} else if (rhs_horiz_6 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
|
|
|
|
|
} else if (rhs_horiz_5 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
|
|
|
|
|
} else if (rhs_horiz_4 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
|
|
|
|
|
} else if (rhs_horiz_3 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
|
|
|
|
|
} else if (rhs_horiz_2 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
|
|
|
|
|
} else if (rhs_horiz_1 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
|
|
|
|
|
} else if (rhs_horiz_0 < n_size) { \
|
|
|
|
|
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define writeRegToShmem() \
|
|
|
|
|
@@ -321,8 +324,8 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
Scalar rrow(7);
|
|
|
|
|
|
|
|
|
|
// Now x corresponds to k, y to m, and z to n
|
|
|
|
|
const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
|
|
|
|
|
const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
|
|
|
|
|
const Scalar* lhs_block = &lhs_shmem[thread_x + 9 * thread_y];
|
|
|
|
|
const Scalar* rhs_block = &rhs_shmem[thread_x + 8 * thread_z];
|
|
|
|
|
|
|
|
|
|
#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
|
|
|
|
|
#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
|
|
|
|
|
@@ -441,7 +444,7 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
// wait for shared mem to be out of use
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
#define writeResultShmem(i, j) lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j);
|
|
|
|
|
#define writeResultShmem(i, j) lhs_shmem[i + 8 * thread_y + 64 * thread_z + 512 * j] = res(i, j);
|
|
|
|
|
|
|
|
|
|
#define writeRow(i) \
|
|
|
|
|
writeResultShmem(i, 0); \
|
|
|
|
|
@@ -453,7 +456,7 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
writeResultShmem(i, 6); \
|
|
|
|
|
writeResultShmem(i, 7);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
if (thread_x == 0) {
|
|
|
|
|
writeRow(0);
|
|
|
|
|
writeRow(1);
|
|
|
|
|
writeRow(2);
|
|
|
|
|
@@ -466,34 +469,34 @@ __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapp
|
|
|
|
|
#undef writeResultShmem
|
|
|
|
|
#undef writeRow
|
|
|
|
|
|
|
|
|
|
const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
|
|
|
|
|
const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
|
|
|
|
|
const int max_i_write = numext::mini((int)((m_size - base_m - thread_y + 7) / 8), 8);
|
|
|
|
|
const int max_j_write = numext::mini((int)((n_size - base_n - thread_z + 7) / 8), 8);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x < max_i_write) {
|
|
|
|
|
if (thread_x < max_i_write) {
|
|
|
|
|
if (max_j_write == 8) {
|
|
|
|
|
// TODO: Can we trade bank conflicts for coalesced writes?
|
|
|
|
|
Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
|
|
|
|
|
Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
|
|
|
|
|
Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
|
|
|
|
|
Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
|
|
|
|
|
Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
|
|
|
|
|
Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
|
|
|
|
|
Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
|
|
|
|
|
Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
|
|
|
|
|
Scalar val0 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 0];
|
|
|
|
|
Scalar val1 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 1];
|
|
|
|
|
Scalar val2 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 2];
|
|
|
|
|
Scalar val3 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 3];
|
|
|
|
|
Scalar val4 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 4];
|
|
|
|
|
Scalar val5 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 5];
|
|
|
|
|
Scalar val6 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 6];
|
|
|
|
|
Scalar val7 = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * 7];
|
|
|
|
|
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 0) = val0;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 1) = val1;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 2) = val2;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 3) = val3;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 4) = val4;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 5) = val5;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 6) = val6;
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * 7) = val7;
|
|
|
|
|
} else {
|
|
|
|
|
#pragma unroll 7
|
|
|
|
|
for (int j = 0; j < max_j_write; j++) {
|
|
|
|
|
Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
|
|
|
|
|
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
|
|
|
|
|
Scalar val = lhs_shmem[thread_x + 8 * thread_y + 64 * thread_z + 512 * j];
|
|
|
|
|
output(base_m + thread_y + 8 * thread_x, base_n + thread_z + 8 * j) = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -539,6 +542,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
float4 lhs_pf0, rhs_pf0;
|
|
|
|
|
|
|
|
|
|
float4 results[4];
|
|
|
|
|
const Index thread_x = threadIdx.x;
|
|
|
|
|
const Index thread_y = threadIdx.y;
|
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
|
|
|
results[i].x = results[i].y = results[i].z = results[i].w = 0;
|
|
|
|
|
}
|
|
|
|
|
@@ -565,17 +570,17 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Index lhs_vert = base_m + threadIdx.x * 4;
|
|
|
|
|
Index lhs_vert = base_m + thread_x * 4;
|
|
|
|
|
|
|
|
|
|
for (Index k = 0; k < k_size; k += 16) {
|
|
|
|
|
lhs_pf0 = internal::pset1<float4>(0);
|
|
|
|
|
rhs_pf0 = internal::pset1<float4>(0);
|
|
|
|
|
|
|
|
|
|
Index lhs_horiz = threadIdx.y + k;
|
|
|
|
|
Index lhs_horiz = thread_y + k;
|
|
|
|
|
prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
|
|
|
|
|
|
|
|
|
|
Index rhs_vert = k + (threadIdx.x % 4) * 4;
|
|
|
|
|
Index rhs_horiz0 = (threadIdx.x >> 2) + threadIdx.y * 4 + base_n;
|
|
|
|
|
Index rhs_vert = k + (thread_x % 4) * 4;
|
|
|
|
|
Index rhs_horiz0 = (thread_x >> 2) + thread_y * 4 + base_n;
|
|
|
|
|
|
|
|
|
|
if (!CHECK_RHS_BOUNDARY) {
|
|
|
|
|
if ((rhs_vert + 3) < k_size) {
|
|
|
|
|
@@ -610,7 +615,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
}
|
|
|
|
|
float x1, x2;
|
|
|
|
|
// TODO: The following can be a bitwise operation.
|
|
|
|
|
if ((threadIdx.x % 8) < 4) {
|
|
|
|
|
if ((thread_x % 8) < 4) {
|
|
|
|
|
x1 = rhs_pf0.y;
|
|
|
|
|
x2 = rhs_pf0.w;
|
|
|
|
|
} else {
|
|
|
|
|
@@ -624,7 +629,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
|
|
|
|
|
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
|
|
|
|
|
#endif
|
|
|
|
|
if ((threadIdx.x % 8) < 4) {
|
|
|
|
|
if ((thread_x % 8) < 4) {
|
|
|
|
|
rhs_pf0.y = x1;
|
|
|
|
|
rhs_pf0.w = x2;
|
|
|
|
|
} else {
|
|
|
|
|
@@ -639,8 +644,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
// Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
|
|
|
|
|
// Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
|
|
|
|
|
// ...
|
|
|
|
|
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2][threadIdx.x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
|
|
|
|
|
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2 + 32][threadIdx.x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
|
|
|
|
|
rhs_shmem2[(thread_x >> 3) + thread_y * 2][thread_x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
|
|
|
|
|
rhs_shmem2[(thread_x >> 3) + thread_y * 2 + 32][thread_x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
|
|
|
|
|
|
|
|
|
|
// Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
|
|
|
|
|
// Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
|
|
|
|
|
@@ -649,8 +654,8 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
// Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
|
|
|
|
|
// ...
|
|
|
|
|
|
|
|
|
|
lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
|
|
|
|
|
lhs_shmem2[threadIdx.y + 16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
|
|
|
|
|
lhs_shmem2[thread_y][thread_x] = make_float2(lhs_pf0.x, lhs_pf0.y);
|
|
|
|
|
lhs_shmem2[thread_y + 16][thread_x] = make_float2(lhs_pf0.z, lhs_pf0.w);
|
|
|
|
|
|
|
|
|
|
#define add_vals(fl1, fl2, fr1, fr2) \
|
|
|
|
|
results[0].x += fl1.x * fr1.x; \
|
|
|
|
|
@@ -679,10 +684,10 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int koff = 0; koff < 16; koff++) {
|
|
|
|
|
// 32 x threads.
|
|
|
|
|
float2 fl1 = lhs_shmem2[koff][threadIdx.x];
|
|
|
|
|
float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
|
|
|
|
|
float2 fl1 = lhs_shmem2[koff][thread_x];
|
|
|
|
|
float2 fl2 = lhs_shmem2[koff + 16][thread_x];
|
|
|
|
|
|
|
|
|
|
int start_feature = threadIdx.y * 4;
|
|
|
|
|
int start_feature = thread_y * 4;
|
|
|
|
|
float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
|
|
|
|
|
float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
|
|
|
|
|
|
|
|
|
|
@@ -694,7 +699,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const L
|
|
|
|
|
#undef prefetch_lhs
|
|
|
|
|
#undef add_vals
|
|
|
|
|
|
|
|
|
|
Index horiz_base = threadIdx.y * 4 + base_n;
|
|
|
|
|
Index horiz_base = thread_y * 4 + base_n;
|
|
|
|
|
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
|
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
|
|
|
output(lhs_vert, horiz_base + i) = results[i].x;
|
|
|
|
|
@@ -770,11 +775,15 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
float4 rhs_pf0, rhs_pf1;
|
|
|
|
|
|
|
|
|
|
float4 results[8];
|
|
|
|
|
const Index thread_x = threadIdx.x;
|
|
|
|
|
const Index thread_y = threadIdx.y;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 8; i++) {
|
|
|
|
|
results[i].x = results[i].y = results[i].z = results[i].w = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Index lhs_vert = base_m + threadIdx.x * 4 + (threadIdx.y % 4) * 32;
|
|
|
|
|
Index lhs_vert = base_m + thread_x * 4 + (thread_y % 4) * 32;
|
|
|
|
|
|
|
|
|
|
for (Index k = 0; k < k_size; k += 32) {
|
|
|
|
|
lhs_pf0 = internal::pset1<float4>(0);
|
|
|
|
|
lhs_pf1 = internal::pset1<float4>(0);
|
|
|
|
|
@@ -785,123 +794,123 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
rhs_pf1 = internal::pset1<float4>(0);
|
|
|
|
|
|
|
|
|
|
if (!CHECK_LHS_BOUNDARY) {
|
|
|
|
|
if ((threadIdx.y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
if ((thread_y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 24));
|
|
|
|
|
} else if ((thread_y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
|
|
|
|
|
} else if ((thread_y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
} else if ((thread_y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// just CHECK_LHS_BOUNDARY
|
|
|
|
|
if (lhs_vert + 3 < m_size) {
|
|
|
|
|
if ((threadIdx.y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
|
|
|
|
|
if ((thread_y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 24));
|
|
|
|
|
} else if ((thread_y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 16));
|
|
|
|
|
} else if ((thread_y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k + 8));
|
|
|
|
|
} else if ((thread_y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (thread_y / 4 + k));
|
|
|
|
|
}
|
|
|
|
|
} else if (lhs_vert + 2 < m_size) {
|
|
|
|
|
if ((threadIdx.y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
|
|
|
|
|
if ((thread_y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 24));
|
|
|
|
|
} else if ((thread_y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 16));
|
|
|
|
|
} else if ((thread_y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.z = lhs(lhs_vert + 2, (thread_y / 4 + k + 8));
|
|
|
|
|
} else if ((thread_y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.z = lhs(lhs_vert + 2, (thread_y / 4 + k));
|
|
|
|
|
}
|
|
|
|
|
} else if (lhs_vert + 1 < m_size) {
|
|
|
|
|
if ((threadIdx.y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
|
|
|
|
|
if ((thread_y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
|
|
|
|
|
lhs_pf3.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 24));
|
|
|
|
|
} else if ((thread_y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf2.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 16));
|
|
|
|
|
} else if ((thread_y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf1.y = lhs(lhs_vert + 1, (thread_y / 4 + k + 8));
|
|
|
|
|
} else if ((thread_y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf0.y = lhs(lhs_vert + 1, (thread_y / 4 + k));
|
|
|
|
|
}
|
|
|
|
|
} else if (lhs_vert < m_size) {
|
|
|
|
|
if ((threadIdx.y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
|
|
|
|
|
} else if ((threadIdx.y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
|
|
|
|
|
if ((thread_y / 4 + k + 24) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
lhs_pf3.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 24));
|
|
|
|
|
} else if ((thread_y / 4 + k + 16) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
lhs_pf2.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 16));
|
|
|
|
|
} else if ((thread_y / 4 + k + 8) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
lhs_pf1.x = lhs(lhs_vert + 0, (thread_y / 4 + k + 8));
|
|
|
|
|
} else if ((thread_y / 4 + k) < k_size) {
|
|
|
|
|
lhs_pf0.x = lhs(lhs_vert + 0, (thread_y / 4 + k));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
Index rhs_vert = k + threadIdx.x * 4;
|
|
|
|
|
Index rhs_horiz0 = threadIdx.y * 2 + base_n;
|
|
|
|
|
Index rhs_horiz1 = threadIdx.y * 2 + 1 + base_n;
|
|
|
|
|
Index rhs_vert = k + thread_x * 4;
|
|
|
|
|
Index rhs_horiz0 = thread_y * 2 + base_n;
|
|
|
|
|
Index rhs_horiz1 = thread_y * 2 + 1 + base_n;
|
|
|
|
|
if (!CHECK_RHS_BOUNDARY) {
|
|
|
|
|
if ((rhs_vert + 3) < k_size) {
|
|
|
|
|
// just CHECK_RHS_BOUNDARY
|
|
|
|
|
@@ -938,12 +947,12 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
|
|
|
|
|
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
|
|
|
|
|
rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
|
|
|
|
|
} else if (k + threadIdx.x * 4 + 1 < k_size) {
|
|
|
|
|
} else if (k + thread_x * 4 + 1 < k_size) {
|
|
|
|
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
|
|
|
|
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
|
|
|
|
|
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
|
|
|
|
|
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
|
|
|
|
|
} else if (k + threadIdx.x * 4 < k_size) {
|
|
|
|
|
} else if (k + thread_x * 4 < k_size) {
|
|
|
|
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
|
|
|
|
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
|
|
|
|
|
}
|
|
|
|
|
@@ -970,17 +979,17 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
// Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
|
|
|
|
|
// ..
|
|
|
|
|
// Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
|
|
|
|
|
rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
|
|
|
|
|
rhs_shmem2[thread_y][thread_x] = make_float2(rhs_pf0.x, rhs_pf1.x);
|
|
|
|
|
// Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
|
|
|
|
|
// Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
|
|
|
|
|
// ..
|
|
|
|
|
rhs_shmem2[threadIdx.y + 32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
|
|
|
|
|
rhs_shmem2[thread_y + 32][thread_x] = make_float2(rhs_pf0.y, rhs_pf1.y);
|
|
|
|
|
// Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
|
|
|
|
|
// Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
|
|
|
|
|
rhs_shmem2[threadIdx.y + 64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
|
|
|
|
|
rhs_shmem2[thread_y + 64][thread_x] = make_float2(rhs_pf0.z, rhs_pf1.z);
|
|
|
|
|
// Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
|
|
|
|
|
// Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
|
|
|
|
|
rhs_shmem2[threadIdx.y + 96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
|
|
|
|
|
rhs_shmem2[thread_y + 96][thread_x] = make_float2(rhs_pf0.w, rhs_pf1.w);
|
|
|
|
|
|
|
|
|
|
// LHS.
|
|
|
|
|
// Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
|
|
|
|
|
@@ -1026,26 +1035,26 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
results[6].w += a_feat2.y * f4.x; \
|
|
|
|
|
results[7].w += a_feat2.y * f4.y;
|
|
|
|
|
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 8][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 16][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 24][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
|
|
|
|
|
lhs_shmem2[thread_y / 4][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 8][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 16][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 24][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
|
|
|
|
|
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 32][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 40][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 48][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
|
|
|
|
|
lhs_shmem2[threadIdx.y / 4 + 56][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 32][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 40][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 48][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
|
|
|
|
|
lhs_shmem2[thread_y / 4 + 56][thread_x + (thread_y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
// Do the multiplies.
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int koff = 0; koff < 32; koff++) {
|
|
|
|
|
float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
|
|
|
|
|
float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
|
|
|
|
|
float2 a3 = lhs_shmem2[koff][thread_x + (thread_y % 4) * 8];
|
|
|
|
|
float2 a4 = lhs_shmem2[koff + 32][thread_x + (thread_y % 4) * 8];
|
|
|
|
|
|
|
|
|
|
// first feature is at (threadIdx.y/4) * 8 last is at start + 8.
|
|
|
|
|
int start_feature = (threadIdx.y / 4) * 8;
|
|
|
|
|
// first feature is at (thread_y/4) * 8 last is at start + 8.
|
|
|
|
|
int start_feature = (thread_y / 4) * 8;
|
|
|
|
|
|
|
|
|
|
float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
|
|
|
|
|
float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
|
|
|
|
|
@@ -1060,7 +1069,7 @@ __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMap
|
|
|
|
|
#undef add_vals
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
Index horiz_base = (threadIdx.y / 4) * 8 + base_n;
|
|
|
|
|
Index horiz_base = (thread_y / 4) * 8 + base_n;
|
|
|
|
|
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
|
|
|
|
|
for (int i = 0; i < 8; i++) {
|
|
|
|
|
output(lhs_vert, horiz_base + i) = results[i].x;
|
|
|
|
|
|