From 61a866287673fbcad7cec37e649b20352afd052f Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen <4643818-rmlarsen1@users.noreply.gitlab.com> Date: Thu, 2 Apr 2026 11:29:25 -0700 Subject: [PATCH] Improve log1p accuracy and speed with direct range reduction libeigen/eigen!2378 Co-authored-by: Rasmus Munk Larsen --- Eigen/src/Core/arch/AVX/MathFunctions.h | 2 +- .../arch/Default/GenericPacketMathFunctions.h | 129 ++++++++++++++++++ .../Default/GenericPacketMathFunctionsFwd.h | 24 +++- 3 files changed, 152 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 43fe36da5..357f3142e 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -60,7 +60,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4d ptan(cons EIGEN_GENERIC_PACKET_FUNCTION(atan, Packet4d) EIGEN_GENERIC_PACKET_FUNCTION(exp2, Packet4d) EIGEN_GENERIC_PACKET_FUNCTION(expm1, Packet4d) -EIGEN_GENERIC_PACKET_FUNCTION(log1p, Packet4d) +EIGEN_DOUBLE_PACKET_FUNCTION(log1p, Packet4d) // Notice that for newer processors, it is counterproductive to use Newton // iteration for square root. In particular, Skylake and Zen2 processors diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index b79971f39..f30c87edf 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -204,8 +204,137 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa return plog_impl_double(_x); } +/** \internal \returns log(1 + x) for single precision float. + Computes log(1+x) directly with inline range reduction, avoiding + the double rounding in the Kahan formula (which calls plog(1+x) + as a black box). The rounding error from forming u = fl(1+x) is + recovered as dx = x - (u - 1), and folded in as a first-order + correction dx/u after the polynomial evaluation. + */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x) { + const Packet one = pset1(1.0f); + const Packet cst_minus_inf = pset1frombits(static_cast(0xff800000u)); + const Packet cst_pos_inf = pset1frombits(static_cast(0x7f800000u)); + + // u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1). + Packet u = padd(one, x); + Packet dx = psub(x, psub(u, one)); + + // For |x| tiny enough that u rounds to 1, return x directly. + Packet small_mask = pcmp_eq(u, one); + // For u = +inf (x very large), return +inf. + Packet inf_mask = pcmp_eq(u, cst_pos_inf); + + // Inline the plog range reduction on u = 1 + x. + const Packet cst_cephes_SQRTHF = pset1(0.707106781186547524f); + Packet e; + Packet m = pfrexp(u, e); + Packet mask = pcmp_lt(m, cst_cephes_SQRTHF); + Packet tmp = pand(m, mask); + m = psub(m, one); + e = psub(e, pand(one, mask)); + m = padd(m, tmp); + + // Same rational polynomial as plog_float. + constexpr float alpha[] = {0.18256296349849254f, 1.0000000190281063f, 1.0000000190281136f}; + constexpr float beta[] = {0.049616247954120038f, 0.59923249590823520f, 1.4999999999999927f, 1.0f}; + Packet p = ppolevl::run(m, alpha); + p = pmul(m, p); + Packet q = ppolevl::run(m, beta); + Packet log_m = pdiv(p, q); + + // result = e * ln2 + log(m) + dx/u. + // The dx/u term corrects for the rounding error in u = fl(1+x). + const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); + Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u))); + + // Handle special cases. + Packet neg_mask = pcmp_lt(u, pzero(u)); + Packet zero_mask = pcmp_eq(x, pset1(-1.0f)); + result = pselect(small_mask, x, result); + result = pselect(inf_mask, cst_pos_inf, result); + result = pselect(zero_mask, cst_minus_inf, result); + result = por(neg_mask, result); // NaN for x < -1 + return result; +} + +/** \internal \returns log(1 + x) for double precision float. + Same direct approach as the float version. + */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) { + const Packet one = pset1(1.0); + const Packet cst_minus_inf = pset1frombits(static_cast(0xfff0000000000000ull)); + const Packet cst_pos_inf = pset1frombits(static_cast(0x7ff0000000000000ull)); + + Packet u = padd(one, x); + Packet dx = psub(x, psub(u, one)); + + Packet small_mask = pcmp_eq(u, one); + Packet inf_mask = pcmp_eq(u, cst_pos_inf); + + const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); + Packet e; + Packet m = pfrexp(u, e); + Packet mask = pcmp_lt(m, cst_cephes_SQRTHF); + Packet tmp = pand(m, mask); + m = psub(m, one); + e = psub(e, pand(one, mask)); + m = padd(m, tmp); + + // Same polynomial as plog_double. + const Packet cst_neg_half = pset1(-0.5); + const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); + const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); + const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); + const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); + const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); + const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); + const Packet cst_cephes_log_q0 = pset1(1.0); + const Packet cst_cephes_log_q1 = pset1(1.12873587189167450590E1); + const Packet cst_cephes_log_q2 = pset1(4.52279145837532221105E1); + const Packet cst_cephes_log_q3 = pset1(8.29875266912776603211E1); + const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); + const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); + + Packet m2 = pmul(m, m); + Packet m3 = pmul(m2, m); + + Packet y, y1, y_; + y = pmadd(cst_cephes_log_p0, m, cst_cephes_log_p1); + y1 = pmadd(cst_cephes_log_p3, m, cst_cephes_log_p4); + y = pmadd(y, m, cst_cephes_log_p2); + y1 = pmadd(y1, m, cst_cephes_log_p5); + y_ = pmadd(y, m3, y1); + + y = pmadd(cst_cephes_log_q0, m, cst_cephes_log_q1); + y1 = pmadd(cst_cephes_log_q3, m, cst_cephes_log_q4); + y = pmadd(y, m, cst_cephes_log_q2); + y1 = pmadd(y1, m, cst_cephes_log_q5); + y = pmadd(y, m3, y1); + + y_ = pmul(y_, m3); + Packet log_m = pdiv(y_, y); + log_m = pmadd(cst_neg_half, m2, log_m); + log_m = padd(m, log_m); + + // result = e * ln2 + log(m) + dx/u. + const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); + Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u))); + + Packet neg_mask = pcmp_lt(u, pzero(u)); + Packet zero_mask = pcmp_eq(x, pset1(-1.0)); + result = pselect(small_mask, x, result); + result = pselect(inf_mask, cst_pos_inf, result); + result = pselect(zero_mask, cst_minus_inf, result); + result = por(neg_mask, result); + return result; +} + /** \internal \returns log(1 + x) computed using W. Kahan's formula. See: http://www.plunk.org/~hatch/rightway.php + This is the generic fallback for types without a specialized implementation. */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x) { diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index aa3098a2b..4534c9d7e 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -82,6 +82,26 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x); +/** \internal \returns log(1+x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x); + +/** \internal \returns log(1+x) for double precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x); + +/** \internal \returns log(1+x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_float(const Packet& x) { + return generic_log1p_float(x); +} + +/** \internal \returns log(1+x) for double precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_double(const Packet& x) { + return generic_log1p_double(x); +} + /** \internal \returns exp(x)-1 */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_expm1(const Packet& x); @@ -264,7 +284,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a); EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \ - EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \ + EIGEN_FLOAT_PACKET_FUNCTION(log1p, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET) #define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \ @@ -284,7 +304,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a); EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \ - EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \ + EIGEN_DOUBLE_PACKET_FUNCTION(log1p, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET) // Macro to instantiate complex math function specializations (psqrt, plog, pexp)