From a7674b70d3413f3db13bd35d913a190fdc30f57d Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen <4643818-rmlarsen1@users.noreply.gitlab.com> Date: Wed, 12 Nov 2025 22:19:50 +0000 Subject: [PATCH] Improve packet op test coverage for IEEE special values. libeigen/eigen!2075 Co-authored-by: Rasmus Munk Larsen --- Eigen/src/Core/MathFunctions.h | 12 +- Eigen/src/Core/arch/AVX/MathFunctions.h | 2 + Eigen/src/Core/arch/Default/BFloat16.h | 7 +- test/packetmath.cpp | 184 ++++++++++++------------ 4 files changed, 110 insertions(+), 95 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 5e36ce84d..f6269aa1a 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -1357,6 +1357,12 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr T round_down(T a, U b) { return ub * (ua / ub); } +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log2(T x) { + EIGEN_USING_STD(log2); + return log2(x); +} + /** Log base 2 for 32 bits positive integers. * Conveniently returns 0 for x==0. */ constexpr int log2(int x) { @@ -1436,9 +1442,9 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double log(const double& x) { #endif template -EIGEN_DEVICE_FUNC -EIGEN_ALWAYS_INLINE std::enable_if_t::IsSigned || NumTraits::IsComplex, typename NumTraits::Real> -abs(const T& x) { +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + std::enable_if_t::IsSigned || NumTraits::IsComplex, typename NumTraits::Real> + abs(const T& x) { EIGEN_USING_STD(abs); return abs(x); } diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 6c40ff90b..5ee67a599 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -106,6 +106,7 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, preciprocal) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcbrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) @@ -122,6 +123,7 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, preciprocal) F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, psin) F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pcbrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) #endif diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index b93c4bc2e..b69097d45 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -622,6 +622,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) { return bfloat16(static_cast(EIGEN_LOG2E) * ::logf(float(a))); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { return bfloat16(::sqrtf(float(a))); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cbrt(const bfloat16& a) { return bfloat16(::cbrtf(float(a))); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) { return bfloat16(::powf(float(a), float(b))); } @@ -794,8 +795,10 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, c } // Specialize multiply-add to match packet operations and reduce conversions to/from float. -template<> -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) { +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd(const Eigen::bfloat16& x, + const Eigen::bfloat16& y, + const Eigen::bfloat16& z) { return Eigen::bfloat16(static_cast(x) * static_cast(y) + static_cast(z)); } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 959abd97b..f645b8ce1 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -805,84 +805,6 @@ Scalar log2(Scalar x) { return Scalar(EIGEN_LOG2E) * std::log(x); } -// Create a functor out of a function so it can be passed (with overloads) -// to another function as an input argument. -#define CREATE_FUNCTOR(Name, Func) \ - struct Name { \ - template \ - T operator()(const T& val) const { \ - return Func(val); \ - } \ - } - -CREATE_FUNCTOR(psqrt_functor, internal::psqrt); -CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt); -CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt); - -// TODO(rmlarsen): Run this test for more functions. -template -void packetmath_test_IEEE_corner_cases(const RefFunctorT& ref_fun, const FunctorT& fun) { - const int PacketSize = internal::unpacket_traits::size; - const Scalar norm_min = (std::numeric_limits::min)(); - const Scalar norm_max = (std::numeric_limits::max)(); - - constexpr int size = PacketSize * 2; - EIGEN_ALIGN_MAX Scalar data1[size]; - EIGEN_ALIGN_MAX Scalar data2[size]; - EIGEN_ALIGN_MAX Scalar ref[size]; - for (int i = 0; i < size; ++i) { - data1[i] = data2[i] = ref[i] = Scalar(0); - } - - // Test for subnormals. - if (Cond && std::numeric_limits::has_denorm == std::denorm_present && !EIGEN_ARCH_ARM) { - for (int scale = 1; scale < 5; ++scale) { - // When EIGEN_FAST_MATH is 1 we relax the conditions slightly, and allow the function - // to return the same value for subnormals as the reference would return for zero with - // the same sign as the input. -#if EIGEN_FAST_MATH - data1[0] = Scalar(scale) * std::numeric_limits::denorm_min(); - data1[1] = -data1[0]; - test::packet_helper h; - h.store(data2, fun(h.load(data1))); - for (int i = 0; i < PacketSize; ++i) { - const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0)); - const Scalar ref_val = ref_fun(data1[i]); - VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero || - verifyIsApprox(data2[i], ref_val)); - } -#else - CHECK_CWISE1_IF(Cond, ref_fun, fun); -#endif - } - } - - // Test for smallest normalized floats. - data1[0] = norm_min; - data1[1] = -data1[0]; - CHECK_CWISE1_IF(Cond, ref_fun, fun); - - // Test for largest floats. - data1[0] = norm_max; - data1[1] = -data1[0]; - CHECK_CWISE1_IF(Cond, ref_fun, fun); - - // Test for zeros. - data1[0] = Scalar(0.0); - data1[1] = -data1[0]; - CHECK_CWISE1_IF(Cond, ref_fun, fun); - - // Test for infinities. - data1[0] = NumTraits::infinity(); - data1[1] = -data1[0]; - CHECK_CWISE1_IF(Cond, ref_fun, fun); - - // Test for quiet NaNs. - data1[0] = std::numeric_limits::quiet_NaN(); - data1[1] = -std::numeric_limits::quiet_NaN(); - CHECK_CWISE1_IF(Cond, ref_fun, fun); -} - template void packetmath_real() { typedef internal::packet_traits PacketTraits; @@ -1071,18 +993,12 @@ void packetmath_real() { test::packet_helper h; h.store(data2, internal::pexp(h.load(data1))); VERIFY((numext::isnan)(data2[0])); - // TODO(rmlarsen): Re-enable for bfloat16. - if (!internal::is_same::value) { - VERIFY_IS_APPROX(std::exp(small), data2[1]); - } + VERIFY_IS_APPROX(std::exp(small), data2[1]); data1[0] = -small; data1[1] = Scalar(0); h.store(data2, internal::pexp(h.load(data1))); - // TODO(rmlarsen): Re-enable for bfloat16. - if (!internal::is_same::value) { - VERIFY_IS_APPROX(std::exp(-small), data2[0]); - } + VERIFY_IS_APPROX(std::exp(-small), data2[0]); VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]); data1[0] = (std::numeric_limits::min)(); @@ -1186,10 +1102,6 @@ void packetmath_real() { VERIFY((numext::isnan)(data2[1])); } - packetmath_test_IEEE_corner_cases(numext::sqrt, psqrt_functor()); - packetmath_test_IEEE_corner_cases(numext::rsqrt, prsqrt_functor()); - packetmath_test_IEEE_corner_cases(numext::cbrt, pcbrt_functor()); - // TODO(rmlarsen): Re-enable for half and bfloat16. if (PacketTraits::HasCos && !internal::is_same::value && !internal::is_same::value) { @@ -1292,8 +1204,100 @@ Scalar propagate_number_min(const Scalar& a, const Scalar& b) { return (numext::mini)(a, b); } +template +std::enable_if_t run_ieee_cases(const FunctorT&) {} + +template +std::enable_if_t run_ieee_cases(const FunctorT& fun) { + const int PacketSize = internal::unpacket_traits::size; + const Scalar norm_min = (std::numeric_limits::min)(); + const Scalar norm_max = (std::numeric_limits::max)(); + const Scalar inf = (std::numeric_limits::infinity)(); + const Scalar nan = (std::numeric_limits::quiet_NaN)(); + std::vector values{norm_min, Scalar(0), Scalar(1), norm_max, inf, nan}; + + constexpr int size = PacketSize * 2; + EIGEN_ALIGN_MAX Scalar data1[size]; + EIGEN_ALIGN_MAX Scalar data2[size]; + EIGEN_ALIGN_MAX Scalar ref[size]; + for (int i = 0; i < size; ++i) { + data1[i] = data2[i] = ref[i] = Scalar(0); + } + + if (Cond && !SkipDenorms && std::numeric_limits::has_denorm == std::denorm_present) { + values.push_back(std::numeric_limits::denorm_min()); + values.push_back(norm_min / Scalar(2)); + } + + for (Scalar abs_value : values) { + data1[0] = abs_value; + data1[1] = -data1[0]; + CHECK_CWISE1_IF(Cond, fun.expected, fun.actual); + } +} + +// Create a tester struct with the actual and the reference function +// as templated member functions. +#define CREATE_TESTER(NAME, ACTUAL, EXPECTED) \ + struct NAME { \ + template \ + T actual(const T& val) const { \ + return ACTUAL(val); \ + } \ + template \ + T expected(const T& val) const { \ + return EXPECTED(val); \ + } \ + } + +CREATE_TESTER(sqrt_fun, internal::psqrt, numext::sqrt); +CREATE_TESTER(rsqrt_fun, internal::prsqrt, numext::rsqrt); +CREATE_TESTER(cbrt_fun, internal::pcbrt, numext::cbrt); +CREATE_TESTER(exp_fun, internal::pexp, numext::exp); +CREATE_TESTER(exp2_fun, internal::pexp2, numext::exp2); +CREATE_TESTER(log_fun, internal::plog, numext::log); +CREATE_TESTER(log2_fun, internal::plog2, numext::log2); +CREATE_TESTER(expm1_fun, internal::pexpm1, numext::expm1); +CREATE_TESTER(log1p_fun, internal::plog1p, numext::log1p); +CREATE_TESTER(sin_fun, internal::psin, numext::sin); +CREATE_TESTER(cos_fun, internal::pcos, numext::cos); +CREATE_TESTER(tan_fun, internal::ptan, numext::tan); +CREATE_TESTER(asin_fun, internal::pasin, numext::asin); +CREATE_TESTER(acos_fun, internal::pacos, numext::acos); +CREATE_TESTER(atan_fun, internal::patan, numext::atan); +CREATE_TESTER(tanh_fun, internal::ptanh, numext::tanh); +CREATE_TESTER(atanh_fun, internal::patanh, numext::atanh); + +template +std::enable_if_t::IsComplex, void> packetmath_ieee_special_values() {} + +template +std::enable_if_t::IsComplex, void> packetmath_ieee_special_values() { + typedef internal::packet_traits PacketTraits; + run_ieee_cases(sqrt_fun()); + // TODO(rmlarsen): See if we can fix rsqrt for denorms without wreaking performance. + run_ieee_cases(rsqrt_fun()); + run_ieee_cases(cbrt_fun()); + run_ieee_cases(exp_fun()); + run_ieee_cases(exp2_fun()); + run_ieee_cases(log_fun()); + run_ieee_cases(log2_fun()); + run_ieee_cases(expm1_fun()); + run_ieee_cases(log1p_fun()); + run_ieee_cases(sin_fun()); + run_ieee_cases(cos_fun()); + run_ieee_cases(tan_fun()); + run_ieee_cases(asin_fun()); + run_ieee_cases(acos_fun()); + run_ieee_cases(atan_fun()); + run_ieee_cases(tanh_fun()); + run_ieee_cases(atanh_fun()); +} + template void packetmath_notcomplex() { + packetmath_ieee_special_values(); + typedef internal::packet_traits PacketTraits; const int PacketSize = internal::unpacket_traits::size;