mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix pexp_complex for complex<double> (issue #3022)
libeigen/eigen!2140 Closes #3022 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -245,6 +245,7 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
|
||||
HasNegate = 1,
|
||||
HasSqrt = 1,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasAbs = 0,
|
||||
HasAbs2 = 0,
|
||||
HasMin = 0,
|
||||
@@ -451,6 +452,11 @@ EIGEN_STRONG_INLINE Packet4cf plog<Packet4cf>(const Packet4cf& a) {
|
||||
return plog_complex<Packet4cf>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cd pexp<Packet2cd>(const Packet2cd& a) {
|
||||
return pexp_complex<Packet2cd>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4cf pexp<Packet4cf>(const Packet4cf& a) {
|
||||
return pexp_complex<Packet4cf>(a);
|
||||
|
||||
@@ -226,6 +226,7 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
|
||||
HasNegate = 1,
|
||||
HasSqrt = 1,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasAbs = 0,
|
||||
HasAbs2 = 0,
|
||||
HasMin = 0,
|
||||
@@ -462,6 +463,11 @@ EIGEN_STRONG_INLINE Packet8cf plog<Packet8cf>(const Packet8cf& a) {
|
||||
return plog_complex<Packet8cf>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4cd pexp<Packet4cd>(const Packet4cd& a) {
|
||||
return pexp_complex<Packet4cd>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8cf pexp<Packet8cf>(const Packet8cf& a) {
|
||||
return pexp_complex<Packet8cf>(a);
|
||||
|
||||
@@ -1117,8 +1117,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin));
|
||||
} else if (Func == TrigFunction::SinCos) {
|
||||
Packet peven = peven_mask(x);
|
||||
sign_bit = pselect((s), sign_sin, sign_cos);
|
||||
sFinalRes = pselect(pxor(peven, poly_mask), ssin, scos);
|
||||
sign_bit = pselect(peven, sign_sin, sign_cos);
|
||||
sFinalRes = pselect(pxor(peven, poly_mask), scos, ssin);
|
||||
}
|
||||
sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
|
||||
sFinalRes = pxor(sFinalRes, sign_bit);
|
||||
@@ -1608,7 +1608,6 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Pa
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) {
|
||||
// FIXME(rmlarsen): This does not work correctly for Packets of std::complex<double>.
|
||||
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename Scalar::value_type RealScalar;
|
||||
@@ -1642,6 +1641,15 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Pa
|
||||
// is (+/-inf, NaN), where the signs are undetermined (take the sign of y).
|
||||
RealPacket y_sign = por(pandnot(y, pabs(y)), pset1<RealPacket>(RealScalar(1)));
|
||||
cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy);
|
||||
|
||||
// If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to
|
||||
// prevent inf * 0 = NaN. The vectorized sincos may compute exact zero
|
||||
// for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct.
|
||||
// The y=0 case is handled separately below.
|
||||
RealPacket cisy_sign_one = por(pand(cisy, pset1<RealPacket>(RealScalar(-0.0))), pset1<RealPacket>(RealScalar(1)));
|
||||
RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf));
|
||||
cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy);
|
||||
|
||||
Packet result = Packet(pmul(expx, cisy));
|
||||
|
||||
// If y is +/- 0, the input is real, so take the real result for consistency.
|
||||
|
||||
@@ -516,6 +516,7 @@ struct packet_traits<std::complex<double>> : default_packet_traits {
|
||||
HasNegate = 1,
|
||||
HasSqrt = 1,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasAbs = 0,
|
||||
HasAbs2 = 0,
|
||||
HasMin = 0,
|
||||
@@ -723,6 +724,11 @@ EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
|
||||
return plog_complex(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
|
||||
return pexp_complex<Packet1cd>(a);
|
||||
}
|
||||
|
||||
#endif // EIGEN_ARCH_ARM64
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
@@ -244,7 +244,8 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
|
||||
HasAbs2 = 0,
|
||||
HasMin = 0,
|
||||
HasMax = 0,
|
||||
HasSetLinear = 0
|
||||
HasSetLinear = 0,
|
||||
HasExp = 1
|
||||
};
|
||||
};
|
||||
#endif
|
||||
@@ -432,6 +433,11 @@ EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
|
||||
return plog_complex<Packet2cf>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
|
||||
return pexp_complex<Packet1cd>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
|
||||
return pexp_complex<Packet2cf>(a);
|
||||
|
||||
@@ -81,7 +81,6 @@ struct packet_traits<std::complex<double>> : generic_complex_packet_traits {
|
||||
using half = Packet4cd;
|
||||
enum {
|
||||
size = 4,
|
||||
HasExp = 0, // FIXME(rmlarsen): pexp_complex is broken for double.
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -1485,10 +1485,10 @@ struct exp_complex_test_impl {
|
||||
}
|
||||
|
||||
// Verify equality with signed zero.
|
||||
static bool is_exactly_equal(const Scalar& a, const Scalar& b) {
|
||||
static bool is_exactly_equal(const Scalar& a, const Scalar& b, bool quiet = false) {
|
||||
bool result = is_exactly_equal(numext::real_ref(a), numext::real_ref(b)) &&
|
||||
is_exactly_equal(numext::imag_ref(a), numext::imag_ref(b));
|
||||
if (!result) {
|
||||
if (!result && !quiet) {
|
||||
std::cout << a << " != " << b << std::endl;
|
||||
}
|
||||
return result;
|
||||
@@ -1512,6 +1512,13 @@ struct exp_complex_test_impl {
|
||||
if (numext::real_ref(z) == +inf && (numext::isnan)(numext::imag_ref(z))) {
|
||||
return true;
|
||||
}
|
||||
// If exp(x) overflows to inf and y is finite nonzero, the result involves inf * cos(y) and
|
||||
// inf * sin(y). When cos(y) or sin(y) is near a zero crossing (e.g., cos(pi/2)), different
|
||||
// trig implementations may produce different signs, so the signs of the result are unspecified.
|
||||
if (!(numext::isinf)(numext::imag_ref(z)) && !(numext::isnan)(numext::imag_ref(z)) && numext::imag_ref(z) != 0 &&
|
||||
(numext::isinf)(std::exp(numext::real_ref(z)))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1558,7 +1565,12 @@ struct exp_complex_test_impl {
|
||||
Scalar(numext::abs(numext::real_ref(expected)), numext::abs(numext::imag_ref(expected)));
|
||||
VERIFY(is_exactly_equal(abs_w, abs_expected));
|
||||
} else {
|
||||
VERIFY(is_exactly_equal(w, numext::exp(z)));
|
||||
Scalar expected = numext::exp(z);
|
||||
// First try exact equality (handles NaN, signed zeros correctly).
|
||||
// Fall back to approximate comparison to allow for small differences
|
||||
// in trig functions near zero crossings (e.g., vectorized sincos may
|
||||
// compute cos(pi/2) = 0 while scalar std::exp gives ~6.12e-17).
|
||||
VERIFY(is_exactly_equal(w, expected, /*quiet=*/true) || verifyIsApprox(w, expected));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user