Speed up plog_float by 1.6x with improved accuracy

libeigen/eigen!2382

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-04-03 13:45:01 -07:00
parent ebae0c7c10
commit a91913e961

View File

@@ -30,62 +30,104 @@ namespace internal {
// Exponential and Logarithmic Functions
//----------------------------------------------------------------------
// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
// be easily approximated by a polynomial centered on m=1 for stability.
// TODO(gonnet): Further reduce the interval allowing for lower-degree
// polynomial interpolants -> ... -> profit!
// Core range reduction and polynomial evaluation for float logarithm.
//
// Given a positive float value v (may be denormal), decomposes it as
// v = 2^e * (1+f) with f in [sqrt(0.5)-1, sqrt(2)-1], then evaluates
// log(1+f) ≈ f + f^2 * P(f) using a degree-7 minimax polynomial.
//
// Returns the approximation of log(v_mantissa) in log_mantissa and the
// integer exponent in e. The caller combines these as appropriate
// (e.g. e*ln2 + log_mantissa for natural log, or log_mantissa*log2e + e
// for log2).
//
// Range reduction uses integer bit manipulation (musl-inspired) instead of the
// heavier pfrexp_generic, saving ~12 ops. The minimax polynomial was found via
// Sollya's fpminimax, giving faithfully-rounded results (max 1 ULP for log).
template <typename Packet>
EIGEN_STRONG_INLINE void plog_core_float(const Packet v, Packet& log_mantissa, Packet& e) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const PacketI cst_min_normal = pset1<PacketI>(0x00800000);
const PacketI cst_mant_mask = pset1<PacketI>(0x007fffff);
// Adding this offset to the integer representation biases the exponent so
// that values near 1 (0x3f800000) map to exponent 0, and values below
// sqrt(0.5) get folded into the previous exponent. The magic constant is
// 0x3f800000 - 0x3f3504f3 = 0x004afb0d, where 0x3f3504f3 ≈ sqrt(0.5).
const PacketI cst_sqrt_half_offset = pset1<PacketI>(0x004afb0d);
const PacketI cst_exp_bias = pset1<PacketI>(0x7f); // 127
const PacketI cst_half_mant = pset1<PacketI>(0x3f3504f3); // sqrt(0.5)
// Normalize denormals by multiplying by 2^23.
PacketI vi = preinterpret<PacketI>(v);
PacketI is_denormal = pcmp_lt(vi, cst_min_normal);
Packet v_normalized = pmul(v, pset1<Packet>(8388608.0f)); // 2^23
vi = pselect(is_denormal, preinterpret<PacketI>(v_normalized), vi);
// Denormal exponent adjustment: subtract 23 from exponent.
PacketI denorm_adj = pand(is_denormal, pset1<PacketI>(23));
// Combined range reduction: bias integer representation so that exponent
// extraction automatically shifts mantissa to [sqrt(0.5), sqrt(2)).
PacketI vi_biased = padd(vi, cst_sqrt_half_offset);
// Extract exponent as integer, subtract bias and denormal adjustment.
PacketI e_int = psub(psub(plogical_shift_right<23>(vi_biased), cst_exp_bias), denorm_adj);
e = pcast<PacketI, Packet>(e_int);
// Reconstruct mantissa in [sqrt(0.5), sqrt(2)). The integer addition of the
// masked mantissa with 0x3f3504f3 (sqrt(0.5)) naturally produces carry into
// the exponent field, yielding values in [sqrt(0.5), 1) or [1, sqrt(2)).
// Then subtract 1 to center on 0 → f in [sqrt(0.5)-1, sqrt(2)-1].
Packet f = psub(preinterpret<Packet>(padd(pand(vi_biased, cst_mant_mask), cst_half_mant)), pset1<Packet>(1.0f));
// Minimax degree-7 polynomial for g(f) = (log(1+f) - f) / f^2 on
// [sqrt(0.5)-1, sqrt(2)-1], so log(1+f) ≈ f + f^2 * P(f).
// Generated by Sollya: fpminimax(g, 7, [|single...|], [lo;hi])
// Mathematical approximation error: max |log(1+f) - (f + f^2*P(f))| < 2.04e-8.
// Coefficients stored in reverse order for ppolevl (highest degree first).
constexpr float coeffs[] = {
8.8758550584316254e-02f, // c7 (x^7)
-1.4199858903884888e-01f, // c6 (x^6)
1.4824025332927704e-01f, // c5 (x^5)
-1.6583317518234253e-01f, // c4 (x^4)
1.9972395896911621e-01f, // c3 (x^3)
-2.5001299381256104e-01f, // c2 (x^2)
3.3333668112754822e-01f, // c1 (x^1)
-4.9999997019767761e-01f, // c0 (x^0)
};
// Evaluate P(f) via Horner's method, then log(1+f) ≈ f + f^2 * P(f).
Packet f2 = pmul(f, f);
Packet p = ppolevl<Packet, 7>::run(f, coeffs);
log_mantissa = pmadd(p, f2, f);
}
// Natural or base-2 logarithm for float packets.
//
// Computes log(x) as e*C + log(m), where x = 2^e * m with m in [sqrt(1/2), sqrt(2))
// and C = ln(2) for natural log, C = 1 for log2.
template <typename Packet, bool base2>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_float(const Packet _x) {
const Packet cst_1 = pset1<Packet>(1.0f);
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0xff800000u));
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0x7f800000u));
Packet log_mantissa, e;
plog_core_float(_x, log_mantissa, e);
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
Packet e, x;
// extract significant in the range [0.5,1) and exponent
x = pfrexp(_x, e);
// part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
// and shift by -1. The values are then centered around 0, which improves
// the stability of the polynomial evaluation.
// if( x < SQRTHF ) {
// e -= 1;
// x = x + x - 1.0;
// } else { x = x - 1.0; }
Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
Packet tmp = pand(x, mask);
x = psub(x, cst_1);
e = psub(e, pand(cst_1, mask));
x = padd(x, tmp);
// Polynomial coefficients for rational r(x) = p(x)/q(x)
// approximating log(1+x) on [sqrt(0.5)-1;sqrt(2)-1].
constexpr float alpha[] = {0.18256296349849254f, 1.0000000190281063f, 1.0000000190281136f};
constexpr float beta[] = {0.049616247954120038f, 0.59923249590823520f, 1.4999999999999927f, 1.0f};
Packet p = ppolevl<Packet, 2>::run(x, alpha);
p = pmul(x, p);
Packet q = ppolevl<Packet, 3>::run(x, beta);
x = pdiv(p, q);
// Add the logarithm of the exponent back to the result of the interpolation.
// Add the logarithm of the exponent back to the result.
Packet x;
if (base2) {
const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E));
x = pmadd(x, cst_log2e, e);
x = pmadd(log_mantissa, cst_log2e, e);
} else {
const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
x = pmadd(e, cst_ln2, x);
x = pmadd(e, cst_ln2, log_mantissa);
}
// Filter out invalid inputs:
// - negative arg → NAN
// - 0 → -INF
// - +INF → +INF
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0xff800000u));
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0x7f800000u));
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x, pzero(_x));
Packet pos_inf_mask = pcmp_eq(_x, cst_pos_inf);
// Filter out invalid inputs, i.e.:
// - negative arg will be NAN
// - 0 will be -INF
// - +INF will be +INF
return pselect(iszero_mask, cst_minus_inf, por(pselect(pos_inf_mask, cst_pos_inf, x), invalid_mask));
}
@@ -205,10 +247,9 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa
}
/** \internal \returns log(1 + x) for single precision float.
Computes log(1+x) directly with inline range reduction, avoiding
the double rounding in the Kahan formula (which calls plog(1+x)
as a black box). The rounding error from forming u = fl(1+x) is
recovered as dx = x - (u - 1), and folded in as a first-order
Computes log(1+x) using plog_core_float for the core range reduction
and polynomial evaluation. The rounding error from forming u = fl(1+x)
is recovered as dx = x - (u - 1), and folded in as a first-order
correction dx/u after the polynomial evaluation.
*/
template <typename Packet>
@@ -226,28 +267,14 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(c
// For u = +inf (x very large), return +inf.
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
// Inline the plog range reduction on u = 1 + x.
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
Packet e;
Packet m = pfrexp(u, e);
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
Packet tmp = pand(m, mask);
m = psub(m, one);
e = psub(e, pand(one, mask));
m = padd(m, tmp);
// Core range reduction and polynomial on u.
Packet log_u, e;
plog_core_float(u, log_u, e);
// Same rational polynomial as plog_float.
constexpr float alpha[] = {0.18256296349849254f, 1.0000000190281063f, 1.0000000190281136f};
constexpr float beta[] = {0.049616247954120038f, 0.59923249590823520f, 1.4999999999999927f, 1.0f};
Packet p = ppolevl<Packet, 2>::run(m, alpha);
p = pmul(m, p);
Packet q = ppolevl<Packet, 3>::run(m, beta);
Packet log_m = pdiv(p, q);
// result = e * ln2 + log(m) + dx/u.
// result = e * ln2 + log(u) + dx/u.
// The dx/u term corrects for the rounding error in u = fl(1+x).
const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
Packet result = pmadd(e, cst_ln2, padd(log_u, pdiv(dx, u)));
// Handle special cases.
Packet neg_mask = pcmp_lt(u, pzero(u));