Improve packet op test coverage for IEEE special values.

libeigen/eigen!2075

Co-authored-by: Rasmus Munk Larsen <rmlarsen@google.com>
This commit is contained in:
Rasmus Munk Larsen
2025-11-12 22:19:50 +00:00
parent 72bfca3d82
commit a7674b70d3
4 changed files with 110 additions and 95 deletions

View File

@@ -805,84 +805,6 @@ Scalar log2(Scalar x) {
return Scalar(EIGEN_LOG2E) * std::log(x);
}
// Create a functor out of a function so it can be passed (with overloads)
// to another function as an input argument.
#define CREATE_FUNCTOR(Name, Func) \
struct Name { \
template <typename T> \
T operator()(const T& val) const { \
return Func(val); \
} \
}
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt);
// TODO(rmlarsen): Run this test for more functions.
template <bool Cond, typename Scalar, typename Packet, typename RefFunctorT, typename FunctorT>
void packetmath_test_IEEE_corner_cases(const RefFunctorT& ref_fun, const FunctorT& fun) {
const int PacketSize = internal::unpacket_traits<Packet>::size;
const Scalar norm_min = (std::numeric_limits<Scalar>::min)();
const Scalar norm_max = (std::numeric_limits<Scalar>::max)();
constexpr int size = PacketSize * 2;
EIGEN_ALIGN_MAX Scalar data1[size];
EIGEN_ALIGN_MAX Scalar data2[size];
EIGEN_ALIGN_MAX Scalar ref[size];
for (int i = 0; i < size; ++i) {
data1[i] = data2[i] = ref[i] = Scalar(0);
}
// Test for subnormals.
if (Cond && std::numeric_limits<Scalar>::has_denorm == std::denorm_present && !EIGEN_ARCH_ARM) {
for (int scale = 1; scale < 5; ++scale) {
// When EIGEN_FAST_MATH is 1 we relax the conditions slightly, and allow the function
// to return the same value for subnormals as the reference would return for zero with
// the same sign as the input.
#if EIGEN_FAST_MATH
data1[0] = Scalar(scale) * std::numeric_limits<Scalar>::denorm_min();
data1[1] = -data1[0];
test::packet_helper<Cond, Packet> h;
h.store(data2, fun(h.load(data1)));
for (int i = 0; i < PacketSize; ++i) {
const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0));
const Scalar ref_val = ref_fun(data1[i]);
VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero ||
verifyIsApprox(data2[i], ref_val));
}
#else
CHECK_CWISE1_IF(Cond, ref_fun, fun);
#endif
}
}
// Test for smallest normalized floats.
data1[0] = norm_min;
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, ref_fun, fun);
// Test for largest floats.
data1[0] = norm_max;
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, ref_fun, fun);
// Test for zeros.
data1[0] = Scalar(0.0);
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, ref_fun, fun);
// Test for infinities.
data1[0] = NumTraits<Scalar>::infinity();
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, ref_fun, fun);
// Test for quiet NaNs.
data1[0] = std::numeric_limits<Scalar>::quiet_NaN();
data1[1] = -std::numeric_limits<Scalar>::quiet_NaN();
CHECK_CWISE1_IF(Cond, ref_fun, fun);
}
template <typename Scalar, typename Packet>
void packetmath_real() {
typedef internal::packet_traits<Scalar> PacketTraits;
@@ -1071,18 +993,12 @@ void packetmath_real() {
test::packet_helper<PacketTraits::HasExp, Packet> h;
h.store(data2, internal::pexp(h.load(data1)));
VERIFY((numext::isnan)(data2[0]));
// TODO(rmlarsen): Re-enable for bfloat16.
if (!internal::is_same<Scalar, bfloat16>::value) {
VERIFY_IS_APPROX(std::exp(small), data2[1]);
}
VERIFY_IS_APPROX(std::exp(small), data2[1]);
data1[0] = -small;
data1[1] = Scalar(0);
h.store(data2, internal::pexp(h.load(data1)));
// TODO(rmlarsen): Re-enable for bfloat16.
if (!internal::is_same<Scalar, bfloat16>::value) {
VERIFY_IS_APPROX(std::exp(-small), data2[0]);
}
VERIFY_IS_APPROX(std::exp(-small), data2[0]);
VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]);
data1[0] = (std::numeric_limits<Scalar>::min)();
@@ -1186,10 +1102,6 @@ void packetmath_real() {
VERIFY((numext::isnan)(data2[1]));
}
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, psqrt_functor());
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, prsqrt_functor());
packetmath_test_IEEE_corner_cases<PacketTraits::HasCbrt, Scalar, Packet>(numext::cbrt<Scalar>, pcbrt_functor());
// TODO(rmlarsen): Re-enable for half and bfloat16.
if (PacketTraits::HasCos && !internal::is_same<Scalar, half>::value &&
!internal::is_same<Scalar, bfloat16>::value) {
@@ -1292,8 +1204,100 @@ Scalar propagate_number_min(const Scalar& a, const Scalar& b) {
return (numext::mini)(a, b);
}
template <bool Cond, typename Scalar, typename Packet, bool SkipDenorms = false, typename FunctorT>
std::enable_if_t<!Cond, void> run_ieee_cases(const FunctorT&) {}
template <bool Cond, typename Scalar, typename Packet, bool SkipDenorms = false, typename FunctorT>
std::enable_if_t<Cond, void> run_ieee_cases(const FunctorT& fun) {
const int PacketSize = internal::unpacket_traits<Packet>::size;
const Scalar norm_min = (std::numeric_limits<Scalar>::min)();
const Scalar norm_max = (std::numeric_limits<Scalar>::max)();
const Scalar inf = (std::numeric_limits<Scalar>::infinity)();
const Scalar nan = (std::numeric_limits<Scalar>::quiet_NaN)();
std::vector<Scalar> values{norm_min, Scalar(0), Scalar(1), norm_max, inf, nan};
constexpr int size = PacketSize * 2;
EIGEN_ALIGN_MAX Scalar data1[size];
EIGEN_ALIGN_MAX Scalar data2[size];
EIGEN_ALIGN_MAX Scalar ref[size];
for (int i = 0; i < size; ++i) {
data1[i] = data2[i] = ref[i] = Scalar(0);
}
if (Cond && !SkipDenorms && std::numeric_limits<Scalar>::has_denorm == std::denorm_present) {
values.push_back(std::numeric_limits<Scalar>::denorm_min());
values.push_back(norm_min / Scalar(2));
}
for (Scalar abs_value : values) {
data1[0] = abs_value;
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, fun.expected, fun.actual);
}
}
// Create a tester struct with the actual and the reference function
// as templated member functions.
#define CREATE_TESTER(NAME, ACTUAL, EXPECTED) \
struct NAME { \
template <typename T> \
T actual(const T& val) const { \
return ACTUAL(val); \
} \
template <typename T> \
T expected(const T& val) const { \
return EXPECTED(val); \
} \
}
CREATE_TESTER(sqrt_fun, internal::psqrt, numext::sqrt);
CREATE_TESTER(rsqrt_fun, internal::prsqrt, numext::rsqrt);
CREATE_TESTER(cbrt_fun, internal::pcbrt, numext::cbrt);
CREATE_TESTER(exp_fun, internal::pexp, numext::exp);
CREATE_TESTER(exp2_fun, internal::pexp2, numext::exp2);
CREATE_TESTER(log_fun, internal::plog, numext::log);
CREATE_TESTER(log2_fun, internal::plog2, numext::log2);
CREATE_TESTER(expm1_fun, internal::pexpm1, numext::expm1);
CREATE_TESTER(log1p_fun, internal::plog1p, numext::log1p);
CREATE_TESTER(sin_fun, internal::psin, numext::sin);
CREATE_TESTER(cos_fun, internal::pcos, numext::cos);
CREATE_TESTER(tan_fun, internal::ptan, numext::tan);
CREATE_TESTER(asin_fun, internal::pasin, numext::asin);
CREATE_TESTER(acos_fun, internal::pacos, numext::acos);
CREATE_TESTER(atan_fun, internal::patan, numext::atan);
CREATE_TESTER(tanh_fun, internal::ptanh, numext::tanh);
CREATE_TESTER(atanh_fun, internal::patanh, numext::atanh);
template <typename Scalar, typename Packet>
std::enable_if_t<NumTraits<Scalar>::IsComplex, void> packetmath_ieee_special_values() {}
template <typename Scalar, typename Packet>
std::enable_if_t<!NumTraits<Scalar>::IsComplex, void> packetmath_ieee_special_values() {
typedef internal::packet_traits<Scalar> PacketTraits;
run_ieee_cases<PacketTraits::HasSqrt, Scalar, Packet>(sqrt_fun());
// TODO(rmlarsen): See if we can fix rsqrt for denorms without wreaking performance.
run_ieee_cases<PacketTraits::HasRsqrt, Scalar, Packet, true>(rsqrt_fun());
run_ieee_cases<PacketTraits::HasCbrt, Scalar, Packet>(cbrt_fun());
run_ieee_cases<PacketTraits::HasExp, Scalar, Packet>(exp_fun());
run_ieee_cases<PacketTraits::HasExp, Scalar, Packet>(exp2_fun());
run_ieee_cases<PacketTraits::HasLog, Scalar, Packet>(log_fun());
run_ieee_cases<PacketTraits::HasLog, Scalar, Packet>(log2_fun());
run_ieee_cases<PacketTraits::HasExpm1, Scalar, Packet>(expm1_fun());
run_ieee_cases<PacketTraits::HasLog1p, Scalar, Packet>(log1p_fun());
run_ieee_cases<PacketTraits::HasSin, Scalar, Packet>(sin_fun());
run_ieee_cases<PacketTraits::HasCos, Scalar, Packet>(cos_fun());
run_ieee_cases<PacketTraits::HasTan, Scalar, Packet>(tan_fun());
run_ieee_cases<PacketTraits::HasASin, Scalar, Packet>(asin_fun());
run_ieee_cases<PacketTraits::HasACos, Scalar, Packet>(acos_fun());
run_ieee_cases<PacketTraits::HasATan, Scalar, Packet>(atan_fun());
run_ieee_cases<PacketTraits::HasTanh, Scalar, Packet>(tanh_fun());
run_ieee_cases<PacketTraits::HasATanh, Scalar, Packet>(atanh_fun());
}
template <typename Scalar, typename Packet>
void packetmath_notcomplex() {
packetmath_ieee_special_values<Scalar, Packet>();
typedef internal::packet_traits<Scalar> PacketTraits;
const int PacketSize = internal::unpacket_traits<Packet>::size;