mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Optimize GEMV kernels: row-major small-cols and template deduplication
libeigen/eigen!2151 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -99,8 +99,81 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, Conj
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC inline static void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
|
EIGEN_DEVICE_FUNC inline static void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
|
||||||
ResScalar* res, Index resIncr, RhsScalar alpha);
|
ResScalar* res, Index resIncr, RhsScalar alpha);
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows(
|
||||||
|
Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res,
|
||||||
|
const ResPacket& palpha, conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>& pcj);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Recursive template unroller for col-major GEMV full-packet row blocks.
|
||||||
|
// Unrolls the packet dimension (K = 0..N-1) at compile time, guaranteeing
|
||||||
|
// that each accumulator lives in its own register variable.
|
||||||
|
template <int K, int N>
|
||||||
|
struct gemv_colmajor_unroller {
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) {
|
||||||
|
gemv_colmajor_unroller<K - 1, N>::init_zero(c);
|
||||||
|
c[K] = pzero(Packet{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LhsPacket, int LhsStride, int Alignment, typename AccPacket, typename RhsPacket,
|
||||||
|
typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const RhsPacket& b0, ConjHelper& pcj) {
|
||||||
|
gemv_colmajor_unroller<K - 1, N>::template madd<LhsPacket, LhsStride, Alignment>(c, lhs, i, j, b0, pcj);
|
||||||
|
c[K] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i + LhsStride * K, j), b0, c[K]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ResPacket, int ResStride, typename ResScalar>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i,
|
||||||
|
const ResPacket& palpha) {
|
||||||
|
gemv_colmajor_unroller<K - 1, N>::template store<ResPacket, ResStride>(c, res, i, palpha);
|
||||||
|
pstoreu(res + i + ResStride * K, pmadd(c[K], palpha, ploadu<ResPacket>(res + i + ResStride * K)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct gemv_colmajor_unroller<0, N> {
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) {
|
||||||
|
c[0] = pzero(Packet{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LhsPacket, int LhsStride, int Alignment, typename AccPacket, typename RhsPacket,
|
||||||
|
typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const RhsPacket& b0, ConjHelper& pcj) {
|
||||||
|
c[0] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i, j), b0, c[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ResPacket, int ResStride, typename ResScalar>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i,
|
||||||
|
const ResPacket& palpha) {
|
||||||
|
pstoreu(res + i, pmadd(c[0], palpha, ploadu<ResPacket>(res + i)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
||||||
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
|
template <int N>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void general_matrix_vector_product<
|
||||||
|
Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
|
||||||
|
Version>::process_rows(Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res,
|
||||||
|
const ResPacket& palpha,
|
||||||
|
conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>& pcj) {
|
||||||
|
enum { LhsAlignment = Unaligned, LhsPacketSize = Traits::LhsPacketSize, ResPacketSize = Traits::ResPacketSize };
|
||||||
|
using Unroller = gemv_colmajor_unroller<N - 1, N>;
|
||||||
|
|
||||||
|
ResPacket c[N];
|
||||||
|
Unroller::init_zero(c);
|
||||||
|
for (Index j = j2; j < jend; ++j) {
|
||||||
|
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
||||||
|
Unroller::template madd<LhsPacket, LhsPacketSize, LhsAlignment>(c, lhs, i, j, b0, pcj);
|
||||||
|
}
|
||||||
|
Unroller::template store<ResPacket, ResPacketSize>(c, res, i, palpha);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
||||||
typename RhsMapper, bool ConjugateRhs, int Version>
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
EIGEN_DEVICE_FUNC inline void
|
EIGEN_DEVICE_FUNC inline void
|
||||||
@@ -140,7 +213,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLh
|
|||||||
const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1;
|
const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1;
|
||||||
|
|
||||||
// TODO: improve the following heuristic:
|
// TODO: improve the following heuristic:
|
||||||
const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 32000 ? 16 : 4);
|
const Index block_cols = cols < 128 ? cols : (lhsStride * Index(sizeof(LhsScalar)) < 32000 ? Index(16) : Index(4));
|
||||||
ResPacket palpha = pset1<ResPacket>(alpha);
|
ResPacket palpha = pset1<ResPacket>(alpha);
|
||||||
ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
|
ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
|
||||||
ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
|
ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
|
||||||
@@ -148,81 +221,21 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLh
|
|||||||
for (Index j2 = 0; j2 < cols; j2 += block_cols) {
|
for (Index j2 = 0; j2 < cols; j2 += block_cols) {
|
||||||
Index jend = numext::mini(j2 + block_cols, cols);
|
Index jend = numext::mini(j2 + block_cols, cols);
|
||||||
Index i = 0;
|
Index i = 0;
|
||||||
for (; i < n8; i += ResPacketSize * 8) {
|
for (; i < n8; i += ResPacketSize * 8) process_rows<8>(i, j2, jend, lhs, rhs, res, palpha, pcj);
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{}),
|
|
||||||
c4 = pzero(ResPacket{}), c5 = pzero(ResPacket{}), c6 = pzero(ResPacket{}), c7 = pzero(ResPacket{});
|
|
||||||
|
|
||||||
for (Index j = j2; j < jend; j += 1) {
|
|
||||||
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
|
||||||
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
|
|
||||||
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
|
|
||||||
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
|
|
||||||
c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
|
|
||||||
c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 4, j), b0, c4);
|
|
||||||
c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 5, j), b0, c5);
|
|
||||||
c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 6, j), b0, c6);
|
|
||||||
c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 7, j), b0, c7);
|
|
||||||
}
|
|
||||||
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 4, pmadd(c4, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 4)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 5, pmadd(c5, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 5)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 6, pmadd(c6, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 6)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 7, pmadd(c7, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 7)));
|
|
||||||
}
|
|
||||||
if (i < n4) {
|
if (i < n4) {
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{});
|
process_rows<4>(i, j2, jend, lhs, rhs, res, palpha, pcj);
|
||||||
|
|
||||||
for (Index j = j2; j < jend; j += 1) {
|
|
||||||
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
|
||||||
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
|
|
||||||
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
|
|
||||||
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
|
|
||||||
c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
|
|
||||||
}
|
|
||||||
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
|
|
||||||
|
|
||||||
i += ResPacketSize * 4;
|
i += ResPacketSize * 4;
|
||||||
}
|
}
|
||||||
if (i < n3) {
|
if (i < n3) {
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{});
|
process_rows<3>(i, j2, jend, lhs, rhs, res, palpha, pcj);
|
||||||
|
|
||||||
for (Index j = j2; j < jend; j += 1) {
|
|
||||||
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
|
||||||
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
|
|
||||||
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
|
|
||||||
c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
|
|
||||||
}
|
|
||||||
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
|
|
||||||
|
|
||||||
i += ResPacketSize * 3;
|
i += ResPacketSize * 3;
|
||||||
}
|
}
|
||||||
if (i < n2) {
|
if (i < n2) {
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{});
|
process_rows<2>(i, j2, jend, lhs, rhs, res, palpha, pcj);
|
||||||
|
|
||||||
for (Index j = j2; j < jend; j += 1) {
|
|
||||||
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
|
||||||
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
|
|
||||||
c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
|
|
||||||
}
|
|
||||||
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
|
|
||||||
pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
|
|
||||||
i += ResPacketSize * 2;
|
i += ResPacketSize * 2;
|
||||||
}
|
}
|
||||||
if (i < n1) {
|
if (i < n1) {
|
||||||
ResPacket c0 = pzero(ResPacket{});
|
process_rows<1>(i, j2, jend, lhs, rhs, res, palpha, pcj);
|
||||||
for (Index j = j2; j < jend; j += 1) {
|
|
||||||
RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
|
|
||||||
c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
|
|
||||||
}
|
|
||||||
pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
|
|
||||||
i += ResPacketSize;
|
i += ResPacketSize;
|
||||||
}
|
}
|
||||||
if (HasHalf && i < n_half) {
|
if (HasHalf && i < n_half) {
|
||||||
@@ -287,6 +300,22 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, Conj
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC static inline void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
|
EIGEN_DEVICE_FUNC static inline void run(Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs,
|
||||||
ResScalar* res, Index resIncr, ResScalar alpha);
|
ResScalar* res, Index resIncr, ResScalar alpha);
|
||||||
|
|
||||||
|
// Specialized path for when cols < full packet size. Kept noinline to avoid
|
||||||
|
// bloating the main run() function and causing icache pressure.
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run_small_cols(Index rows, Index cols, const LhsMapper& lhs,
|
||||||
|
const RhsMapper& rhs, ResScalar* res, Index resIncr,
|
||||||
|
ResScalar alpha);
|
||||||
|
|
||||||
|
// Templated helper that processes N rows in run_small_cols. N is a compile-time
|
||||||
|
// constant; row-dimension unrolling is done via recursive templates to guarantee
|
||||||
|
// full unrolling regardless of compiler heuristics.
|
||||||
|
template <int N>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs,
|
||||||
|
const RhsMapper& rhs, ResScalar* res,
|
||||||
|
Index resIncr, ResScalar alpha,
|
||||||
|
Index halfColBlockEnd,
|
||||||
|
Index quarterColBlockEnd);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
||||||
@@ -295,6 +324,23 @@ EIGEN_DEVICE_FUNC inline void
|
|||||||
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
|
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
|
||||||
Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
|
Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
|
||||||
ResScalar* res, Index resIncr, ResScalar alpha) {
|
ResScalar* res, Index resIncr, ResScalar alpha) {
|
||||||
|
// When cols < full packet size, the main vectorized loops are empty.
|
||||||
|
// Dispatch to a separate noinline function to avoid polluting the icache.
|
||||||
|
// Only dispatch when cols is large enough that half or quarter packets can be used;
|
||||||
|
// otherwise the helper would just do scalar work with extra function call overhead.
|
||||||
|
enum {
|
||||||
|
LhsPacketSize_ = Traits::LhsPacketSize,
|
||||||
|
MinUsefulCols_ =
|
||||||
|
((int)QuarterTraits::LhsPacketSize < (int)HalfTraits::LhsPacketSize)
|
||||||
|
? (int)QuarterTraits::LhsPacketSize
|
||||||
|
: (((int)HalfTraits::LhsPacketSize < (int)Traits::LhsPacketSize) ? (int)HalfTraits::LhsPacketSize
|
||||||
|
: (int)Traits::LhsPacketSize)
|
||||||
|
};
|
||||||
|
if (cols >= MinUsefulCols_ && cols < LhsPacketSize_) {
|
||||||
|
run_small_cols(rows, cols, alhs, rhs, res, resIncr, alpha);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||||
// This helps GCC to generate proper code.
|
// This helps GCC to generate proper code.
|
||||||
LhsMapper lhs(alhs);
|
LhsMapper lhs(alhs);
|
||||||
@@ -376,7 +422,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
|||||||
res[(i + 6) * resIncr] += alpha * cc6;
|
res[(i + 6) * resIncr] += alpha * cc6;
|
||||||
res[(i + 7) * resIncr] += alpha * cc7;
|
res[(i + 7) * resIncr] += alpha * cc7;
|
||||||
}
|
}
|
||||||
for (; i < n4; i += 4) {
|
if (i < n4) {
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{});
|
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{});
|
||||||
|
|
||||||
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
|
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
|
||||||
@@ -404,8 +450,9 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
|||||||
res[(i + 1) * resIncr] += alpha * cc1;
|
res[(i + 1) * resIncr] += alpha * cc1;
|
||||||
res[(i + 2) * resIncr] += alpha * cc2;
|
res[(i + 2) * resIncr] += alpha * cc2;
|
||||||
res[(i + 3) * resIncr] += alpha * cc3;
|
res[(i + 3) * resIncr] += alpha * cc3;
|
||||||
|
i += 4;
|
||||||
}
|
}
|
||||||
for (; i < n2; i += 2) {
|
if (i < n2) {
|
||||||
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{});
|
ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{});
|
||||||
|
|
||||||
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
|
for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
|
||||||
@@ -425,8 +472,9 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
|||||||
}
|
}
|
||||||
res[(i + 0) * resIncr] += alpha * cc0;
|
res[(i + 0) * resIncr] += alpha * cc0;
|
||||||
res[(i + 1) * resIncr] += alpha * cc1;
|
res[(i + 1) * resIncr] += alpha * cc1;
|
||||||
|
i += 2;
|
||||||
}
|
}
|
||||||
for (; i < rows; ++i) {
|
if (i < rows) {
|
||||||
ResPacket c0 = pzero(ResPacket{});
|
ResPacket c0 = pzero(ResPacket{});
|
||||||
ResPacketHalf c0_h = pzero(ResPacketHalf{});
|
ResPacketHalf c0_h = pzero(ResPacketHalf{});
|
||||||
ResPacketQuarter c0_q = pzero(ResPacketQuarter{});
|
ResPacketQuarter c0_q = pzero(ResPacketQuarter{});
|
||||||
@@ -457,6 +505,165 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recursive template unroller for process_rows_small_cols.
|
||||||
|
// Unrolls the row dimension (K = 0..N-1) at compile time, guaranteeing
|
||||||
|
// that each accumulator lives in its own register variable regardless
|
||||||
|
// of compiler unrolling heuristics.
|
||||||
|
template <int K, int N>
|
||||||
|
struct gemv_small_cols_unroller {
|
||||||
|
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const Packet& b0, ConjHelper& pcj) {
|
||||||
|
gemv_small_cols_unroller<K - 1, N>::template madd<Packet, Alignment>(acc, lhs, i, j, b0, pcj);
|
||||||
|
acc[K] = pcj.pmadd(lhs.template load<Packet, Alignment>(i + K, j), b0, acc[K]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const Scalar& b0, ConjHelper& cj) {
|
||||||
|
gemv_small_cols_unroller<K - 1, N>::scalar_madd(cc, lhs, i, j, b0, cj);
|
||||||
|
cc[K] += cj.pmul(lhs(i + K, j), b0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) {
|
||||||
|
gemv_small_cols_unroller<K - 1, N>::predux_accum(cc, acc);
|
||||||
|
cc[K] += predux(acc[K]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) {
|
||||||
|
gemv_small_cols_unroller<K - 1, N>::init_zero(acc);
|
||||||
|
acc[K] = pzero(Packet{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha,
|
||||||
|
const Scalar* cc) {
|
||||||
|
gemv_small_cols_unroller<K - 1, N>::write_result(res, resIncr, i, alpha, cc);
|
||||||
|
res[(i + K) * resIncr] += alpha * cc[K];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct gemv_small_cols_unroller<0, N> {
|
||||||
|
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const Packet& b0, ConjHelper& pcj) {
|
||||||
|
acc[0] = pcj.pmadd(lhs.template load<Packet, Alignment>(i, j), b0, acc[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||||
|
const Scalar& b0, ConjHelper& cj) {
|
||||||
|
cc[0] += cj.pmul(lhs(i, j), b0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) {
|
||||||
|
cc[0] += predux(acc[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) {
|
||||||
|
acc[0] = pzero(Packet{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename Index>
|
||||||
|
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha,
|
||||||
|
const Scalar* cc) {
|
||||||
|
res[i * resIncr] += alpha * cc[0];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
||||||
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
|
template <int N>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void
|
||||||
|
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
|
||||||
|
Version>::process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs,
|
||||||
|
const RhsMapper& rhs, ResScalar* res, Index resIncr,
|
||||||
|
ResScalar alpha, Index halfColBlockEnd,
|
||||||
|
Index quarterColBlockEnd) {
|
||||||
|
conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
|
||||||
|
conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
|
||||||
|
conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
LhsAlignment = Unaligned,
|
||||||
|
ResPacketSizeHalf = HalfTraits::ResPacketSize,
|
||||||
|
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
|
||||||
|
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
|
||||||
|
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
|
||||||
|
HasHalf = (int)ResPacketSizeHalf < (int)Traits::ResPacketSize,
|
||||||
|
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
|
||||||
|
};
|
||||||
|
|
||||||
|
using Unroll = gemv_small_cols_unroller<N - 1, N>;
|
||||||
|
|
||||||
|
ResScalar cc[N] = {};
|
||||||
|
if (HasHalf) {
|
||||||
|
ResPacketHalf h[N];
|
||||||
|
Unroll::init_zero(h);
|
||||||
|
for (Index j = 0; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
|
||||||
|
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
|
||||||
|
Unroll::template madd<ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
|
||||||
|
}
|
||||||
|
Unroll::predux_accum(cc, h);
|
||||||
|
}
|
||||||
|
if (HasQuarter) {
|
||||||
|
ResPacketQuarter q[N];
|
||||||
|
Unroll::init_zero(q);
|
||||||
|
for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
|
||||||
|
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
|
||||||
|
Unroll::template madd<ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
|
||||||
|
}
|
||||||
|
Unroll::predux_accum(cc, q);
|
||||||
|
}
|
||||||
|
for (Index j = quarterColBlockEnd; j < cols; ++j) {
|
||||||
|
RhsScalar b0 = rhs(j, 0);
|
||||||
|
Unroll::scalar_madd(cc, lhs, i, j, b0, cj);
|
||||||
|
}
|
||||||
|
Unroll::write_result(res, resIncr, i, alpha, cc);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
|
||||||
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void
|
||||||
|
general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
|
||||||
|
Version>::run_small_cols(Index rows, Index cols, const LhsMapper& alhs,
|
||||||
|
const RhsMapper& rhs, ResScalar* res, Index resIncr,
|
||||||
|
ResScalar alpha) {
|
||||||
|
LhsMapper lhs(alhs);
|
||||||
|
eigen_internal_assert(rhs.stride() == 1);
|
||||||
|
|
||||||
|
enum {
|
||||||
|
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
|
||||||
|
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
|
||||||
|
};
|
||||||
|
|
||||||
|
using UnsignedIndex = std::make_unsigned_t<Index>;
|
||||||
|
const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(cols) / LhsPacketSizeHalf);
|
||||||
|
const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(cols) / LhsPacketSizeQuarter);
|
||||||
|
|
||||||
|
const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? 0 : rows - 7;
|
||||||
|
const Index n4 = rows - 3;
|
||||||
|
const Index n2 = rows - 1;
|
||||||
|
|
||||||
|
Index i = 0;
|
||||||
|
for (; i < n8; i += 8)
|
||||||
|
process_rows_small_cols<8>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
|
||||||
|
if (i < n4) {
|
||||||
|
process_rows_small_cols<4>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
|
||||||
|
i += 4;
|
||||||
|
}
|
||||||
|
if (i < n2) {
|
||||||
|
process_rows_small_cols<2>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
if (i < rows) process_rows_small_cols<1>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|||||||
Reference in New Issue
Block a user