Fix pdiv for complex packets involving infinites.

libeigen/eigen!2131
This commit is contained in:
Antonio Sánchez
2026-02-15 01:47:32 +00:00
committed by Rasmus Munk Larsen
parent 9b709e8269
commit 1a2b80727c
4 changed files with 89 additions and 69 deletions

View File

@@ -2102,6 +2102,57 @@ struct conj_impl<std::complex<T>, true> {
};
#endif
// Complex multiply and division operators.
// Note that these do not handle the case if inf+NaNi, which is considered an infinity.
// This is for consistency with our standard pmul, pdiv implementations.
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_multiply(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
return std::complex<T>(a_real * b_real - a_imag * b_imag, a_imag * b_real + a_real * b_imag);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide_fast(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
const T norm = (b_real * b_real + b_imag * b_imag);
return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm, (a_imag * b_real - a_real * b_imag) / norm);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide_smith(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
// guards against over/under-flow.
const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
const T rscale = scale_imag ? T(1) : b_real / b_imag;
const T iscale = scale_imag ? b_imag / b_real : T(1);
const T denominator = b_real * rscale + b_imag * iscale;
return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
(a_imag * rscale - a_real * iscale) / denominator);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::complex<T>& a,
const std::complex<T>& b) {
#if EIGEN_FAST_MATH
return complex_divide_fast(a, b);
#else
return complex_divide_smith(a, b);
#endif
}
} // end namespace internal
} // end namespace Eigen

View File

@@ -1553,21 +1553,20 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_double(const P
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::as_real RealPacket;
typedef typename unpacket_traits<RealPacket>::type RealScalar;
// In the following we annotate the code for the case where the inputs
// are a pair length-2 SIMD vectors representing a single pair of complex
// numbers x = a + i*b, y = c + i*d.
const RealPacket y_abs = pabs(y.v); // |c|, |d|
const RealPacket y_abs_flip = pcplxflip(Packet(y_abs)).v; // |d|, |c|
const RealPacket y_max = pmax(y_abs, y_abs_flip); // max(|c|, |d|), max(|c|, |d|)
const RealPacket y_scaled = pdiv(y.v, y_max); // c / max(|c|, |d|), d / max(|c|, |d|)
// Compute scaled denominator.
const RealPacket y_scaled_sq = pmul(y_scaled, y_scaled); // c'**2, d'**2
const RealPacket denom = padd(y_scaled_sq, pcplxflip(Packet(y_scaled_sq)).v);
Packet result_scaled = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c
// Divide elementwise by denom.
result_scaled = Packet(pdiv(result_scaled.v, denom));
// Rescale result
return Packet(pdiv(result_scaled.v, y_max));
static const RealPacket one = pset1<RealPacket>(RealScalar(1));
const RealPacket y_flip = pcplxflip(y).v;
// We need to avoid dividing by Inf/Inf, so use a mask to carefully
// apply the scale.
const RealPacket mask = pcmp_lt(pabs(y.v), pabs(y_flip)); // |c| < |d|
const RealPacket y_scaled = pselect(mask, pdiv(y.v, y_flip), one);
RealPacket denom = pmul(y.v, y_scaled);
denom = padd(denom, pcplxflip(Packet(denom)).v); // c * c' + d * d'
Packet num = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c
return Packet(pdiv(num.v, denom));
}
template <typename Packet>

View File

@@ -62,54 +62,6 @@ namespace Eigen {
// Specialized std::complex overloads.
namespace complex_operator_detail {
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_multiply(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
return std::complex<T>(a_real * b_real - a_imag * b_imag, a_imag * b_real + a_real * b_imag);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide_fast(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
const T norm = (b_real * b_real + b_imag * b_imag);
return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm, (a_imag * b_real - a_real * b_imag) / norm);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide_stable(const std::complex<T>& a,
const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
// guards against over/under-flow.
const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
const T rscale = scale_imag ? T(1) : b_real / b_imag;
const T iscale = scale_imag ? b_imag / b_real : T(1);
const T denominator = b_real * rscale + b_imag * iscale;
return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
(a_imag * rscale - a_real * iscale) / denominator);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::complex<T>& a,
const std::complex<T>& b) {
#if EIGEN_FAST_MATH
return complex_divide_fast(a, b);
#else
return complex_divide_stable(a, b);
#endif
}
// NOTE: We cannot specialize compound assignment operators with Scalar T,
// (i.e. operator@=(const T&), for @=+,-,*,/)
// since they are already specialized for float/double/long double within
@@ -151,7 +103,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator*(const std::complex<T>& a, \
const std::complex<T>& b) { \
return complex_multiply(a, b); \
return internal::complex_multiply(a, b); \
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator*(const std::complex<T>& a, const T& b) { \
@@ -164,7 +116,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator/(const std::complex<T>& a, \
const std::complex<T>& b) { \
return complex_divide(a, b); \
return internal::complex_divide(a, b); \
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator/(const std::complex<T>& a, const T& b) { \
@@ -172,7 +124,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator/(const T& a, const std::complex<T>& b) { \
return complex_divide(std::complex<T>(a, 0), b); \
return internal::complex_divide(std::complex<T>(a, 0), b); \
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T>& operator+=(std::complex<T>& a, const std::complex<T>& b) { \
@@ -188,12 +140,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> complex_divide(const std::
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T>& operator*=(std::complex<T>& a, const std::complex<T>& b) { \
a = complex_multiply(a, b); \
a = internal::complex_multiply(a, b); \
return a; \
} \
\
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T>& operator/=(std::complex<T>& a, const std::complex<T>& b) { \
a = complex_divide(a, b); \
a = internal::complex_divide(a, b); \
return a; \
} \
\

View File

@@ -1607,6 +1607,28 @@ void packetmath_complex() {
VERIFY(test::areApprox(ref, pval, PacketSize) && "pcplxflip");
}
const RealScalar zero = RealScalar(0);
const RealScalar one = RealScalar(1);
const RealScalar inf = std::numeric_limits<RealScalar>::infinity();
const RealScalar nan = std::numeric_limits<RealScalar>::quiet_NaN();
// Multiplication and Division.
{
std::array<RealScalar, 8> special_values = {zero, one, inf, nan, -zero, -one, -inf, -nan};
for (RealScalar a : special_values) {
for (RealScalar b : special_values) {
for (RealScalar c : special_values) {
for (RealScalar d : special_values) {
data1[0] = Scalar(a, b);
data2[0] = Scalar(c, d);
CHECK_CWISE2_IF(PacketTraits::HasMul, internal::complex_multiply, internal::pmul);
CHECK_CWISE2_IF(PacketTraits::HasDiv, internal::complex_divide, internal::pdiv);
}
}
}
}
}
if (PacketTraits::HasSqrt) {
for (int i = 0; i < size; ++i) {
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
@@ -1615,10 +1637,6 @@ void packetmath_complex() {
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
// Test misc. corner cases.
const RealScalar zero = RealScalar(0);
const RealScalar one = RealScalar(1);
const RealScalar inf = std::numeric_limits<RealScalar>::infinity();
const RealScalar nan = std::numeric_limits<RealScalar>::quiet_NaN();
data1[0] = Scalar(zero, zero);
data1[1] = Scalar(-zero, zero);
data1[2] = Scalar(one, zero);