Consolidate complex math function boilerplate with shared macros

libeigen/eigen!2201

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-25 07:21:20 -08:00
parent c4c704e5dd
commit d0d70a9527
9 changed files with 45 additions and 208 deletions

View File

@@ -432,25 +432,8 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cd, 2>& kernel) {
kernel.packet[0].v = tmp;
}
template <>
EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) {
return psqrt_complex<Packet2cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) {
return psqrt_complex<Packet4cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cd plog<Packet2cd>(const Packet2cd& a) {
return plog_complex<Packet2cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cf plog<Packet4cf>(const Packet4cf& a) {
return plog_complex<Packet4cf>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(Packet2cd)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet4cf)
template <>
EIGEN_STRONG_INLINE Packet2cd pexp<Packet2cd>(const Packet2cd& a) {
@@ -465,11 +448,6 @@ EIGEN_STRONG_INLINE Packet2cd pexp<Packet2cd>(const Packet2cd& a) {
#endif
}
template <>
EIGEN_STRONG_INLINE Packet4cf pexp<Packet4cf>(const Packet4cf& a) {
return pexp_complex<Packet4cf>(a);
}
#ifdef EIGEN_VECTORIZE_FMA
// std::complex<float>
template <>

View File

@@ -443,35 +443,8 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4cd, 4>& kernel) {
kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0, 2, 0, 2>::mask))); // [a0 b0 c0 d0]
}
template <>
EIGEN_STRONG_INLINE Packet4cd psqrt<Packet4cd>(const Packet4cd& a) {
return psqrt_complex<Packet4cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet8cf psqrt<Packet8cf>(const Packet8cf& a) {
return psqrt_complex<Packet8cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cd plog<Packet4cd>(const Packet4cd& a) {
return plog_complex<Packet4cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet8cf plog<Packet8cf>(const Packet8cf& a) {
return plog_complex<Packet8cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet4cd pexp<Packet4cd>(const Packet4cd& a) {
return pexp_complex<Packet4cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet8cf pexp<Packet8cf>(const Packet8cf& a) {
return pexp_complex<Packet8cf>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet4cd)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet8cf)
} // end namespace internal
} // end namespace Eigen

View File

@@ -361,20 +361,7 @@ EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) {
return Packet2cf(vec_and(eq, vec_perm(eq, eq, p16uc_COMPLEX32_REV)));
}
template <>
EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex<Packet2cf>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf)
//---------- double ----------
#ifdef EIGEN_VECTORIZE_VSX
@@ -621,15 +608,7 @@ EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) {
return Packet1cd(vec_and(eq, eq_swapped));
}
template <>
EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex<Packet1cd>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(Packet1cd)
#endif // __VSX__
} // end namespace internal

View File

@@ -237,6 +237,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET)
// Macro to instantiate complex math function specializations (psqrt, plog, pexp)
// that delegate to the generic implementations. Use in arch-specific Complex.h files.
#define EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(PacketType) \
template <> \
EIGEN_STRONG_INLINE PacketType psqrt<PacketType>(const PacketType& a) { \
return psqrt_complex<PacketType>(a); \
} \
template <> \
EIGEN_STRONG_INLINE PacketType plog<PacketType>(const PacketType& a) { \
return plog_complex<PacketType>(a); \
} \
template <> \
EIGEN_STRONG_INLINE PacketType pexp<PacketType>(const PacketType& a) { \
return pexp_complex<PacketType>(a); \
}
// Variant without pexp, for backends where pexp needs special handling for a given packet type.
#define EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(PacketType) \
template <> \
EIGEN_STRONG_INLINE PacketType psqrt<PacketType>(const PacketType& a) { \
return psqrt_complex<PacketType>(a); \
} \
template <> \
EIGEN_STRONG_INLINE PacketType plog<PacketType>(const PacketType& a) { \
return plog_complex<PacketType>(a); \
}
} // end namespace internal
} // end namespace Eigen

View File

@@ -226,11 +226,6 @@ EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2c
return pdiv_complex(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pzero(const Packet2cf& /* a */) {
__m128 v = {0.0f, 0.0f, 0.0f, 0.0f};
@@ -251,11 +246,6 @@ EIGEN_STRONG_INLINE Packet2cf pmadd<Packet2cf>(const Packet2cf& a, const Packet2
return result;
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex(a);
}
//---------- double ----------
struct Packet1cd {
EIGEN_STRONG_INLINE Packet1cd() {}
@@ -458,20 +448,8 @@ EIGEN_DEVICE_FUNC inline Packet2cf pselect(const Packet2cf& mask, const Packet2c
return res;
}
template <>
EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(Packet1cd)
template <>
EIGEN_STRONG_INLINE Packet1cd pzero<Packet1cd>(const Packet1cd& /* a */) {

View File

@@ -456,35 +456,8 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cf, 2>& kernel) {
kernel.packet[1].v = tmp;
}
template <>
EIGEN_STRONG_INLINE Packet1cf psqrt<Packet1cf>(const Packet1cf& a) {
return psqrt_complex<Packet1cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cf plog<Packet1cf>(const Packet1cf& a) {
return plog_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cf pexp<Packet1cf>(const Packet1cf& a) {
return pexp_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet1cf)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf)
//---------- double ----------
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
@@ -714,20 +687,7 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd, 2>& kernel) {
kernel.packet[1].v = tmp;
}
template <>
EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
return pexp_complex<Packet1cd>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet1cd)
#endif // EIGEN_ARCH_ARM64

View File

@@ -413,35 +413,8 @@ EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) {
return Packet1cd(pand<Packet2d>(eq, vec2d_swizzle1(eq, 1, 0)));
}
template <>
EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd pexp<Packet1cd>(const Packet1cd& a) {
return pexp_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex<Packet2cf>(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet1cd)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf)
#ifdef EIGEN_VECTORIZE_FMA
// std::complex<float>

View File

@@ -20,7 +20,8 @@ namespace internal {
#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
inline Packet4ui p4ui_CONJ_XOR() {
return Packet4ui {0x00000000, 0x80000000, 0x00000000, 0x80000000}; // vec_mergeh((Packet4ui)p4i_ZERO, (Packet4ui)p4f_MZERO);
return Packet4ui{0x00000000, 0x80000000, 0x00000000,
0x80000000}; // vec_mergeh((Packet4ui)p4i_ZERO, (Packet4ui)p4f_MZERO);
}
#endif
@@ -255,29 +256,8 @@ EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1c
return pdiv_complex(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet1cd plog<Packet1cd>(const Packet1cd& a) {
return plog_complex<Packet1cd>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf plog<Packet2cf>(const Packet2cf& a) {
return plog_complex<Packet2cf>(a);
}
template <>
EIGEN_STRONG_INLINE Packet2cf pexp<Packet2cf>(const Packet2cf& a) {
return pexp_complex(a);
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(Packet1cd)
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf)
EIGEN_STRONG_INLINE Packet1cd pcplxflip /*<Packet1cd>*/ (const Packet1cd& x) {
return Packet1cd(preverse(Packet2d(x.v)));

View File

@@ -147,18 +147,7 @@ EIGEN_STRONG_INLINE Packet4cd pset1<Packet4cd>(const std::complex<double>& from)
EIGEN_STRONG_INLINE unpacket_traits<PACKET_TYPE>::type pfirst<PACKET_TYPE>(const PACKET_TYPE& a) { \
return a[0]; \
} \
template <> \
EIGEN_STRONG_INLINE PACKET_TYPE pexp<PACKET_TYPE>(const PACKET_TYPE& a) { \
return pexp_complex(a); \
} \
template <> \
EIGEN_STRONG_INLINE PACKET_TYPE plog<PACKET_TYPE>(const PACKET_TYPE& a) { \
return plog_complex(a); \
} \
template <> \
EIGEN_STRONG_INLINE PACKET_TYPE psqrt<PACKET_TYPE>(const PACKET_TYPE& a) { \
return psqrt_complex(a); \
}
EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(PACKET_TYPE)
EIGEN_CLANG_COMPLEX_UNARY_CWISE_OPS(Packet8cf);
EIGEN_CLANG_COMPLEX_UNARY_CWISE_OPS(Packet4cd);