Optimize GEMV kernels: row-major small-cols and template deduplication

libeigen/eigen!2151

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-19 15:06:24 -08:00
parent 9c63d26dec
commit 53e3408cb7

View File

@@ -99,8 +99,81 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, Conj
EIGEN_DEVICE_FUNC inline static void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, EIGEN_DEVICE_FUNC inline static void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
ResScalar* res, Index resIncr, RhsScalar alpha); ResScalar* res, Index resIncr, RhsScalar alpha);
template <int N>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows(
Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res,
const ResPacket& palpha, conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>& pcj);
}; };
// Recursive template unroller for col-major GEMV full-packet row blocks.
// Unrolls the packet dimension (K = 0..N-1) at compile time, guaranteeing
// that each accumulator lives in its own register variable.
template <int K, int N>
struct gemv_colmajor_unroller {
template <typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) {
gemv_colmajor_unroller<K - 1, N>::init_zero(c);
c[K] = pzero(Packet{});
}
template <typename LhsPacket, int LhsStride, int Alignment, typename AccPacket, typename RhsPacket,
typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j,
const RhsPacket& b0, ConjHelper& pcj) {
gemv_colmajor_unroller<K - 1, N>::template madd<LhsPacket, LhsStride, Alignment>(c, lhs, i, j, b0, pcj);
c[K] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i + LhsStride * K, j), b0, c[K]);
}
template <typename ResPacket, int ResStride, typename ResScalar>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i,
const ResPacket& palpha) {
gemv_colmajor_unroller<K - 1, N>::template store<ResPacket, ResStride>(c, res, i, palpha);
pstoreu(res + i + ResStride * K, pmadd(c[K], palpha, ploadu<ResPacket>(res + i + ResStride * K)));
}
};
template <int N>
struct gemv_colmajor_unroller<0, N> {
template <typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) {
c[0] = pzero(Packet{});
}
template <typename LhsPacket, int LhsStride, int Alignment, typename AccPacket, typename RhsPacket,
typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j,
const RhsPacket& b0, ConjHelper& pcj) {
c[0] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i, j), b0, c[0]);
}
template <typename ResPacket, int ResStride, typename ResScalar>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i,
const ResPacket& palpha) {
pstoreu(res + i, pmadd(c[0], palpha, ploadu<ResPacket>(res + i)));
}
};
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
typename RhsMapper, bool ConjugateRhs, int Version>
template <int N>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void general_matrix_vector_product<
Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
Version>::process_rows(Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res,
const ResPacket& palpha,
conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>& pcj) {
enum { LhsAlignment = Unaligned, LhsPacketSize = Traits::LhsPacketSize, ResPacketSize = Traits::ResPacketSize };
using Unroller = gemv_colmajor_unroller<N - 1, N>;
ResPacket c[N];
Unroller::init_zero(c);
for (Index j = j2; j < jend; ++j) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
Unroller::template madd<LhsPacket, LhsPacketSize, LhsAlignment>(c, lhs, i, j, b0, pcj);
}
Unroller::template store<ResPacket, ResPacketSize>(c, res, i, palpha);
}
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
typename RhsMapper, bool ConjugateRhs, int Version> typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DEVICE_FUNC inline void EIGEN_DEVICE_FUNC inline void
@@ -140,7 +213,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLh
const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1; const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1;
// TODO: improve the following heuristic: // TODO: improve the following heuristic:
const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 32000 ? 16 : 4); const Index block_cols = cols < 128 ? cols : (lhsStride * Index(sizeof(LhsScalar)) < 32000 ? Index(16) : Index(4));
ResPacket palpha = pset1<ResPacket>(alpha); ResPacket palpha = pset1<ResPacket>(alpha);
ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha); ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha); ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
@@ -148,81 +221,21 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLh
for (Index j2 = 0; j2 < cols; j2 += block_cols) { for (Index j2 = 0; j2 < cols; j2 += block_cols) {
Index jend = numext::mini(j2 + block_cols, cols); Index jend = numext::mini(j2 + block_cols, cols);
Index i = 0; Index i = 0;
for (; i < n8; i += ResPacketSize * 8) { for (; i < n8; i += ResPacketSize * 8) process_rows<8>(i, j2, jend, lhs, rhs, res, palpha, pcj);
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{}),
c4 = pzero(ResPacket{}), c5 = pzero(ResPacket{}), c6 = pzero(ResPacket{}), c7 = pzero(ResPacket{});
for (Index j = j2; j < jend; j += 1) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 4, j), b0, c4);
c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 5, j), b0, c5);
c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 6, j), b0, c6);
c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 7, j), b0, c7);
}
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
pstoreu(res + i + ResPacketSize * 4, pmadd(c4, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 4)));
pstoreu(res + i + ResPacketSize * 5, pmadd(c5, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 5)));
pstoreu(res + i + ResPacketSize * 6, pmadd(c6, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 6)));
pstoreu(res + i + ResPacketSize * 7, pmadd(c7, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 7)));
}
if (i < n4) { if (i < n4) {
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{}); process_rows<4>(i, j2, jend, lhs, rhs, res, palpha, pcj);
for (Index j = j2; j < jend; j += 1) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
}
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
i += ResPacketSize * 4; i += ResPacketSize * 4;
} }
if (i < n3) { if (i < n3) {
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}); process_rows<3>(i, j2, jend, lhs, rhs, res, palpha, pcj);
for (Index j = j2; j < jend; j += 1) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
}
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
i += ResPacketSize * 3; i += ResPacketSize * 3;
} }
if (i < n2) { if (i < n2) {
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}); process_rows<2>(i, j2, jend, lhs, rhs, res, palpha, pcj);
for (Index j = j2; j < jend; j += 1) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
}
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
i += ResPacketSize * 2; i += ResPacketSize * 2;
} }
if (i < n1) { if (i < n1) {
ResPacket c0 = pzero(ResPacket{}); process_rows<1>(i, j2, jend, lhs, rhs, res, palpha, pcj);
for (Index j = j2; j < jend; j += 1) {
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
}
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
i += ResPacketSize; i += ResPacketSize;
} }
if (HasHalf && i < n_half) { if (HasHalf && i < n_half) {
@@ -287,6 +300,22 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, Conj
EIGEN_DEVICE_FUNC static inline void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, EIGEN_DEVICE_FUNC static inline void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
ResScalar* res, Index resIncr, ResScalar alpha); ResScalar* res, Index resIncr, ResScalar alpha);
// Specialized path for when cols < full packet size. Kept noinline to avoid
// bloating the main run() function and causing icache pressure.
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run_small_cols(Index rows, Index cols, const LhsMapper& lhs,
const RhsMapper& rhs, ResScalar* res, Index resIncr,
ResScalar alpha);
// Templated helper that processes N rows in run_small_cols. N is a compile-time
// constant; row-dimension unrolling is done via recursive templates to guarantee
// full unrolling regardless of compiler heuristics.
template <int N>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs,
const RhsMapper& rhs, ResScalar* res,
Index resIncr, ResScalar alpha,
Index halfColBlockEnd,
Index quarterColBlockEnd);
}; };
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
@@ -295,6 +324,23 @@ EIGEN_DEVICE_FUNC inline void
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
ResScalar* res, Index resIncr, ResScalar alpha) { ResScalar* res, Index resIncr, ResScalar alpha) {
// When cols < full packet size, the main vectorized loops are empty.
// Dispatch to a separate noinline function to avoid polluting the icache.
// Only dispatch when cols is large enough that half or quarter packets can be used;
// otherwise the helper would just do scalar work with extra function call overhead.
enum {
LhsPacketSize_ = Traits::LhsPacketSize,
MinUsefulCols_ =
((int)QuarterTraits::LhsPacketSize < (int)HalfTraits::LhsPacketSize)
? (int)QuarterTraits::LhsPacketSize
: (((int)HalfTraits::LhsPacketSize < (int)Traits::LhsPacketSize) ? (int)HalfTraits::LhsPacketSize
: (int)Traits::LhsPacketSize)
};
if (cols >= MinUsefulCols_ && cols < LhsPacketSize_) {
run_small_cols(rows, cols, alhs, rhs, res, resIncr, alpha);
return;
}
// The following copy tells the compiler that lhs's attributes are not modified outside this function // The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate proper code. // This helps GCC to generate proper code.
LhsMapper lhs(alhs); LhsMapper lhs(alhs);
@@ -376,7 +422,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
res[(i + 6) * resIncr] += alpha * cc6; res[(i + 6) * resIncr] += alpha * cc6;
res[(i + 7) * resIncr] += alpha * cc7; res[(i + 7) * resIncr] += alpha * cc7;
} }
for (; i < n4; i += 4) { if (i < n4) {
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{}); ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{});
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) { for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
@@ -404,8 +450,9 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
res[(i + 1) * resIncr] += alpha * cc1; res[(i + 1) * resIncr] += alpha * cc1;
res[(i + 2) * resIncr] += alpha * cc2; res[(i + 2) * resIncr] += alpha * cc2;
res[(i + 3) * resIncr] += alpha * cc3; res[(i + 3) * resIncr] += alpha * cc3;
i += 4;
} }
for (; i < n2; i += 2) { if (i < n2) {
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}); ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{});
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) { for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
@@ -425,8 +472,9 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
} }
res[(i + 0) * resIncr] += alpha * cc0; res[(i + 0) * resIncr] += alpha * cc0;
res[(i + 1) * resIncr] += alpha * cc1; res[(i + 1) * resIncr] += alpha * cc1;
i += 2;
} }
for (; i < rows; ++i) { if (i < rows) {
ResPacket c0 = pzero(ResPacket{}); ResPacket c0 = pzero(ResPacket{});
ResPacketHalf c0_h = pzero(ResPacketHalf{}); ResPacketHalf c0_h = pzero(ResPacketHalf{});
ResPacketQuarter c0_q = pzero(ResPacketQuarter{}); ResPacketQuarter c0_q = pzero(ResPacketQuarter{});
@@ -457,6 +505,165 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
} }
} }
// Recursive template unroller for process_rows_small_cols.
// Unrolls the row dimension (K = 0..N-1) at compile time, guaranteeing
// that each accumulator lives in its own register variable regardless
// of compiler unrolling heuristics.
template <int K, int N>
struct gemv_small_cols_unroller {
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
const Packet& b0, ConjHelper& pcj) {
gemv_small_cols_unroller<K - 1, N>::template madd<Packet, Alignment>(acc, lhs, i, j, b0, pcj);
acc[K] = pcj.pmadd(lhs.template load<Packet, Alignment>(i + K, j), b0, acc[K]);
}
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
const Scalar& b0, ConjHelper& cj) {
gemv_small_cols_unroller<K - 1, N>::scalar_madd(cc, lhs, i, j, b0, cj);
cc[K] += cj.pmul(lhs(i + K, j), b0);
}
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) {
gemv_small_cols_unroller<K - 1, N>::predux_accum(cc, acc);
cc[K] += predux(acc[K]);
}
template <typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) {
gemv_small_cols_unroller<K - 1, N>::init_zero(acc);
acc[K] = pzero(Packet{});
}
template <typename Scalar, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha,
const Scalar* cc) {
gemv_small_cols_unroller<K - 1, N>::write_result(res, resIncr, i, alpha, cc);
res[(i + K) * resIncr] += alpha * cc[K];
}
};
template <int N>
struct gemv_small_cols_unroller<0, N> {
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
const Packet& b0, ConjHelper& pcj) {
acc[0] = pcj.pmadd(lhs.template load<Packet, Alignment>(i, j), b0, acc[0]);
}
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
const Scalar& b0, ConjHelper& cj) {
cc[0] += cj.pmul(lhs(i, j), b0);
}
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) {
cc[0] += predux(acc[0]);
}
template <typename Packet>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) {
acc[0] = pzero(Packet{});
}
template <typename Scalar, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha,
const Scalar* cc) {
res[i * resIncr] += alpha * cc[0];
}
};
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
typename RhsMapper, bool ConjugateRhs, int Version>
template <int N>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
Version>::process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs,
const RhsMapper& rhs, ResScalar* res, Index resIncr,
ResScalar alpha, Index halfColBlockEnd,
Index quarterColBlockEnd) {
conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
enum {
LhsAlignment = Unaligned,
ResPacketSizeHalf = HalfTraits::ResPacketSize,
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
HasHalf = (int)ResPacketSizeHalf < (int)Traits::ResPacketSize,
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
};
using Unroll = gemv_small_cols_unroller<N - 1, N>;
ResScalar cc[N] = {};
if (HasHalf) {
ResPacketHalf h[N];
Unroll::init_zero(h);
for (Index j = 0; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
Unroll::template madd<ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
}
Unroll::predux_accum(cc, h);
}
if (HasQuarter) {
ResPacketQuarter q[N];
Unroll::init_zero(q);
for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
Unroll::template madd<ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
}
Unroll::predux_accum(cc, q);
}
for (Index j = quarterColBlockEnd; j < cols; ++j) {
RhsScalar b0 = rhs(j, 0);
Unroll::scalar_madd(cc, lhs, i, j, b0, cj);
}
Unroll::write_result(res, resIncr, i, alpha, cc);
}
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
Version>::run_small_cols(Index rows, Index cols, const LhsMapper& alhs,
const RhsMapper& rhs, ResScalar* res, Index resIncr,
ResScalar alpha) {
LhsMapper lhs(alhs);
eigen_internal_assert(rhs.stride() == 1);
enum {
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
};
using UnsignedIndex = std::make_unsigned_t<Index>;
const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(cols) / LhsPacketSizeHalf);
const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(cols) / LhsPacketSizeQuarter);
const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? 0 : rows - 7;
const Index n4 = rows - 3;
const Index n2 = rows - 1;
Index i = 0;
for (; i < n8; i += 8)
process_rows_small_cols<8>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
if (i < n4) {
process_rows_small_cols<4>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
i += 4;
}
if (i < n2) {
process_rows_small_cols<2>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
i += 2;
}
if (i < rows) process_rows_small_cols<1>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
}
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen