mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Improve packet op test coverage for IEEE special values.
libeigen/eigen!2075 Co-authored-by: Rasmus Munk Larsen <rmlarsen@google.com>
This commit is contained in:
@@ -1357,6 +1357,12 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr T round_down(T a, U b) {
|
||||
return ub * (ua / ub);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log2(T x) {
|
||||
EIGEN_USING_STD(log2);
|
||||
return log2(x);
|
||||
}
|
||||
|
||||
/** Log base 2 for 32 bits positive integers.
|
||||
* Conveniently returns 0 for x==0. */
|
||||
constexpr int log2(int x) {
|
||||
@@ -1436,9 +1442,9 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double log(const double& x) {
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
EIGEN_USING_STD(abs);
|
||||
return abs(x);
|
||||
}
|
||||
|
||||
@@ -106,6 +106,7 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, preciprocal)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcbrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
|
||||
@@ -122,6 +123,7 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, preciprocal)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcbrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
|
||||
#endif
|
||||
|
||||
|
||||
@@ -622,6 +622,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
|
||||
return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { return bfloat16(::sqrtf(float(a))); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cbrt(const bfloat16& a) { return bfloat16(::cbrtf(float(a))); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::powf(float(a), float(b)));
|
||||
}
|
||||
@@ -794,8 +795,10 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, c
|
||||
}
|
||||
|
||||
// Specialize multiply-add to match packet operations and reduce conversions to/from float.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x,
|
||||
const Eigen::bfloat16& y,
|
||||
const Eigen::bfloat16& z) {
|
||||
return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
|
||||
}
|
||||
|
||||
|
||||
@@ -805,84 +805,6 @@ Scalar log2(Scalar x) {
|
||||
return Scalar(EIGEN_LOG2E) * std::log(x);
|
||||
}
|
||||
|
||||
// Create a functor out of a function so it can be passed (with overloads)
|
||||
// to another function as an input argument.
|
||||
#define CREATE_FUNCTOR(Name, Func) \
|
||||
struct Name { \
|
||||
template <typename T> \
|
||||
T operator()(const T& val) const { \
|
||||
return Func(val); \
|
||||
} \
|
||||
}
|
||||
|
||||
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
|
||||
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
|
||||
CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt);
|
||||
|
||||
// TODO(rmlarsen): Run this test for more functions.
|
||||
template <bool Cond, typename Scalar, typename Packet, typename RefFunctorT, typename FunctorT>
|
||||
void packetmath_test_IEEE_corner_cases(const RefFunctorT& ref_fun, const FunctorT& fun) {
|
||||
const int PacketSize = internal::unpacket_traits<Packet>::size;
|
||||
const Scalar norm_min = (std::numeric_limits<Scalar>::min)();
|
||||
const Scalar norm_max = (std::numeric_limits<Scalar>::max)();
|
||||
|
||||
constexpr int size = PacketSize * 2;
|
||||
EIGEN_ALIGN_MAX Scalar data1[size];
|
||||
EIGEN_ALIGN_MAX Scalar data2[size];
|
||||
EIGEN_ALIGN_MAX Scalar ref[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data1[i] = data2[i] = ref[i] = Scalar(0);
|
||||
}
|
||||
|
||||
// Test for subnormals.
|
||||
if (Cond && std::numeric_limits<Scalar>::has_denorm == std::denorm_present && !EIGEN_ARCH_ARM) {
|
||||
for (int scale = 1; scale < 5; ++scale) {
|
||||
// When EIGEN_FAST_MATH is 1 we relax the conditions slightly, and allow the function
|
||||
// to return the same value for subnormals as the reference would return for zero with
|
||||
// the same sign as the input.
|
||||
#if EIGEN_FAST_MATH
|
||||
data1[0] = Scalar(scale) * std::numeric_limits<Scalar>::denorm_min();
|
||||
data1[1] = -data1[0];
|
||||
test::packet_helper<Cond, Packet> h;
|
||||
h.store(data2, fun(h.load(data1)));
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0));
|
||||
const Scalar ref_val = ref_fun(data1[i]);
|
||||
VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero ||
|
||||
verifyIsApprox(data2[i], ref_val));
|
||||
}
|
||||
#else
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Test for smallest normalized floats.
|
||||
data1[0] = norm_min;
|
||||
data1[1] = -data1[0];
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
|
||||
// Test for largest floats.
|
||||
data1[0] = norm_max;
|
||||
data1[1] = -data1[0];
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
|
||||
// Test for zeros.
|
||||
data1[0] = Scalar(0.0);
|
||||
data1[1] = -data1[0];
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
|
||||
// Test for infinities.
|
||||
data1[0] = NumTraits<Scalar>::infinity();
|
||||
data1[1] = -data1[0];
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
|
||||
// Test for quiet NaNs.
|
||||
data1[0] = std::numeric_limits<Scalar>::quiet_NaN();
|
||||
data1[1] = -std::numeric_limits<Scalar>::quiet_NaN();
|
||||
CHECK_CWISE1_IF(Cond, ref_fun, fun);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_real() {
|
||||
typedef internal::packet_traits<Scalar> PacketTraits;
|
||||
@@ -1071,18 +993,12 @@ void packetmath_real() {
|
||||
test::packet_helper<PacketTraits::HasExp, Packet> h;
|
||||
h.store(data2, internal::pexp(h.load(data1)));
|
||||
VERIFY((numext::isnan)(data2[0]));
|
||||
// TODO(rmlarsen): Re-enable for bfloat16.
|
||||
if (!internal::is_same<Scalar, bfloat16>::value) {
|
||||
VERIFY_IS_APPROX(std::exp(small), data2[1]);
|
||||
}
|
||||
VERIFY_IS_APPROX(std::exp(small), data2[1]);
|
||||
|
||||
data1[0] = -small;
|
||||
data1[1] = Scalar(0);
|
||||
h.store(data2, internal::pexp(h.load(data1)));
|
||||
// TODO(rmlarsen): Re-enable for bfloat16.
|
||||
if (!internal::is_same<Scalar, bfloat16>::value) {
|
||||
VERIFY_IS_APPROX(std::exp(-small), data2[0]);
|
||||
}
|
||||
VERIFY_IS_APPROX(std::exp(-small), data2[0]);
|
||||
VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]);
|
||||
|
||||
data1[0] = (std::numeric_limits<Scalar>::min)();
|
||||
@@ -1186,10 +1102,6 @@ void packetmath_real() {
|
||||
VERIFY((numext::isnan)(data2[1]));
|
||||
}
|
||||
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, psqrt_functor());
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, prsqrt_functor());
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasCbrt, Scalar, Packet>(numext::cbrt<Scalar>, pcbrt_functor());
|
||||
|
||||
// TODO(rmlarsen): Re-enable for half and bfloat16.
|
||||
if (PacketTraits::HasCos && !internal::is_same<Scalar, half>::value &&
|
||||
!internal::is_same<Scalar, bfloat16>::value) {
|
||||
@@ -1292,8 +1204,100 @@ Scalar propagate_number_min(const Scalar& a, const Scalar& b) {
|
||||
return (numext::mini)(a, b);
|
||||
}
|
||||
|
||||
template <bool Cond, typename Scalar, typename Packet, bool SkipDenorms = false, typename FunctorT>
|
||||
std::enable_if_t<!Cond, void> run_ieee_cases(const FunctorT&) {}
|
||||
|
||||
template <bool Cond, typename Scalar, typename Packet, bool SkipDenorms = false, typename FunctorT>
|
||||
std::enable_if_t<Cond, void> run_ieee_cases(const FunctorT& fun) {
|
||||
const int PacketSize = internal::unpacket_traits<Packet>::size;
|
||||
const Scalar norm_min = (std::numeric_limits<Scalar>::min)();
|
||||
const Scalar norm_max = (std::numeric_limits<Scalar>::max)();
|
||||
const Scalar inf = (std::numeric_limits<Scalar>::infinity)();
|
||||
const Scalar nan = (std::numeric_limits<Scalar>::quiet_NaN)();
|
||||
std::vector<Scalar> values{norm_min, Scalar(0), Scalar(1), norm_max, inf, nan};
|
||||
|
||||
constexpr int size = PacketSize * 2;
|
||||
EIGEN_ALIGN_MAX Scalar data1[size];
|
||||
EIGEN_ALIGN_MAX Scalar data2[size];
|
||||
EIGEN_ALIGN_MAX Scalar ref[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data1[i] = data2[i] = ref[i] = Scalar(0);
|
||||
}
|
||||
|
||||
if (Cond && !SkipDenorms && std::numeric_limits<Scalar>::has_denorm == std::denorm_present) {
|
||||
values.push_back(std::numeric_limits<Scalar>::denorm_min());
|
||||
values.push_back(norm_min / Scalar(2));
|
||||
}
|
||||
|
||||
for (Scalar abs_value : values) {
|
||||
data1[0] = abs_value;
|
||||
data1[1] = -data1[0];
|
||||
CHECK_CWISE1_IF(Cond, fun.expected, fun.actual);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tester struct with the actual and the reference function
|
||||
// as templated member functions.
|
||||
#define CREATE_TESTER(NAME, ACTUAL, EXPECTED) \
|
||||
struct NAME { \
|
||||
template <typename T> \
|
||||
T actual(const T& val) const { \
|
||||
return ACTUAL(val); \
|
||||
} \
|
||||
template <typename T> \
|
||||
T expected(const T& val) const { \
|
||||
return EXPECTED(val); \
|
||||
} \
|
||||
}
|
||||
|
||||
CREATE_TESTER(sqrt_fun, internal::psqrt, numext::sqrt);
|
||||
CREATE_TESTER(rsqrt_fun, internal::prsqrt, numext::rsqrt);
|
||||
CREATE_TESTER(cbrt_fun, internal::pcbrt, numext::cbrt);
|
||||
CREATE_TESTER(exp_fun, internal::pexp, numext::exp);
|
||||
CREATE_TESTER(exp2_fun, internal::pexp2, numext::exp2);
|
||||
CREATE_TESTER(log_fun, internal::plog, numext::log);
|
||||
CREATE_TESTER(log2_fun, internal::plog2, numext::log2);
|
||||
CREATE_TESTER(expm1_fun, internal::pexpm1, numext::expm1);
|
||||
CREATE_TESTER(log1p_fun, internal::plog1p, numext::log1p);
|
||||
CREATE_TESTER(sin_fun, internal::psin, numext::sin);
|
||||
CREATE_TESTER(cos_fun, internal::pcos, numext::cos);
|
||||
CREATE_TESTER(tan_fun, internal::ptan, numext::tan);
|
||||
CREATE_TESTER(asin_fun, internal::pasin, numext::asin);
|
||||
CREATE_TESTER(acos_fun, internal::pacos, numext::acos);
|
||||
CREATE_TESTER(atan_fun, internal::patan, numext::atan);
|
||||
CREATE_TESTER(tanh_fun, internal::ptanh, numext::tanh);
|
||||
CREATE_TESTER(atanh_fun, internal::patanh, numext::atanh);
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
std::enable_if_t<NumTraits<Scalar>::IsComplex, void> packetmath_ieee_special_values() {}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
std::enable_if_t<!NumTraits<Scalar>::IsComplex, void> packetmath_ieee_special_values() {
|
||||
typedef internal::packet_traits<Scalar> PacketTraits;
|
||||
run_ieee_cases<PacketTraits::HasSqrt, Scalar, Packet>(sqrt_fun());
|
||||
// TODO(rmlarsen): See if we can fix rsqrt for denorms without wreaking performance.
|
||||
run_ieee_cases<PacketTraits::HasRsqrt, Scalar, Packet, true>(rsqrt_fun());
|
||||
run_ieee_cases<PacketTraits::HasCbrt, Scalar, Packet>(cbrt_fun());
|
||||
run_ieee_cases<PacketTraits::HasExp, Scalar, Packet>(exp_fun());
|
||||
run_ieee_cases<PacketTraits::HasExp, Scalar, Packet>(exp2_fun());
|
||||
run_ieee_cases<PacketTraits::HasLog, Scalar, Packet>(log_fun());
|
||||
run_ieee_cases<PacketTraits::HasLog, Scalar, Packet>(log2_fun());
|
||||
run_ieee_cases<PacketTraits::HasExpm1, Scalar, Packet>(expm1_fun());
|
||||
run_ieee_cases<PacketTraits::HasLog1p, Scalar, Packet>(log1p_fun());
|
||||
run_ieee_cases<PacketTraits::HasSin, Scalar, Packet>(sin_fun());
|
||||
run_ieee_cases<PacketTraits::HasCos, Scalar, Packet>(cos_fun());
|
||||
run_ieee_cases<PacketTraits::HasTan, Scalar, Packet>(tan_fun());
|
||||
run_ieee_cases<PacketTraits::HasASin, Scalar, Packet>(asin_fun());
|
||||
run_ieee_cases<PacketTraits::HasACos, Scalar, Packet>(acos_fun());
|
||||
run_ieee_cases<PacketTraits::HasATan, Scalar, Packet>(atan_fun());
|
||||
run_ieee_cases<PacketTraits::HasTanh, Scalar, Packet>(tanh_fun());
|
||||
run_ieee_cases<PacketTraits::HasATanh, Scalar, Packet>(atanh_fun());
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_notcomplex() {
|
||||
packetmath_ieee_special_values<Scalar, Packet>();
|
||||
|
||||
typedef internal::packet_traits<Scalar> PacketTraits;
|
||||
const int PacketSize = internal::unpacket_traits<Packet>::size;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user