Fix pexp_complex for complex<double> (issue #3022)

libeigen/eigen!2140

Closes #3022

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-16 15:30:31 -08:00
parent 2b561f9284
commit e6e5b5c4c8
7 changed files with 51 additions and 8 deletions

View File

@@ -245,6 +245,7 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
HasNegate = 1,
HasSqrt = 1,
HasLog = 1,
HasExp = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -451,6 +452,11 @@ EIGEN_STRONG_INLINE Packet4cf plog<Packet4cf>(const Packet4cf& a) {
return plog_complex<Packet4cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cd pexp<Packet2cd>(const Packet2cd& a) {
return pexp_complex<Packet2cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cf pexp<Packet4cf>(const Packet4cf& a) {
return pexp_complex<Packet4cf>(a);

View File

@@ -226,6 +226,7 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
HasNegate = 1,
HasSqrt = 1,
HasLog = 1,
HasExp = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -462,6 +463,11 @@ EIGEN_STRONG_INLINE Packet8cf plog<Packet8cf>(const Packet8cf& a) {
return plog_complex<Packet8cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cd pexp<Packet4cd>(const Packet4cd& a) {
return pexp_complex<Packet4cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet8cf pexp<Packet8cf>(const Packet8cf& a) {
return pexp_complex<Packet8cf>(a);

View File

@@ -1117,8 +1117,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin));
} else if (Func == TrigFunction::SinCos) {
Packet peven = peven_mask(x);
sign_bit = pselect((s), sign_sin, sign_cos);
sFinalRes = pselect(pxor(peven, poly_mask), ssin, scos);
sign_bit = pselect(peven, sign_sin, sign_cos);
sFinalRes = pselect(pxor(peven, poly_mask), scos, ssin);
}
sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
sFinalRes = pxor(sFinalRes, sign_bit);
@@ -1608,7 +1608,6 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Pa
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) {
// FIXME(rmlarsen): This does not work correctly for Packets of std::complex<double>.
typedef typename unpacket_traits<Packet>::as_real RealPacket;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
@@ -1642,6 +1641,15 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Pa
// is (+/-inf, NaN), where the signs are undetermined (take the sign of y).
RealPacket y_sign = por(pandnot(y, pabs(y)), pset1<RealPacket>(RealScalar(1)));
cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy);
// If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to
// prevent inf * 0 = NaN. The vectorized sincos may compute exact zero
// for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct.
// The y=0 case is handled separately below.
RealPacket cisy_sign_one = por(pand(cisy, pset1<RealPacket>(RealScalar(-0.0))), pset1<RealPacket>(RealScalar(1)));
RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf));
cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy);
Packet result = Packet(pmul(expx, cisy));
// If y is +/- 0, the input is real, so take the real result for consistency.

View File

@@ -516,6 +516,7 @@ struct packet_traits<std::complex<double>> : default_packet_traits {
HasNegate = 1,
HasSqrt = 1,
HasLog = 1,
HasExp = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -723,6 +724,11 @@ EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
return pexp_complex<Packet1cd>(a);
}
#endif // EIGEN_ARCH_ARM64
} // end namespace internal

View File

@@ -244,7 +244,8 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0
HasSetLinear = 0,
HasExp = 1
};
};
#endif
@@ -432,6 +433,11 @@ EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
return pexp_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex<Packet2cf>(a);

View File

@@ -81,7 +81,6 @@ struct packet_traits<std::complex<double>> : generic_complex_packet_traits {
using half = Packet4cd;
enum {
size = 4,
HasExp = 0, // FIXME(rmlarsen): pexp_complex is broken for double.
};
};

View File

@@ -1485,10 +1485,10 @@ struct exp_complex_test_impl {
}
// Verify equality with signed zero.
static bool is_exactly_equal(const Scalar& a, const Scalar& b) {
static bool is_exactly_equal(const Scalar& a, const Scalar& b, bool quiet = false) {
bool result = is_exactly_equal(numext::real_ref(a), numext::real_ref(b)) &&
is_exactly_equal(numext::imag_ref(a), numext::imag_ref(b));
if (!result) {
if (!result && !quiet) {
std::cout << a << " != " << b << std::endl;
}
return result;
@@ -1512,6 +1512,13 @@ struct exp_complex_test_impl {
if (numext::real_ref(z) == +inf && (numext::isnan)(numext::imag_ref(z))) {
return true;
}
// If exp(x) overflows to inf and y is finite nonzero, the result involves inf * cos(y) and
// inf * sin(y). When cos(y) or sin(y) is near a zero crossing (e.g., cos(pi/2)), different
// trig implementations may produce different signs, so the signs of the result are unspecified.
if (!(numext::isinf)(numext::imag_ref(z)) && !(numext::isnan)(numext::imag_ref(z)) && numext::imag_ref(z) != 0 &&
(numext::isinf)(std::exp(numext::real_ref(z)))) {
return true;
}
return false;
}
@@ -1558,7 +1565,12 @@ struct exp_complex_test_impl {
Scalar(numext::abs(numext::real_ref(expected)), numext::abs(numext::imag_ref(expected)));
VERIFY(is_exactly_equal(abs_w, abs_expected));
} else {
VERIFY(is_exactly_equal(w, numext::exp(z)));
Scalar expected = numext::exp(z);
// First try exact equality (handles NaN, signed zeros correctly).
// Fall back to approximate comparison to allow for small differences
// in trig functions near zero crossings (e.g., vectorized sincos may
// compute cos(pi/2) = 0 while scalar std::exp gives ~6.12e-17).
VERIFY(is_exactly_equal(w, expected, /*quiet=*/true) || verifyIsApprox(w, expected));
}
}
}