Fix real x complex GEMM for backends where half == full packet size

libeigen/eigen!2150

Closes #3028

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-18 12:32:24 -08:00
parent 073190be04
commit f69745b678
2 changed files with 7 additions and 3 deletions

View File

@@ -397,6 +397,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4cd, 2>&
detail::ptranspose_impl(kernel);
}
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf, Packet16f)
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd, Packet8d)
} // end namespace internal
} // end namespace Eigen

View File

@@ -2582,7 +2582,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Pa
EIGEN_UNUSED_VARIABLE(stride);
EIGEN_UNUSED_VARIABLE(offset);
eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
eigen_assert(((Pack1 % PacketSize) == 0 && Pack1 <= 4 * PacketSize) || (Pack1 <= 4));
eigen_assert(((Pack1 % PacketSize) == 0 && Pack1 <= 4 * PacketSize) || (Pack1 <= 4) || (Pack1 < PacketSize));
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
Index count = 0;
@@ -2594,7 +2594,8 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Pa
const Index peeled_mc_half =
Pack1 >= HalfPacketSize ? peeled_mc1 + ((rows - peeled_mc1) / (HalfPacketSize)) * (HalfPacketSize) : 0;
const Index peeled_mc_quarter = Pack1 >= QuarterPacketSize ? (rows / (QuarterPacketSize)) * (QuarterPacketSize) : 0;
const Index last_lhs_progress = rows > peeled_mc_quarter ? (rows - peeled_mc_quarter) & ~1 : 0;
const Index last_lhs_progress =
rows > peeled_mc_quarter ? (Pack2 > 1 ? Pack2 : ((rows - peeled_mc_quarter) & ~1)) : 0;
const Index peeled_mc0 = Pack2 >= PacketSize ? peeled_mc_quarter
: Pack2 > 1 && last_lhs_progress ? (rows / last_lhs_progress) * last_lhs_progress
: 0;
@@ -2810,7 +2811,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Pa
// last peeling loop at this point (for the rhs).
if (Pack2 < PacketSize && !gone_last) {
gone_last = true;
psize = pack = left & ~1;
psize = pack = Pack2;
}
}
}