From e6e5b5c4c83194c83f8eb7e8142602ce489d5a4f Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen <4643818-rmlarsen1@users.noreply.gitlab.com> Date: Mon, 16 Feb 2026 15:30:31 -0800 Subject: [PATCH] Fix pexp_complex for `complex` (issue #3022) libeigen/eigen!2140 Closes #3022 Co-authored-by: Rasmus Munk Larsen --- Eigen/src/Core/arch/AVX/Complex.h | 6 ++++++ Eigen/src/Core/arch/AVX512/Complex.h | 6 ++++++ .../arch/Default/GenericPacketMathFunctions.h | 14 +++++++++++--- Eigen/src/Core/arch/NEON/Complex.h | 6 ++++++ Eigen/src/Core/arch/SSE/Complex.h | 8 +++++++- Eigen/src/Core/arch/clang/Complex.h | 1 - test/packetmath.cpp | 18 +++++++++++++++--- 7 files changed, 51 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index bf19df79a..407300289 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -245,6 +245,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -451,6 +452,11 @@ EIGEN_STRONG_INLINE Packet4cf plog(const Packet4cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet2cd pexp(const Packet2cd& a) { + return pexp_complex(a); +} + template <> EIGEN_STRONG_INLINE Packet4cf pexp(const Packet4cf& a) { return pexp_complex(a); diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index ba15f41db..ab426ac09 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -226,6 +226,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -462,6 +463,11 @@ EIGEN_STRONG_INLINE Packet8cf plog(const Packet8cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet4cd pexp(const Packet4cd& a) { + return pexp_complex(a); +} + template <> EIGEN_STRONG_INLINE Packet8cf pexp(const Packet8cf& a) { return pexp_complex(a); diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index d38233a43..febd29542 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1117,8 +1117,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin)); } else if (Func == TrigFunction::SinCos) { Packet peven = peven_mask(x); - sign_bit = pselect((s), sign_sin, sign_cos); - sFinalRes = pselect(pxor(peven, poly_mask), ssin, scos); + sign_bit = pselect(peven, sign_sin, sign_cos); + sFinalRes = pselect(pxor(peven, poly_mask), scos, ssin); } sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit sFinalRes = pxor(sFinalRes, sign_bit); @@ -1608,7 +1608,6 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Pa template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) { - // FIXME(rmlarsen): This does not work correctly for Packets of std::complex. typedef typename unpacket_traits::as_real RealPacket; typedef typename unpacket_traits::type Scalar; typedef typename Scalar::value_type RealScalar; @@ -1642,6 +1641,15 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Pa // is (+/-inf, NaN), where the signs are undetermined (take the sign of y). RealPacket y_sign = por(pandnot(y, pabs(y)), pset1(RealScalar(1))); cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy); + + // If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to + // prevent inf * 0 = NaN. The vectorized sincos may compute exact zero + // for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct. + // The y=0 case is handled separately below. + RealPacket cisy_sign_one = por(pand(cisy, pset1(RealScalar(-0.0))), pset1(RealScalar(1))); + RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf)); + cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy); + Packet result = Packet(pmul(expx, cisy)); // If y is +/- 0, the input is real, so take the real result for consistency. diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h index b8655c80d..9b476e051 100644 --- a/Eigen/src/Core/arch/NEON/Complex.h +++ b/Eigen/src/Core/arch/NEON/Complex.h @@ -516,6 +516,7 @@ struct packet_traits> : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -723,6 +724,11 @@ EIGEN_STRONG_INLINE Packet1cd plog(const Packet1cd& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet1cd pexp(const Packet1cd& a) { + return pexp_complex(a); +} + #endif // EIGEN_ARCH_ARM64 } // end namespace internal diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index 4002c1612..484ec4c71 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -244,7 +244,8 @@ struct packet_traits > : default_packet_traits { HasAbs2 = 0, HasMin = 0, HasMax = 0, - HasSetLinear = 0 + HasSetLinear = 0, + HasExp = 1 }; }; #endif @@ -432,6 +433,11 @@ EIGEN_STRONG_INLINE Packet2cf plog(const Packet2cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet1cd pexp(const Packet1cd& a) { + return pexp_complex(a); +} + template <> EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a) { return pexp_complex(a); diff --git a/Eigen/src/Core/arch/clang/Complex.h b/Eigen/src/Core/arch/clang/Complex.h index d6cc435a6..2a62f0b85 100644 --- a/Eigen/src/Core/arch/clang/Complex.h +++ b/Eigen/src/Core/arch/clang/Complex.h @@ -81,7 +81,6 @@ struct packet_traits> : generic_complex_packet_traits { using half = Packet4cd; enum { size = 4, - HasExp = 0, // FIXME(rmlarsen): pexp_complex is broken for double. }; }; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index e43c110b7..12ba95d88 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -1485,10 +1485,10 @@ struct exp_complex_test_impl { } // Verify equality with signed zero. - static bool is_exactly_equal(const Scalar& a, const Scalar& b) { + static bool is_exactly_equal(const Scalar& a, const Scalar& b, bool quiet = false) { bool result = is_exactly_equal(numext::real_ref(a), numext::real_ref(b)) && is_exactly_equal(numext::imag_ref(a), numext::imag_ref(b)); - if (!result) { + if (!result && !quiet) { std::cout << a << " != " << b << std::endl; } return result; @@ -1512,6 +1512,13 @@ struct exp_complex_test_impl { if (numext::real_ref(z) == +inf && (numext::isnan)(numext::imag_ref(z))) { return true; } + // If exp(x) overflows to inf and y is finite nonzero, the result involves inf * cos(y) and + // inf * sin(y). When cos(y) or sin(y) is near a zero crossing (e.g., cos(pi/2)), different + // trig implementations may produce different signs, so the signs of the result are unspecified. + if (!(numext::isinf)(numext::imag_ref(z)) && !(numext::isnan)(numext::imag_ref(z)) && numext::imag_ref(z) != 0 && + (numext::isinf)(std::exp(numext::real_ref(z)))) { + return true; + } return false; } @@ -1558,7 +1565,12 @@ struct exp_complex_test_impl { Scalar(numext::abs(numext::real_ref(expected)), numext::abs(numext::imag_ref(expected))); VERIFY(is_exactly_equal(abs_w, abs_expected)); } else { - VERIFY(is_exactly_equal(w, numext::exp(z))); + Scalar expected = numext::exp(z); + // First try exact equality (handles NaN, signed zeros correctly). + // Fall back to approximate comparison to allow for small differences + // in trig functions near zero crossings (e.g., vectorized sincos may + // compute cos(pi/2) = 0 while scalar std::exp gives ~6.12e-17). + VERIFY(is_exactly_equal(w, expected, /*quiet=*/true) || verifyIsApprox(w, expected)); } } }