From 7b5a8b6bc55151f1998dee63a46d745c4a635894 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 5 Jan 2022 23:40:31 +0000 Subject: [PATCH] Improve plog: 20% speedup for float + handle denormals --- .../arch/Default/GenericPacketMathFunctions.h | 60 ++++++------------- test/packetmath.cpp | 3 +- 2 files changed, 18 insertions(+), 45 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 6934e2a30..137daa815 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -170,33 +170,14 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet plog_impl_float(const Packet _x) { - Packet x = _x; - const Packet cst_1 = pset1(1.0f); - const Packet cst_neg_half = pset1(-0.5f); - // The smallest non denormalized float number. - const Packet cst_min_norm_pos = pset1frombits( 0x00800000u); const Packet cst_minus_inf = pset1frombits( 0xff800000u); const Packet cst_pos_inf = pset1frombits( 0x7f800000u); - // Polynomial coefficients. const Packet cst_cephes_SQRTHF = pset1(0.707106781186547524f); - const Packet cst_cephes_log_p0 = pset1(7.0376836292E-2f); - const Packet cst_cephes_log_p1 = pset1(-1.1514610310E-1f); - const Packet cst_cephes_log_p2 = pset1(1.1676998740E-1f); - const Packet cst_cephes_log_p3 = pset1(-1.2420140846E-1f); - const Packet cst_cephes_log_p4 = pset1(+1.4249322787E-1f); - const Packet cst_cephes_log_p5 = pset1(-1.6668057665E-1f); - const Packet cst_cephes_log_p6 = pset1(+2.0000714765E-1f); - const Packet cst_cephes_log_p7 = pset1(-2.4999993993E-1f); - const Packet cst_cephes_log_p8 = pset1(+3.3333331174E-1f); - - // Truncate input values to the minimum positive normal. - x = pmax(x, cst_min_norm_pos); - - Packet e; + Packet e, x; // extract significant in the range [0.5,1) and exponent - x = pfrexp(x,e); + x = pfrexp(_x,e); // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) // and shift by -1. The values are then centered around 0, which improves @@ -211,24 +192,22 @@ Packet plog_impl_float(const Packet _x) e = psub(e, pand(cst_1, mask)); x = padd(x, tmp); - Packet x2 = pmul(x, x); - Packet x3 = pmul(x2, x); + // Polynomial coefficients for rational (3,3) r(x) = p(x)/q(x) + // approximating log(1+x) on [sqrt(0.5)-1;sqrt(2)-1]. + const Packet cst_p1 = pset1(1.0000000190281136f); + const Packet cst_p2 = pset1(1.0000000190281063f); + const Packet cst_p3 = pset1(0.18256296349849254f); + const Packet cst_q1 = pset1(1.4999999999999927f); + const Packet cst_q2 = pset1(0.59923249590823520f); + const Packet cst_q3 = pset1(0.049616247954120038f); - // Evaluate the polynomial approximant of degree 8 in three parts, probably - // to improve instruction-level parallelism. - Packet y, y1, y2; - y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); - y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); - y2 = pmadd(cst_cephes_log_p6, x, cst_cephes_log_p7); - y = pmadd(y, x, cst_cephes_log_p2); - y1 = pmadd(y1, x, cst_cephes_log_p5); - y2 = pmadd(y2, x, cst_cephes_log_p8); - y = pmadd(y, x3, y1); - y = pmadd(y, x3, y2); - y = pmul(y, x3); - - y = pmadd(cst_neg_half, x2, y); - x = padd(x, y); + Packet p = pmadd(x, cst_p3, cst_p2); + p = pmadd(x, p, cst_p1); + p = pmul(x, p); + Packet q = pmadd(x, cst_q3, cst_q2); + q = pmadd(x, q, cst_q1); + q = pmadd(x, q, cst_1); + x = pdiv(p, q); // Add the logarithm of the exponent back to the result of the interpolation. if (base2) { @@ -284,8 +263,6 @@ Packet plog_impl_double(const Packet _x) const Packet cst_1 = pset1(1.0); const Packet cst_neg_half = pset1(-0.5); - // The smallest non denormalized double. - const Packet cst_min_norm_pos = pset1frombits( static_cast(0x0010000000000000ull)); const Packet cst_minus_inf = pset1frombits( static_cast(0xfff0000000000000ull)); const Packet cst_pos_inf = pset1frombits( static_cast(0x7ff0000000000000ull)); @@ -307,9 +284,6 @@ Packet plog_impl_double(const Packet _x) const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); - // Truncate input values to the minimum positive normal. - x = pmax(x, cst_min_norm_pos); - Packet e; // extract significant in the range [0.5,1) and exponent x = pfrexp(x,e); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index fcdc2bb67..db1c9adcb 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -890,8 +890,7 @@ void packetmath_real() { data1[0] = std::numeric_limits::denorm_min(); data1[1] = -std::numeric_limits::denorm_min(); h.store(data2, internal::plog(h.load(data1))); - // TODO(rmlarsen): Re-enable. - // VERIFY_IS_EQUAL(std::log(std::numeric_limits::denorm_min()), data2[0]); + VERIFY_IS_APPROX(std::log(std::numeric_limits::denorm_min()), data2[0]); VERIFY((numext::isnan)(data2[1])); } #endif