diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 3d8a8cf42..d4aa3fbe8 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -141,6 +141,140 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_float(const Pac return plog_impl_float(_x); } +// ----------------------------------------------------------------------- +// Double logarithm: shared polynomial + two range-reduction backends +// ----------------------------------------------------------------------- + +// Cephes rational-polynomial approximation of log(1+f) for +// f in [sqrt(0.5)-1, sqrt(2)-1]. +// Evaluates x - 0.5*x^2 + x^3 * P(x)/Q(x) where P and Q are degree-5. +// See: http://www.netlib.org/cephes/ +template +EIGEN_STRONG_INLINE Packet plog_mantissa_double(const Packet x) { + const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); + const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); + const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); + const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); + const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); + const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); + // Q0 = 1.0; pmadd(1, x, q1) simplifies to padd(x, q1). + const Packet cst_cephes_log_q1 = pset1(1.12873587189167450590E1); + const Packet cst_cephes_log_q2 = pset1(4.52279145837532221105E1); + const Packet cst_cephes_log_q3 = pset1(8.29875266912776603211E1); + const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); + const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); + + Packet x2 = pmul(x, x); + Packet x3 = pmul(x2, x); + + // Evaluate P and Q simultaneously for better ILP. + Packet y, y1, y_; + y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); + y = pmadd(y, x, cst_cephes_log_p2); + y1 = pmadd(y1, x, cst_cephes_log_p5); + y_ = pmadd(y, x3, y1); + + y = padd(x, cst_cephes_log_q1); + y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4); + y = pmadd(y, x, cst_cephes_log_q2); + y1 = pmadd(y1, x, cst_cephes_log_q5); + y = pmadd(y, x3, y1); + + y_ = pmul(y_, x3); + y = pdiv(y_, y); + y = pnmadd(pset1(0.5), x2, y); + return padd(x, y); +} + +// Detect whether unpacket_traits::integer_packet is defined. +template +struct packet_has_integer_packet : std::false_type {}; +template +struct packet_has_integer_packet::integer_packet>> : std::true_type {}; + +// Dispatch struct for double-precision range reduction. +// Primary template: pfrexp-based fallback (used when integer_packet is absent). +template +struct plog_range_reduce_double { + EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) { + const Packet one = pset1(1.0); + const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); + // pfrexp: f in [0.5, 1), e = unbiased exponent as double. + f = pfrexp(v, e); + // Shift [0.5,1) -> [sqrt(0.5)-1, sqrt(2)-1] with exponent correction: + // if f < sqrt(0.5): f = f + f - 1, e -= 1 (giving f in [0, sqrt(2)-1)) + // else: f = f - 1 (giving f in [sqrt(0.5)-1, 0)) + Packet mask = pcmp_lt(f, cst_cephes_SQRTHF); + Packet tmp = pand(f, mask); + f = psub(f, one); + e = psub(e, pand(one, mask)); + f = padd(f, tmp); + } +}; + +// Specialisation: fast integer-bit-manipulation path (musl-inspired). +// Requires unpacket_traits::integer_packet to be a 64-bit integer packet. +template +struct plog_range_reduce_double { + EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) { + typedef typename unpacket_traits::integer_packet PacketI; + // 2^-1022: smallest positive normal double. + const PacketI cst_min_normal = pset1(static_cast(0x0010000000000000LL)); + // Lower 52-bit mask (IEEE mantissa field). + const PacketI cst_mant_mask = pset1(static_cast(0x000FFFFFFFFFFFFFLL)); + // Offset = 1.0_bits - sqrt(0.5)_bits. Adding this to the integer + // representation shifts the exponent field so that the [sqrt(0.5), sqrt(2)) + // half-octave boundary falls on an exact biased-exponent boundary, letting + // us extract e with a single right shift. The constant is: + // 0x3FF0000000000000 - 0x3FE6A09E667F3BCD = 0x00095F619980C433 + const PacketI cst_sqrt_half_offset = + pset1(static_cast(0x3FF0000000000000LL - 0x3FE6A09E667F3BCDLL)); + // IEEE double exponent bias (1023). + const PacketI cst_exp_bias = pset1(static_cast(1023)); + // sqrt(0.5) IEEE bits — used to reconstruct f from biased mantissa. + const PacketI cst_half_mant = pset1(static_cast(0x3FE6A09E667F3BCDLL)); + + // Reinterpret v as a 64-bit integer vector. + PacketI vi = preinterpret(v); + + // Normalise denormals: multiply by 2^52 and correct the exponent by -52. + PacketI is_denormal = pcmp_lt(vi, cst_min_normal); + // 2^52 via bit pattern: biased exponent = 52 + 1023 = 0x433, mantissa = 0. + Packet v_norm = pmul(v, pset1frombits(static_cast(int64_t(52 + 0x3ff) << 52))); + vi = pselect(is_denormal, preinterpret(v_norm), vi); + PacketI denorm_adj = pand(is_denormal, pset1(static_cast(52))); + + // Bias the integer representation so the exponent field directly encodes + // the half-octave index. + PacketI vi_biased = padd(vi, cst_sqrt_half_offset); + // Extract unbiased exponent: shift out mantissa bits, subtract IEEE bias + // and denormal adjustment. + PacketI e_int = psub(psub(plogical_shift_right<52>(vi_biased), cst_exp_bias), denorm_adj); + // Convert integer exponent to floating-point. + e = pcast(e_int); + + // Reconstruct mantissa in [sqrt(0.5), sqrt(2)) via integer arithmetic. + // The integer addition of the masked mantissa bits and the sqrt(0.5) bit + // pattern carries into the exponent field, yielding a value in that range. + // Then subtract 1 to centre on 0: f in [sqrt(0.5)-1, sqrt(2)-1]. + f = psub(preinterpret(padd(pand(vi_biased, cst_mant_mask), cst_half_mant)), pset1(1.0)); + } +}; + +// Core range reduction and polynomial for double logarithm. +// Input: v > 0 (zero / negative / inf / nan are handled by the caller). +// Output: log_mantissa ≈ log(mantissa of v in [sqrt(0.5), sqrt(2))), +// e = unbiased exponent of v as a double. +// Selects the fast integer path when integer_packet is available, otherwise +// falls back to pfrexp. +template +EIGEN_STRONG_INLINE void plog_core_double(const Packet v, Packet& log_mantissa, Packet& e) { + Packet f; + plog_range_reduce_double::value>::run(v, f, e); + log_mantissa = plog_mantissa_double(f); +} + /* Returns the base e (2.718...) or base 2 logarithm of x. * The argument is separated into its exponent and fractional parts. * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)], @@ -152,87 +286,29 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_float(const Pac */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_double(const Packet _x) { - Packet x = _x; - - const Packet cst_1 = pset1(1.0); - const Packet cst_neg_half = pset1(-0.5); const Packet cst_minus_inf = pset1frombits(static_cast(0xfff0000000000000ull)); const Packet cst_pos_inf = pset1frombits(static_cast(0x7ff0000000000000ull)); - // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x) - // 1/sqrt(2) <= x < sqrt(2) - const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); - const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); - const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); - const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); - const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); - const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); - const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); + Packet log_mantissa, e; + plog_core_double(_x, log_mantissa, e); - const Packet cst_cephes_log_q0 = pset1(1.0); - const Packet cst_cephes_log_q1 = pset1(1.12873587189167450590E1); - const Packet cst_cephes_log_q2 = pset1(4.52279145837532221105E1); - const Packet cst_cephes_log_q3 = pset1(8.29875266912776603211E1); - const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); - const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); - - Packet e; - // extract significant in the range [0.5,1) and exponent - x = pfrexp(x, e); - - // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) - // and shift by -1. The values are then centered around 0, which improves - // the stability of the polynomial evaluation. - // if( x < SQRTHF ) { - // e -= 1; - // x = x + x - 1.0; - // } else { x = x - 1.0; } - Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); - Packet tmp = pand(x, mask); - x = psub(x, cst_1); - e = psub(e, pand(cst_1, mask)); - x = padd(x, tmp); - - Packet x2 = pmul(x, x); - Packet x3 = pmul(x2, x); - - // Evaluate the polynomial in factored form for better instruction-level parallelism. - // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) ); - Packet y, y1, y_; - y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); - y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); - y = pmadd(y, x, cst_cephes_log_p2); - y1 = pmadd(y1, x, cst_cephes_log_p5); - y_ = pmadd(y, x3, y1); - - y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1); - y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4); - y = pmadd(y, x, cst_cephes_log_q2); - y1 = pmadd(y1, x, cst_cephes_log_q5); - y = pmadd(y, x3, y1); - - y_ = pmul(y_, x3); - y = pdiv(y_, y); - - y = pmadd(cst_neg_half, x2, y); - x = padd(x, y); - - // Add the logarithm of the exponent back to the result of the interpolation. + // Combine: log(x) = e * ln2 + log(mantissa), or log2(x) = log(mantissa)*log2e + e. + Packet x; if (base2) { const Packet cst_log2e = pset1(static_cast(EIGEN_LOG2E)); - x = pmadd(x, cst_log2e, e); + x = pmadd(log_mantissa, cst_log2e, e); } else { const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); - x = pmadd(e, cst_ln2, x); + x = pmadd(e, cst_ln2, log_mantissa); } Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); Packet iszero_mask = pcmp_eq(_x, pzero(_x)); Packet pos_inf_mask = pcmp_eq(_x, cst_pos_inf); - // Filter out invalid inputs, i.e.: - // - negative arg will be NAN - // - 0 will be -INF - // - +INF will be +INF + // Filter out invalid inputs: + // - negative arg → NAN + // - 0 → -INF + // - +INF → +INF return pselect(iszero_mask, cst_minus_inf, por(pselect(pos_inf_mask, cst_pos_inf, x), invalid_mask)); } @@ -286,8 +362,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(c return result; } -/** \internal \returns log(1 + x) for double precision float. - Same direct approach as the float version. +/** \internal \returns log(1 + x) for double precision. + Computes log(1+x) using plog_core_double for the core range reduction and + polynomial evaluation. The rounding error from forming u = fl(1+x) is + recovered as dx = x - (u - 1) and folded in as a first-order correction + dx/u after the polynomial evaluation. */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) { @@ -295,67 +374,31 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double( const Packet cst_minus_inf = pset1frombits(static_cast(0xfff0000000000000ull)); const Packet cst_pos_inf = pset1frombits(static_cast(0x7ff0000000000000ull)); + // u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1). Packet u = padd(one, x); Packet dx = psub(x, psub(u, one)); + // For |x| tiny enough that u rounds to 1, return x directly. Packet small_mask = pcmp_eq(u, one); + // For u = +inf (x very large), return +inf. Packet inf_mask = pcmp_eq(u, cst_pos_inf); - const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); - Packet e; - Packet m = pfrexp(u, e); - Packet mask = pcmp_lt(m, cst_cephes_SQRTHF); - Packet tmp = pand(m, mask); - m = psub(m, one); - e = psub(e, pand(one, mask)); - m = padd(m, tmp); + // Core range reduction and polynomial on u. + Packet log_u, e; + plog_core_double(u, log_u, e); - // Same polynomial as plog_double. - const Packet cst_neg_half = pset1(-0.5); - const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); - const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); - const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); - const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); - const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); - const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); - const Packet cst_cephes_log_q0 = pset1(1.0); - const Packet cst_cephes_log_q1 = pset1(1.12873587189167450590E1); - const Packet cst_cephes_log_q2 = pset1(4.52279145837532221105E1); - const Packet cst_cephes_log_q3 = pset1(8.29875266912776603211E1); - const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); - const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); - - Packet m2 = pmul(m, m); - Packet m3 = pmul(m2, m); - - Packet y, y1, y_; - y = pmadd(cst_cephes_log_p0, m, cst_cephes_log_p1); - y1 = pmadd(cst_cephes_log_p3, m, cst_cephes_log_p4); - y = pmadd(y, m, cst_cephes_log_p2); - y1 = pmadd(y1, m, cst_cephes_log_p5); - y_ = pmadd(y, m3, y1); - - y = pmadd(cst_cephes_log_q0, m, cst_cephes_log_q1); - y1 = pmadd(cst_cephes_log_q3, m, cst_cephes_log_q4); - y = pmadd(y, m, cst_cephes_log_q2); - y1 = pmadd(y1, m, cst_cephes_log_q5); - y = pmadd(y, m3, y1); - - y_ = pmul(y_, m3); - Packet log_m = pdiv(y_, y); - log_m = pmadd(cst_neg_half, m2, log_m); - log_m = padd(m, log_m); - - // result = e * ln2 + log(m) + dx/u. + // result = e * ln2 + log(u) + dx/u. + // The dx/u term corrects for the rounding error in u = fl(1+x). const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); - Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u))); + Packet result = pmadd(e, cst_ln2, padd(log_u, pdiv(dx, u))); + // Handle special cases. Packet neg_mask = pcmp_lt(u, pzero(u)); Packet zero_mask = pcmp_eq(x, pset1(-1.0)); result = pselect(small_mask, x, result); result = pselect(inf_mask, cst_pos_inf, result); result = pselect(zero_mask, cst_minus_inf, result); - result = por(neg_mask, result); + result = por(neg_mask, result); // NaN for x < -1 return result; }