Fix x86 complex vectorized fma

This commit is contained in:
Charles Schlosser
2025-03-12 17:06:32 +00:00
parent 464c1d0978
commit 10e62ccd22
3 changed files with 34 additions and 48 deletions

View File

@@ -92,6 +92,26 @@ inline T REF_LDEXP(const T& x, const T& exp) {
return static_cast<T>(ldexp(x, static_cast<int>(exp)));
}
// provides a convenient function to take the absolute value of each component of a complex number to prevent
// catastrophic cancellation in randomly generated complex numbers
template <typename T, bool IsComplex = NumTraits<T>::IsComplex>
struct abs_helper_impl {
static T run(T x) { return numext::abs(x); }
};
template <typename T>
struct abs_helper_impl<T, true> {
static T run(T x) {
T res = x;
numext::real_ref(res) = numext::abs(numext::real(res));
numext::imag_ref(res) = numext::abs(numext::imag(res));
return res;
}
};
template <typename T>
T abs_helper(T x) {
return abs_helper_impl<T>::run(x);
}
// Uses pcast to cast from one array to another.
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
struct pcast_array;
@@ -724,11 +744,6 @@ void packetmath() {
packetmath_pcast_ops_runner<Scalar, Packet>::run();
packetmath_minus_zero_add_test<Scalar, Packet>::run();
for (int i = 0; i < size; ++i) {
data1[i] = numext::abs(internal::random<Scalar>());
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd);
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
nmsub_test<Scalar, Packet>(data1, data2, ref, PacketSize);
@@ -738,14 +753,17 @@ void packetmath() {
// which can lead to very flaky tests. Here we ensure the signs are such that
// they do not cancel.
for (int i = 0; i < PacketSize; ++i) {
data1[i] = numext::abs(internal::random<Scalar>());
data1[i + PacketSize] = numext::abs(internal::random<Scalar>());
data1[i + 2 * PacketSize] = Scalar(0) - numext::abs(internal::random<Scalar>());
data1[i] = abs_helper(internal::random<Scalar>());
data1[i + PacketSize] = abs_helper(internal::random<Scalar>());
data1[i + 2 * PacketSize] = Scalar(0) - abs_helper(internal::random<Scalar>());
}
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub);
CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd);
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
}
// Notice that this definition works for complex types as well.