Fix GEBP asm register constraints for custom scalar types

libeigen/eigen!2258

Closes #3059

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-03-07 07:49:27 -08:00
parent 20fce70e5a
commit 3041ab44af
2 changed files with 215 additions and 51 deletions

View File

@@ -1151,59 +1151,52 @@ struct gebp_micro_step {
gebp_rhs_cols<0, MrPackets, NrCols>::run(traits, blB, Index(NrCols * K), A, rhs_panel, T0, C);
}
};
// Compiler workaround macros used inside gebp_peeled_loop.
#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
#define EIGEN_GEBP_ARM64_3P_WORKAROUND(MrPackets, A, LhsArray, FullLhsPacket) \
EIGEN_IF_CONSTEXPR( \
(MrPackets == 3 && \
std::is_same<std::remove_all_extents_t<std::remove_reference_t<LhsArray>>, FullLhsPacket>::value)) { \
__asm__("" : "+w,m"(A[0]), "+w,m"(A[1]), "+w,m"(A[2])); \
}
#else
#define EIGEN_GEBP_ARM64_3P_WORKAROUND(MrPackets, A, LhsArray, FullLhsPacket)
#endif
// Compiler register allocation workarounds for the GEBP micro-kernel.
// GCC can fail to keep array-based SIMD values in vector registers, causing
// excessive spilling. These helpers use inline asm constraints to pin values.
// Only applied when the scalar type is actually vectorizable (not custom types).
// See Eigen bugs 935, 1637, and 3059.
// GCC's register allocator can fail to keep array-based accumulators in XMM
// registers, causing excessive spilling to the stack. Pin accumulators using
// inline asm "+x" constraints. The ACC pinning only works for plain SSE
// vector types (not DoublePacket used by complex), so it requires if constexpr
// (C++17) to safely discard the dead branch for complex types.
// See Eigen bugs 935 and 1637.
// ARM64 NEON: pin 3 LHS packets in vector registers.
// Old GCC (< 9) misallocates registers for 3-packet paths without this hint.
template <int MrPackets, typename GEBPTraits_, typename FullLhsPacket_, typename LhsArray_>
EIGEN_ALWAYS_INLINE void gebp_neon_3p_workaround(LhsArray_& A) {
#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
using LhsElement = std::remove_all_extents_t<std::remove_reference_t<LhsArray_>>;
constexpr bool apply = GEBPTraits_::Vectorizable && MrPackets == 3 && std::is_same<LhsElement, FullLhsPacket_>::value;
EIGEN_IF_CONSTEXPR(apply) { __asm__("" : "+w,m"(A[0]), "+w,m"(A[1]), "+w,m"(A[2])); }
#else
EIGEN_UNUSED_VARIABLE(A);
#endif
}
// GCC SSE: prevent register spilling for LHS packets and accumulators.
// C++17: pin accumulators with strict "+x" (if constexpr discards dead branches).
// C++14: pin LHS packets with relaxed "+x,m" (memory fallback for non-SSE types).
template <int MrPackets, int NrCols, typename GEBPTraits_, typename FullLhsPacket_, typename LhsArray_,
typename AccArray_>
EIGEN_ALWAYS_INLINE void gebp_sse_spilling_workaround(LhsArray_& A, AccArray_& ACC) {
EIGEN_UNUSED_VARIABLE(A);
EIGEN_UNUSED_VARIABLE(ACC);
#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE)
using LhsElement = std::remove_all_extents_t<std::remove_reference_t<LhsArray_>>;
constexpr bool apply =
GEBPTraits_::Vectorizable && MrPackets <= 2 && NrCols >= 4 && std::is_same<LhsElement, FullLhsPacket_>::value;
EIGEN_IF_CONSTEXPR(apply) {
#ifdef EIGEN_HAS_CXX17_IFCONSTEXPR
// C++17: pin accumulators when they're plain SSE vectors (sizeof matches FullLhsPacket).
// For complex types, AccPacket is a struct (DoublePacket) and the asm is safely discarded.
#define EIGEN_GEBP_SSE_SPILLING_WORKAROUND(MrPackets, NrCols, A, ACC, LhsArray, FullLhsPacket) \
EIGEN_IF_CONSTEXPR( \
(MrPackets <= 2 && NrCols >= 4 && \
std::is_same<std::remove_all_extents_t<std::remove_reference_t<LhsArray>>, FullLhsPacket>::value && \
sizeof(ACC[0]) == sizeof(FullLhsPacket))) { \
EIGEN_IF_CONSTEXPR(MrPackets == 2 && NrCols == 4) { \
__asm__("" \
: "+x"(ACC[0]), "+x"(ACC[1]), "+x"(ACC[2]), "+x"(ACC[3]), "+x"(ACC[4]), "+x"(ACC[5]), "+x"(ACC[6]), \
"+x"(ACC[7])); \
} \
EIGEN_GEBP_SSE_1P_WORKAROUND(MrPackets, NrCols, A, ACC) \
}
using AccElement = std::decay_t<decltype(ACC[0])>;
constexpr bool pin_acc = std::is_same<AccElement, FullLhsPacket_>::value && MrPackets == 2 && NrCols == 4;
if constexpr (pin_acc) {
__asm__(""
: "+x"(ACC[0]), "+x"(ACC[1]), "+x"(ACC[2]), "+x"(ACC[3]), "+x"(ACC[4]), "+x"(ACC[5]), "+x"(ACC[6]),
"+x"(ACC[7]));
}
#else
// C++14: only pin LHS packets (A), not accumulators, to avoid asm errors with complex types.
#define EIGEN_GEBP_SSE_SPILLING_WORKAROUND(MrPackets, NrCols, A, ACC, LhsArray, FullLhsPacket) \
EIGEN_IF_CONSTEXPR( \
(MrPackets <= 2 && NrCols >= 4 && \
std::is_same<std::remove_all_extents_t<std::remove_reference_t<LhsArray>>, FullLhsPacket>::value)) { \
EIGEN_IF_CONSTEXPR(MrPackets == 2) { __asm__("" : "+x,m"(A[0]), "+x,m"(A[1])); } \
EIGEN_GEBP_SSE_1P_WORKAROUND(MrPackets, NrCols, A, ACC) \
EIGEN_IF_CONSTEXPR(MrPackets == 2) { __asm__("" : "+x,m"(A[0]), "+x,m"(A[1])); }
#endif
}
#endif
#if !(EIGEN_COMP_LCC)
#define EIGEN_GEBP_SSE_1P_WORKAROUND(MrPackets, NrCols, A, ACC) \
EIGEN_IF_CONSTEXPR(MrPackets == 1 && NrCols == 1) { __asm__("" : "+x,m"(A[0])); }
#else
#define EIGEN_GEBP_SSE_1P_WORKAROUND(MrPackets, NrCols, A, ACC)
#endif
#else
#define EIGEN_GEBP_SSE_SPILLING_WORKAROUND(MrPackets, NrCols, A, ACC, LhsArray, FullLhsPacket)
#endif
}
// Unrolled peeled loop body: calls gebp_micro_step for K=0..7, handling
// double-accumulation for 1pX4, prefetches, and compiler workarounds.
@@ -1222,10 +1215,8 @@ struct gebp_peeled_loop {
#define EIGEN_GEBP_DO_STEP(KVAL, ACC) \
do { \
gebp_micro_step<KVAL, MrPackets, NrCols>::run(traits, blA, blB, A, rhs_panel, T0, ACC); \
/* ARM64 NEON register alloc workaround for 3-packet paths */ \
EIGEN_GEBP_ARM64_3P_WORKAROUND(MrPackets, A, LhsArray, FullLhsPacket) \
/* GCC SSE spilling workaround: pin LHS packets and accumulators in registers */ \
EIGEN_GEBP_SSE_SPILLING_WORKAROUND(MrPackets, NrCols, A, ACC, LhsArray, FullLhsPacket) \
gebp_neon_3p_workaround<MrPackets, GEBPTraits, FullLhsPacket>(A); \
gebp_sse_spilling_workaround<MrPackets, NrCols, GEBPTraits, FullLhsPacket>(A, ACC); \
/* LHS prefetch for 2pX4 and 3pX4 */ \
EIGEN_IF_CONSTEXPR((MrPackets == 2 || MrPackets == 3) && NrCols == 4) { \
internal::prefetch(blA + (MrPackets * KVAL + 16) * GEBPTraits::LhsProgress); \

View File

@@ -347,6 +347,178 @@ void bug_1308() {
VERIFY_IS_APPROX(r44.noalias() += Vector4d::Ones() * m44.col(0).transpose(), ones44);
}
// Regression test for issue #3059: GEBP asm register constraints fail
// for custom (non-vectorizable) scalar types. Type T has a non-trivial
// destructor (making sizeof(T) > sizeof(double)), while type U is a
// simple wrapper. Both must compile and produce correct products.
namespace issue_3059 {
class Ptr {
public:
~Ptr() {}
double* m_ptr = nullptr;
};
class T {
public:
T() = default;
T(double v) : m_value(v) {}
friend T operator*(const T& a, const T& b) { return T(a.m_value * b.m_value); }
T& operator*=(const T& o) {
m_value *= o.m_value;
return *this;
}
friend T operator/(const T& a, const T& b) { return T(a.m_value / b.m_value); }
T& operator/=(const T& o) {
m_value /= o.m_value;
return *this;
}
friend T operator+(const T& a, const T& b) { return T(a.m_value + b.m_value); }
T& operator+=(const T& o) {
m_value += o.m_value;
return *this;
}
friend T operator-(const T& a, const T& b) { return T(a.m_value - b.m_value); }
T& operator-=(const T& o) {
m_value -= o.m_value;
return *this;
}
friend T operator-(const T& a) { return T(-a.m_value); }
bool operator==(const T& o) const { return m_value == o.m_value; }
bool operator<(const T& o) const { return m_value < o.m_value; }
bool operator<=(const T& o) const { return m_value <= o.m_value; }
bool operator>(const T& o) const { return m_value > o.m_value; }
bool operator>=(const T& o) const { return m_value >= o.m_value; }
bool operator!=(const T& o) const { return m_value != o.m_value; }
double value() const { return m_value; }
private:
double m_value = 0.0;
Ptr m_ptr; // Makes sizeof(T) > sizeof(double)
};
T sqrt(const T& x) { return T(std::sqrt(x.value())); }
T abs(const T& x) { return T(std::abs(x.value())); }
T abs2(const T& x) { return T(x.value() * x.value()); }
class U {
public:
U() = default;
U(double v) : m_value(v) {}
friend U operator*(const U& a, const U& b) { return U(a.m_value * b.m_value); }
U& operator*=(const U& o) {
m_value *= o.m_value;
return *this;
}
friend U operator/(const U& a, const U& b) { return U(a.m_value / b.m_value); }
U& operator/=(const U& o) {
m_value /= o.m_value;
return *this;
}
friend U operator+(const U& a, const U& b) { return U(a.m_value + b.m_value); }
U& operator+=(const U& o) {
m_value += o.m_value;
return *this;
}
friend U operator-(const U& a, const U& b) { return U(a.m_value - b.m_value); }
U& operator-=(const U& o) {
m_value -= o.m_value;
return *this;
}
friend U operator-(const U& a) { return U(-a.m_value); }
bool operator==(const U& o) const { return m_value == o.m_value; }
bool operator<(const U& o) const { return m_value < o.m_value; }
bool operator<=(const U& o) const { return m_value <= o.m_value; }
bool operator>(const U& o) const { return m_value > o.m_value; }
bool operator>=(const U& o) const { return m_value >= o.m_value; }
bool operator!=(const U& o) const { return m_value != o.m_value; }
double value() const { return m_value; }
private:
double m_value = 0.0;
};
U sqrt(const U& x) { return U(std::sqrt(x.value())); }
U abs(const U& x) { return U(std::abs(x.value())); }
U abs2(const U& x) { return U(x.value() * x.value()); }
} // namespace issue_3059
namespace Eigen {
template <>
struct NumTraits<issue_3059::T> : NumTraits<double> {
using Real = issue_3059::T;
using NonInteger = issue_3059::T;
using Nested = issue_3059::T;
enum { IsComplex = 0, RequireInitialization = 1 };
};
template <>
struct NumTraits<issue_3059::U> : NumTraits<double> {
using Real = issue_3059::U;
using NonInteger = issue_3059::U;
using Nested = issue_3059::U;
enum { IsComplex = 0, RequireInitialization = 0 };
};
} // namespace Eigen
template <int>
void product_custom_scalar_types() {
using namespace issue_3059;
// Type T: has non-trivial destructor, sizeof(T) > sizeof(double)
{
Matrix<T, Dynamic, Dynamic> A(4, 4), B(4, 4), C(4, 4);
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) {
A(i, j) = T(static_cast<double>(i + 1));
B(i, j) = T(static_cast<double>(j + 1));
}
C.noalias() = A * B;
// A*B: C(i,j) = sum_k (i+1)*(k+1) * ... no, A(i,k)=(i+1), B(k,j)=(j+1)
// so C(i,j) = sum_k (i+1)*(j+1) = 4*(i+1)*(j+1)
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) VERIFY(C(i, j) == T(4.0 * (i + 1) * (j + 1)));
}
// Type U: simple wrapper, sizeof(U) == sizeof(double)
{
Matrix<U, Dynamic, Dynamic> A(4, 4), B(4, 4), C(4, 4);
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) {
A(i, j) = U(static_cast<double>(i + 1));
B(i, j) = U(static_cast<double>(j + 1));
}
C.noalias() = A * B;
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) VERIFY(C(i, j) == U(4.0 * (i + 1) * (j + 1)));
}
// Larger matrices to exercise GEBP blocking.
{
const int n = 33;
Matrix<U, Dynamic, Dynamic> A(n, n), B(n, n), C(n, n);
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j) {
A(i, j) = U(static_cast<double>((i * 7 + j * 3) % 13));
B(i, j) = U(static_cast<double>((i * 5 + j * 11) % 17));
}
C.noalias() = A * B;
// Verify against explicit triple loop.
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j) {
double sum = 0;
for (int k = 0; k < n; ++k) sum += A(i, k).value() * B(k, j).value();
VERIFY(C(i, j) == U(sum));
}
}
}
EIGEN_DECLARE_TEST(product_extra) {
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1(product_extra(
@@ -369,4 +541,5 @@ EIGEN_DECLARE_TEST(product_extra) {
CALL_SUBTEST_7(compute_block_size<double>());
CALL_SUBTEST_7(compute_block_size<std::complex<double> >());
CALL_SUBTEST_8(aliasing_with_resize<void>());
CALL_SUBTEST_9(product_custom_scalar_types<0>());
}