Fix GEBP half/quarter-packet loops for nr>=8 RHS packing on ARM64

On ARM64 (and LoongArch64), the GEBP kernel uses nr=8, so the RHS is
packed in 8-column blocks. The half-packet and quarter-packet row
processing loops were iterating columns 4 at a time starting from j2=0,
misindexing into the 8-column packed RHS buffer. This produced
completely wrong results for float GEMM when the number of rows was
smaller than the SIMD packet size (e.g. 2x10 * 10x8 float).

Add the missing nr>=8 column iteration blocks to both loops, matching
the pattern already present in the 3x, 2x, 1x, and scalar remainder
sections.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-25 19:03:11 -08:00
parent e567151ce3
commit 888d708dcd

View File

@@ -1496,7 +1496,16 @@ EIGEN_DONT_INLINE void gebp_kernel<LhsScalar, RhsScalar, Index, DataMapper, mr,
EIGEN_IF_CONSTEXPR((LhsProgressHalf < LhsProgress) && mr >= LhsProgressHalf) {
HalfTraits half_traits;
for (Index i = peeled_mc1; i < peeled_mc_half; i += LhsProgressHalf) {
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
EIGEN_IF_CONSTEXPR(nr >= 8) {
for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
gebp_micro_panel_impl<1, 8, HalfTraits, LhsScalar, RhsScalar, ResScalar, Index, DataMapper, LinearMapper,
LhsPacket>(half_traits, res, blockA, blockB, alpha, i, j2, depth, strideA, strideB,
offsetA, offsetB, prefetch_res_offset, peeled_kc, pk);
}
}
#endif
for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
gebp_micro_panel_impl<1, 4, HalfTraits, LhsScalar, RhsScalar, ResScalar, Index, DataMapper, LinearMapper,
LhsPacket>(half_traits, res, blockA, blockB, alpha, i, j2, depth, strideA, strideB,
offsetA, offsetB, prefetch_res_offset, peeled_kc, pk);
@@ -1513,7 +1522,16 @@ EIGEN_DONT_INLINE void gebp_kernel<LhsScalar, RhsScalar, Index, DataMapper, mr,
EIGEN_IF_CONSTEXPR((LhsProgressQuarter < LhsProgressHalf) && mr >= LhsProgressQuarter) {
QuarterTraits quarter_traits;
for (Index i = peeled_mc_half; i < peeled_mc_quarter; i += LhsProgressQuarter) {
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
EIGEN_IF_CONSTEXPR(nr >= 8) {
for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
gebp_micro_panel_impl<1, 8, QuarterTraits, LhsScalar, RhsScalar, ResScalar, Index, DataMapper, LinearMapper,
LhsPacket>(quarter_traits, res, blockA, blockB, alpha, i, j2, depth, strideA, strideB,
offsetA, offsetB, prefetch_res_offset, peeled_kc, pk);
}
}
#endif
for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
gebp_micro_panel_impl<1, 4, QuarterTraits, LhsScalar, RhsScalar, ResScalar, Index, DataMapper, LinearMapper,
LhsPacket>(quarter_traits, res, blockA, blockB, alpha, i, j2, depth, strideA, strideB,
offsetA, offsetB, prefetch_res_offset, peeled_kc, pk);