Vectorize atanh & add a missing definition and unit test for atan.

This commit is contained in:
Rasmus Munk Larsen
2023-02-21 03:14:05 +00:00
parent 049a144798
commit ce62177b5b
17 changed files with 179 additions and 25 deletions

View File

@@ -93,27 +93,91 @@ void binary_op_test(std::string name, Fn fun, RefFn ref) {
VERIFY(all_pass);
}
#define BINARY_FUNCTOR_TEST_ARGS(fun) #fun, \
[](const auto& x, const auto& y) { return (Eigen::fun)(x, y); }, \
[](const auto& x, const auto& y) { return (std::fun)(x, y); }
template <typename Scalar>
void binary_ops_test() {
binary_op_test<Scalar>("pow",
[](const auto& x, const auto& y) { return Eigen::pow(x, y); },
[](const auto& x, const auto& y) { return std::pow(x, y); });
binary_op_test<Scalar>("atan2",
[](const auto& x, const auto& y) { return Eigen::atan2(x, y); },
[](const auto& x, const auto& y) {
auto t = std::atan2(x, y);
#if EIGEN_COMP_MSVC
// Work around MSVC return value on underflow.
// |atan(y/x)| is bounded above by |y/x|, so on underflow return y/x according to POSIX spec.
// MSVC otherwise returns denorm_min.
if (EIGEN_PREDICT_FALSE(std::abs(t) == std::numeric_limits<decltype(t)>::denorm_min())) {
return x/y;
}
binary_op_test<Scalar>(BINARY_FUNCTOR_TEST_ARGS(pow));
#ifndef EIGEN_COMP_MSVC
binary_op_test<Scalar>(BINARY_FUNCTOR_TEST_ARGS(atan2));
#else
binary_op_test<Scalar>(
"atan2", [](const auto& x, const auto& y) { return Eigen::atan2(x, y); },
[](Scalar x, Scalar y) {
auto t = Scalar(std::atan2(x, y));
// Work around MSVC return value on underflow.
// |atan(y/x)| is bounded above by |y/x|, so on underflow return y/x according to POSIX spec.
// MSVC otherwise returns denorm_min.
if (EIGEN_PREDICT_FALSE(std::abs(t) == std::numeric_limits<decltype(t)>::denorm_min())) {
return x / y;
}
return t;
});
#endif
return t;
});
}
template <typename Scalar, typename Fn, typename RefFn>
void unary_op_test(std::string name, Fn fun, RefFn ref) {
const Scalar tol = test_precision<Scalar>();
auto values = special_values<Scalar>();
Map<Array<Scalar, Dynamic, 1>> x(values.data(), values.size());
Array<Scalar, Dynamic, Dynamic> actual = fun(x);
bool all_pass = true;
for (Index i = 0; i < x.size(); ++i) {
Scalar e = static_cast<Scalar>(ref(x(i)));
Scalar a = actual(i);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << name << "(" << x(i) << ") = " << a << " != " << e << std::endl;
}
}
VERIFY(all_pass);
}
#define UNARY_FUNCTOR_TEST_ARGS(fun) #fun, \
[](const auto& x) { return (Eigen::fun)(x); }, \
[](const auto& x) { return (std::fun)(x); }
template <typename Scalar>
void unary_ops_test() {
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sqrt));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(exp));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(log));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sin));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(cos));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(tan));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(asin));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(acos));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(atan));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sinh));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(cosh));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(tanh));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(asinh));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(acosh));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(atanh));
/* FIXME: Enable when the behavior of rsqrt on denormals for half and double is fixed.
unary_op_test<Scalar>("rsqrt",
[](const auto& x) { return Eigen::rsqrt(x); },
[](Scalar x) {
if (x >= 0 && x < (std::numeric_limits<Scalar>::min)()) {
// rsqrt return +inf for positive subnormals.
return NumTraits<Scalar>::infinity();
} else {
return Scalar(std::sqrt(Scalar(1)/x));
}
});
*/
}
template <typename Scalar>
void pow_scalar_exponent_test() {
using Int_t = typename internal::make_integer<Scalar>::type;
@@ -630,6 +694,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1)));
VERIFY_IS_APPROX(m1.cosh().acosh(), acosh(cosh(m1)));
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
VERIFY_IS_APPROX(m1.logistic(), logistic(m1));
VERIFY_IS_APPROX(m1.arg(), arg(m1));
@@ -722,6 +787,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
// Test pow and atan2 on special IEEE values.
unary_ops_test<Scalar>();
binary_ops_test<Scalar>();
pow_scalar_exponent_test<Scalar>();