diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index ff2ce23c8..167669245 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -99,8 +99,81 @@ struct general_matrix_vector_product + EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows( + Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, + const ResPacket& palpha, conj_helper& pcj); }; +// Recursive template unroller for col-major GEMV full-packet row blocks. +// Unrolls the packet dimension (K = 0..N-1) at compile time, guaranteeing +// that each accumulator lives in its own register variable. +template +struct gemv_colmajor_unroller { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) { + gemv_colmajor_unroller::init_zero(c); + c[K] = pzero(Packet{}); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j, + const RhsPacket& b0, ConjHelper& pcj) { + gemv_colmajor_unroller::template madd(c, lhs, i, j, b0, pcj); + c[K] = pcj.pmadd(lhs.template load(i + LhsStride * K, j), b0, c[K]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i, + const ResPacket& palpha) { + gemv_colmajor_unroller::template store(c, res, i, palpha); + pstoreu(res + i + ResStride * K, pmadd(c[K], palpha, ploadu(res + i + ResStride * K))); + } +}; + +template +struct gemv_colmajor_unroller<0, N> { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* c) { + c[0] = pzero(Packet{}); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* c, const LhsMapper& lhs, Index i, Index j, + const RhsPacket& b0, ConjHelper& pcj) { + c[0] = pcj.pmadd(lhs.template load(i, j), b0, c[0]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void store(const ResPacket* c, ResScalar* res, Index i, + const ResPacket& palpha) { + pstoreu(res + i, pmadd(c[0], palpha, ploadu(res + i))); + } +}; + +template +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void general_matrix_vector_product< + Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, + Version>::process_rows(Index i, Index j2, Index jend, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, + const ResPacket& palpha, + conj_helper& pcj) { + enum { LhsAlignment = Unaligned, LhsPacketSize = Traits::LhsPacketSize, ResPacketSize = Traits::ResPacketSize }; + using Unroller = gemv_colmajor_unroller; + + ResPacket c[N]; + Unroller::init_zero(c); + for (Index j = j2; j < jend; ++j) { + RhsPacket b0 = pset1(rhs(j, 0)); + Unroller::template madd(c, lhs, i, j, b0, pcj); + } + Unroller::template store(c, res, i, palpha); +} + template EIGEN_DEVICE_FUNC inline void @@ -140,7 +213,7 @@ general_matrix_vector_product(alpha); ResPacketHalf palpha_half = pset1(alpha); ResPacketQuarter palpha_quarter = pset1(alpha); @@ -148,81 +221,21 @@ general_matrix_vector_product(rhs(j, 0)); - c0 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 0, j), b0, c0); - c1 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 1, j), b0, c1); - c2 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 2, j), b0, c2); - c3 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 3, j), b0, c3); - c4 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 4, j), b0, c4); - c5 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 5, j), b0, c5); - c6 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 6, j), b0, c6); - c7 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 7, j), b0, c7); - } - pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu(res + i + ResPacketSize * 0))); - pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu(res + i + ResPacketSize * 1))); - pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu(res + i + ResPacketSize * 2))); - pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu(res + i + ResPacketSize * 3))); - pstoreu(res + i + ResPacketSize * 4, pmadd(c4, palpha, ploadu(res + i + ResPacketSize * 4))); - pstoreu(res + i + ResPacketSize * 5, pmadd(c5, palpha, ploadu(res + i + ResPacketSize * 5))); - pstoreu(res + i + ResPacketSize * 6, pmadd(c6, palpha, ploadu(res + i + ResPacketSize * 6))); - pstoreu(res + i + ResPacketSize * 7, pmadd(c7, palpha, ploadu(res + i + ResPacketSize * 7))); - } + for (; i < n8; i += ResPacketSize * 8) process_rows<8>(i, j2, jend, lhs, rhs, res, palpha, pcj); if (i < n4) { - ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}), c3 = pzero(ResPacket{}); - - for (Index j = j2; j < jend; j += 1) { - RhsPacket b0 = pset1(rhs(j, 0)); - c0 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 0, j), b0, c0); - c1 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 1, j), b0, c1); - c2 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 2, j), b0, c2); - c3 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 3, j), b0, c3); - } - pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu(res + i + ResPacketSize * 0))); - pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu(res + i + ResPacketSize * 1))); - pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu(res + i + ResPacketSize * 2))); - pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu(res + i + ResPacketSize * 3))); - + process_rows<4>(i, j2, jend, lhs, rhs, res, palpha, pcj); i += ResPacketSize * 4; } if (i < n3) { - ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}), c2 = pzero(ResPacket{}); - - for (Index j = j2; j < jend; j += 1) { - RhsPacket b0 = pset1(rhs(j, 0)); - c0 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 0, j), b0, c0); - c1 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 1, j), b0, c1); - c2 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 2, j), b0, c2); - } - pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu(res + i + ResPacketSize * 0))); - pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu(res + i + ResPacketSize * 1))); - pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu(res + i + ResPacketSize * 2))); - + process_rows<3>(i, j2, jend, lhs, rhs, res, palpha, pcj); i += ResPacketSize * 3; } if (i < n2) { - ResPacket c0 = pzero(ResPacket{}), c1 = pzero(ResPacket{}); - - for (Index j = j2; j < jend; j += 1) { - RhsPacket b0 = pset1(rhs(j, 0)); - c0 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 0, j), b0, c0); - c1 = pcj.pmadd(lhs.template load(i + LhsPacketSize * 1, j), b0, c1); - } - pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu(res + i + ResPacketSize * 0))); - pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu(res + i + ResPacketSize * 1))); + process_rows<2>(i, j2, jend, lhs, rhs, res, palpha, pcj); i += ResPacketSize * 2; } if (i < n1) { - ResPacket c0 = pzero(ResPacket{}); - for (Index j = j2; j < jend; j += 1) { - RhsPacket b0 = pset1(rhs(j, 0)); - c0 = pcj.pmadd(lhs.template load(i + 0, j), b0, c0); - } - pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu(res + i + ResPacketSize * 0))); + process_rows<1>(i, j2, jend, lhs, rhs, res, palpha, pcj); i += ResPacketSize; } if (HasHalf && i < n_half) { @@ -287,6 +300,22 @@ struct general_matrix_vector_product + EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE void process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, + Index resIncr, ResScalar alpha, + Index halfColBlockEnd, + Index quarterColBlockEnd); }; template ::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { + // When cols < full packet size, the main vectorized loops are empty. + // Dispatch to a separate noinline function to avoid polluting the icache. + // Only dispatch when cols is large enough that half or quarter packets can be used; + // otherwise the helper would just do scalar work with extra function call overhead. + enum { + LhsPacketSize_ = Traits::LhsPacketSize, + MinUsefulCols_ = + ((int)QuarterTraits::LhsPacketSize < (int)HalfTraits::LhsPacketSize) + ? (int)QuarterTraits::LhsPacketSize + : (((int)HalfTraits::LhsPacketSize < (int)Traits::LhsPacketSize) ? (int)HalfTraits::LhsPacketSize + : (int)Traits::LhsPacketSize) + }; + if (cols >= MinUsefulCols_ && cols < LhsPacketSize_) { + run_small_cols(rows, cols, alhs, rhs, res, resIncr, alpha); + return; + } + // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate proper code. LhsMapper lhs(alhs); @@ -376,7 +422,7 @@ general_matrix_vector_product +struct gemv_small_cols_unroller { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j, + const Packet& b0, ConjHelper& pcj) { + gemv_small_cols_unroller::template madd(acc, lhs, i, j, b0, pcj); + acc[K] = pcj.pmadd(lhs.template load(i + K, j), b0, acc[K]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j, + const Scalar& b0, ConjHelper& cj) { + gemv_small_cols_unroller::scalar_madd(cc, lhs, i, j, b0, cj); + cc[K] += cj.pmul(lhs(i + K, j), b0); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) { + gemv_small_cols_unroller::predux_accum(cc, acc); + cc[K] += predux(acc[K]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) { + gemv_small_cols_unroller::init_zero(acc); + acc[K] = pzero(Packet{}); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha, + const Scalar* cc) { + gemv_small_cols_unroller::write_result(res, resIncr, i, alpha, cc); + res[(i + K) * resIncr] += alpha * cc[K]; + } +}; + +template +struct gemv_small_cols_unroller<0, N> { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j, + const Packet& b0, ConjHelper& pcj) { + acc[0] = pcj.pmadd(lhs.template load(i, j), b0, acc[0]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j, + const Scalar& b0, ConjHelper& cj) { + cc[0] += cj.pmul(lhs(i, j), b0); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void predux_accum(Scalar* cc, const Packet* acc) { + cc[0] += predux(acc[0]); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void init_zero(Packet* acc) { + acc[0] = pzero(Packet{}); + } + + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void write_result(Scalar* res, Index resIncr, Index i, Scalar alpha, + const Scalar* cc) { + res[i * resIncr] += alpha * cc[0]; + } +}; + +template +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void +general_matrix_vector_product::process_rows_small_cols(Index i, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, + ResScalar alpha, Index halfColBlockEnd, + Index quarterColBlockEnd) { + conj_helper cj; + conj_helper pcj_half; + conj_helper pcj_quarter; + + enum { + LhsAlignment = Unaligned, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSizeHalf = HalfTraits::LhsPacketSize, + LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)Traits::ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; + + using Unroll = gemv_small_cols_unroller; + + ResScalar cc[N] = {}; + if (HasHalf) { + ResPacketHalf h[N]; + Unroll::init_zero(h); + for (Index j = 0; j < halfColBlockEnd; j += LhsPacketSizeHalf) { + RhsPacketHalf b0 = rhs.template load(j, 0); + Unroll::template madd(h, lhs, i, j, b0, pcj_half); + } + Unroll::predux_accum(cc, h); + } + if (HasQuarter) { + ResPacketQuarter q[N]; + Unroll::init_zero(q); + for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) { + RhsPacketQuarter b0 = rhs.template load(j, 0); + Unroll::template madd(q, lhs, i, j, b0, pcj_quarter); + } + Unroll::predux_accum(cc, q); + } + for (Index j = quarterColBlockEnd; j < cols; ++j) { + RhsScalar b0 = rhs(j, 0); + Unroll::scalar_madd(cc, lhs, i, j, b0, cj); + } + Unroll::write_result(res, resIncr, i, alpha, cc); +} + +template +EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void +general_matrix_vector_product::run_small_cols(Index rows, Index cols, const LhsMapper& alhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, + ResScalar alpha) { + LhsMapper lhs(alhs); + eigen_internal_assert(rhs.stride() == 1); + + enum { + LhsPacketSizeHalf = HalfTraits::LhsPacketSize, + LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize, + }; + + using UnsignedIndex = std::make_unsigned_t; + const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(cols) / LhsPacketSizeHalf); + const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(cols) / LhsPacketSizeQuarter); + + const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? 0 : rows - 7; + const Index n4 = rows - 3; + const Index n2 = rows - 1; + + Index i = 0; + for (; i < n8; i += 8) + process_rows_small_cols<8>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd); + if (i < n4) { + process_rows_small_cols<4>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd); + i += 4; + } + if (i < n2) { + process_rows_small_cols<2>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd); + i += 2; + } + if (i < rows) process_rows_small_cols<1>(i, cols, lhs, rhs, res, resIncr, alpha, halfColBlockEnd, quarterColBlockEnd); +} + } // end namespace internal } // end namespace Eigen