New accurate algorithm for pow(x,y). This version is accurate to 1.4 ulps for float, while still being 10x faster than std::pow for AVX512. A future change will introduce a specialization for double.

This commit is contained in:
Rasmus Munk Larsen
2021-02-17 02:50:32 +00:00
parent 7ff0b7a980
commit be0574e215
2 changed files with 402 additions and 144 deletions

View File

@@ -14,6 +14,7 @@
template<typename Scalar>
void pow_test() {
const Scalar zero = Scalar(0);
const Scalar eps = std::numeric_limits<Scalar>::epsilon();
const Scalar one = Scalar(1);
const Scalar two = Scalar(2);
const Scalar three = Scalar(3);
@@ -21,20 +22,25 @@ void pow_test() {
const Scalar sqrt2 = Scalar(std::sqrt(2));
const Scalar inf = std::numeric_limits<Scalar>::infinity();
const Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
const Scalar denorm_min = std::numeric_limits<Scalar>::denorm_min();
const Scalar min = (std::numeric_limits<Scalar>::min)();
const Scalar max = (std::numeric_limits<Scalar>::max)();
const Scalar max_exp = (static_cast<Scalar>(std::numeric_limits<Scalar>::max_exponent) * Scalar(EIGEN_LN2)) / eps;
const static Scalar abs_vals[] = {zero,
denorm_min,
min,
eps,
sqrt_half,
one,
sqrt2,
two,
three,
min,
max_exp,
max,
inf,
nan};
const int abs_cases = 10;
const int abs_cases = 13;
const int num_cases = 2*abs_cases * 2*abs_cases;
// Repeat the same value to make sure we hit the vectorized path.
const int num_repeats = 32;
@@ -64,10 +70,7 @@ void pow_test() {
bool all_pass = true;
for (int i = 0; i < 1; ++i) {
for (int j = 0; j < num_cases; ++j) {
// TODO(rmlarsen): Skip tests that trigger a known bug in pldexp for now.
if (std::abs(x(i,j)) == max || std::abs(x(i,j)) == min) continue;
Scalar e = numext::pow(x(i,j), y(i,j));
Scalar e = static_cast<Scalar>(std::pow(x(i,j), y(i,j)));
Scalar a = actual(i, j);
bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e));
all_pass &= !fail;
@@ -79,7 +82,6 @@ void pow_test() {
VERIFY(all_pass);
}
template<typename ArrayType> void array(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;