Fix mixed-type compilation error in row-major GEMV small-cols path

libeigen/eigen!2160

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-19 19:54:16 -08:00
parent 4141d1fd2d
commit 4fdc82d695

View File

@@ -512,16 +512,17 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
// of compiler unrolling heuristics.
template <int K, int N>
struct gemv_small_cols_unroller {
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
const Packet& b0, ConjHelper& pcj) {
gemv_small_cols_unroller<K - 1, N>::template madd<Packet, Alignment>(acc, lhs, i, j, b0, pcj);
acc[K] = pcj.pmadd(lhs.template load<Packet, Alignment>(i + K, j), b0, acc[K]);
template <typename LhsPacket, typename AccPacket, int Alignment, typename RhsType, typename ConjHelper,
typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j,
const RhsType& b0, ConjHelper& pcj) {
gemv_small_cols_unroller<K - 1, N>::template madd<LhsPacket, AccPacket, Alignment>(acc, lhs, i, j, b0, pcj);
acc[K] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i + K, j), b0, acc[K]);
}
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
const Scalar& b0, ConjHelper& cj) {
template <typename ResScalar, typename RhsScalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j,
const RhsScalar& b0, ConjHelper& cj) {
gemv_small_cols_unroller<K - 1, N>::scalar_madd(cc, lhs, i, j, b0, cj);
cc[K] += cj.pmul(lhs(i + K, j), b0);
}
@@ -548,15 +549,16 @@ struct gemv_small_cols_unroller {
template <int N>
struct gemv_small_cols_unroller<0, N> {
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
const Packet& b0, ConjHelper& pcj) {
acc[0] = pcj.pmadd(lhs.template load<Packet, Alignment>(i, j), b0, acc[0]);
template <typename LhsPacket, typename AccPacket, int Alignment, typename RhsType, typename ConjHelper,
typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j,
const RhsType& b0, ConjHelper& pcj) {
acc[0] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i, j), b0, acc[0]);
}
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
const Scalar& b0, ConjHelper& cj) {
template <typename ResScalar, typename RhsScalar, typename ConjHelper, typename LhsMapper, typename Index>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j,
const RhsScalar& b0, ConjHelper& cj) {
cc[0] += cj.pmul(lhs(i, j), b0);
}
@@ -608,7 +610,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
Unroll::init_zero(h);
for (Index j = 0; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
Unroll::template madd<ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
Unroll::template madd<LhsPacketHalf, ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
}
Unroll::predux_accum(cc, h);
}
@@ -617,7 +619,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
Unroll::init_zero(q);
for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
Unroll::template madd<ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
Unroll::template madd<LhsPacketQuarter, ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
}
Unroll::predux_accum(cc, q);
}