Fix unary pow error handling and test

This commit is contained in:
Charles Schlosser
2023-06-06 18:46:55 +00:00
committed by Rasmus Munk Larsen
parent 7ac8897431
commit b7151ffaab
2 changed files with 229 additions and 130 deletions

View File

@@ -191,53 +191,97 @@ void unary_ops_test() {
*/
}
template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
struct ref_pow {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
}
};
template <typename Scalar>
void pow_scalar_exponent_test() {
using Int_t = typename internal::make_integer<Scalar>::type;
const Scalar tol = test_precision<Scalar>();
template <typename Base, typename Exponent>
struct ref_pow<Base, Exponent, true> {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, exponent));
}
};
std::vector<Scalar> abs_vals = special_values<Scalar>();
const Index num_vals = (Index)abs_vals.size();
Map<Array<Scalar, Dynamic, 1>> bases(abs_vals.data(), num_vals);
template <typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
struct pow_helper {
static bool is_integer_impl(const Exponent& exp) { return (numext::isfinite)(exp) && exp == numext::floor(exp); }
static bool is_odd_impl(const Exponent& exp) {
Exponent exp_div_2 = exp / Exponent(2);
Exponent floor_exp_div_2 = numext::floor(exp_div_2);
return exp_div_2 != floor_exp_div_2;
}
};
template <typename Exponent>
struct pow_helper<Exponent, true> {
static bool is_integer_impl(const Exponent&) { return true; }
static bool is_odd_impl(const Exponent& exp) { return exp % 2 != 0; }
};
template <typename Exponent>
bool is_integer(const Exponent& exp) {
return pow_helper<Exponent>::is_integer_impl(exp);
}
template <typename Exponent>
bool is_odd(const Exponent& exp) {
return pow_helper<Exponent>::is_odd_impl(exp);
}
template <typename Base, typename Exponent>
void float_pow_test_impl() {
const Base tol = test_precision<Base>();
std::vector<Base> abs_base_vals = special_values<Base>();
std::vector<Exponent> abs_exponent_vals = special_values<Exponent>();
for (int i = 0; i < 100; i++) {
abs_base_vals.push_back(internal::random<Base>(Base(0), Base(10)));
abs_exponent_vals.push_back(internal::random<Exponent>(Exponent(0), Exponent(10)));
}
const Index num_repeats = internal::packet_traits<Base>::size + 1;
ArrayX<Base> bases(num_repeats), eigenPow(num_repeats);
bool all_pass = true;
for (Scalar abs_exponent : abs_vals) {
for (Scalar exponent : {-abs_exponent, abs_exponent}) {
// test integer exponent code path
bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) &&
(numext::abs(exponent) < static_cast<Scalar>(NumTraits<Int_t>::highest()));
if (exponent_is_integer) {
Int_t exponent_as_int = static_cast<Int_t>(exponent);
Array<Scalar, Dynamic, 1> eigenPow = bases.pow(exponent_as_int);
for (Index j = 0; j < num_vals; j++) {
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
}
}
} else {
// test floating point exponent code path
Array<Scalar, Dynamic, 1> eigenPow = bases.pow(exponent);
for (Index j = 0; j < num_vals; j++) {
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
for (Base abs_base : abs_base_vals)
for (Base base : {negative_or_zero(abs_base), abs_base}) {
bases.setConstant(base);
for (Exponent abs_exponent : abs_exponent_vals) {
for (Exponent exponent : {negative_or_zero(abs_exponent), abs_exponent}) {
eigenPow = bases.pow(exponent);
for (Index j = 0; j < num_repeats; j++) {
Base e = ref_pow<Base, Exponent>::run(bases(j), exponent);
if (is_integer(exponent)) {
// std::pow may return an incorrect result for a very large integral exponent
// if base is negative and the exponent is odd, then the result must be negative
// if std::pow returns otherwise, flip the sign
bool exp_is_odd = is_odd(exponent);
bool base_is_neg = !(numext::isnan)(base) && (bool)numext::signbit(base);
bool result_is_neg = exp_is_odd && base_is_neg;
bool ref_is_neg = !(numext::isnan)(e) && (bool)numext::signbit(e);
bool flip_sign = result_is_neg != ref_is_neg;
if (flip_sign) e = -e;
}
Base a = eigenPow(j);
#ifdef EIGEN_COMP_MSVC
// Work around MSVC return value on underflow.
// if std::pow returns 0 and Eigen returns a denormalized value, then skip the test
int fpclass = std::fpclassify(a);
if (e == Base(0) && fpclass == FP_SUBNORMAL) continue;
#endif
bool both_nan = (numext::isnan)(a) && (numext::isnan)(e);
bool exact_or_approx = (a == e) || internal::isApprox(a, e, tol);
bool same_sign = (bool)numext::signbit(e) == (bool)numext::signbit(a);
bool success = both_nan || (exact_or_approx && same_sign);
all_pass &= success;
if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
}
}
}
}
}
}
VERIFY(all_pass);
}
@@ -259,24 +303,9 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) {
}
}
template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
struct ref_pow {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
}
};
template <typename Base, typename Exponent>
struct ref_pow<Base, Exponent, true> {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, exponent));
}
};
template <typename Base, typename Exponent>
void test_exponent(Exponent exponent) {
EIGEN_STATIC_ASSERT(NumTraits<Base>::IsInteger,THIS TEST IS ONLY INTENDED FOR BASE INTEGER TYPES)
const Base max_abs_bases = static_cast<Base>(10000);
// avoid integer overflow in Base type
Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
@@ -300,10 +329,10 @@ void test_exponent(Exponent exponent) {
for (Base a : y) {
Base e = ref_pow<Base, Exponent>::run(base, exponent);
bool pass = (a == e);
if (!NumTraits<Base>::IsInteger) {
pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) ||
((numext::isnan)(a) && (numext::isnan)(e)));
}
//if (!NumTraits<Base>::IsInteger) {
// pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) ||
// ((numext::isnan)(a) && (numext::isnan)(e)));
//}
all_pass &= pass;
if (!pass) {
std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
@@ -314,7 +343,7 @@ void test_exponent(Exponent exponent) {
}
template <typename Base, typename Exponent>
void unary_pow_test() {
void int_pow_test_impl() {
Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
Exponent min_exponent = negative_or_zero(max_exponent);
@@ -323,21 +352,26 @@ void unary_pow_test() {
}
}
void float_pow_test() {
float_pow_test_impl<float, float>();
float_pow_test_impl<double, double>();
}
void mixed_pow_test() {
// The following cases will test promoting a smaller exponent type
// to a wider base type.
unary_pow_test<double, int>();
unary_pow_test<double, float>();
unary_pow_test<float, half>();
unary_pow_test<double, half>();
unary_pow_test<float, bfloat16>();
unary_pow_test<double, bfloat16>();
float_pow_test_impl<double, int>();
float_pow_test_impl<double, float>();
float_pow_test_impl<float, half>();
float_pow_test_impl<double, half>();
float_pow_test_impl<float, bfloat16>();
float_pow_test_impl<double, bfloat16>();
// Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement
// the operation using repeated squaring.
unary_pow_test<float, int>();
unary_pow_test<double, long long>();
float_pow_test_impl<float, int>();
float_pow_test_impl<double, long long>();
// The following cases will test promoting a wider exponent type
// to a narrower base type. This should compile but would generate a
@@ -346,20 +380,20 @@ void mixed_pow_test() {
}
void int_pow_test() {
unary_pow_test<int, int>();
unary_pow_test<unsigned int, unsigned int>();
unary_pow_test<long long, long long>();
unary_pow_test<unsigned long long, unsigned long long>();
int_pow_test_impl<int, int>();
int_pow_test_impl<unsigned int, unsigned int>();
int_pow_test_impl<long long, long long>();
int_pow_test_impl<unsigned long long, unsigned long long>();
// Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement the
// operation using repeated squaring.
unary_pow_test<long long, int>();
unary_pow_test<int, unsigned int>();
unary_pow_test<unsigned int, int>();
unary_pow_test<long long, unsigned long long>();
unary_pow_test<unsigned long long, long long>();
unary_pow_test<long long, int>();
int_pow_test_impl<long long, int>();
int_pow_test_impl<int, unsigned int>();
int_pow_test_impl<unsigned int, int>();
int_pow_test_impl<long long, unsigned long long>();
int_pow_test_impl<unsigned long long, long long>();
int_pow_test_impl<long long, int>();
}
namespace Eigen {
@@ -849,7 +883,6 @@ template<typename ArrayType> void array_real(const ArrayType& m)
// Test pow and atan2 on special IEEE values.
unary_ops_test<Scalar>();
binary_ops_test<Scalar>();
pow_scalar_exponent_test<Scalar>();
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));
@@ -1223,6 +1256,7 @@ EIGEN_DECLARE_TEST(array_cwise)
}
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_5( float_pow_test() );
CALL_SUBTEST_6( int_pow_test() );
CALL_SUBTEST_7( mixed_pow_test() );
CALL_SUBTEST_8( signbit_tests() );