Improve log1p accuracy and speed with direct range reduction

libeigen/eigen!2378

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-04-02 11:29:25 -07:00
parent d31a73437f
commit 61a8662876
3 changed files with 152 additions and 3 deletions

View File

@@ -60,7 +60,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4d ptan<Packet4d>(cons
EIGEN_GENERIC_PACKET_FUNCTION(atan, Packet4d)
EIGEN_GENERIC_PACKET_FUNCTION(exp2, Packet4d)
EIGEN_GENERIC_PACKET_FUNCTION(expm1, Packet4d)
EIGEN_GENERIC_PACKET_FUNCTION(log1p, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log1p, Packet4d)
// Notice that for newer processors, it is counterproductive to use Newton
// iteration for square root. In particular, Skylake and Zen2 processors

View File

@@ -204,8 +204,137 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa
return plog_impl_double<Packet, /* base2 */ true>(_x);
}
/** \internal \returns log(1 + x) for single precision float.
Computes log(1+x) directly with inline range reduction, avoiding
the double rounding in the Kahan formula (which calls plog(1+x)
as a black box). The rounding error from forming u = fl(1+x) is
recovered as dx = x - (u - 1), and folded in as a first-order
correction dx/u after the polynomial evaluation.
*/
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x) {
const Packet one = pset1<Packet>(1.0f);
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0xff800000u));
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0x7f800000u));
// u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1).
Packet u = padd(one, x);
Packet dx = psub(x, psub(u, one));
// For |x| tiny enough that u rounds to 1, return x directly.
Packet small_mask = pcmp_eq(u, one);
// For u = +inf (x very large), return +inf.
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
// Inline the plog range reduction on u = 1 + x.
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
Packet e;
Packet m = pfrexp(u, e);
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
Packet tmp = pand(m, mask);
m = psub(m, one);
e = psub(e, pand(one, mask));
m = padd(m, tmp);
// Same rational polynomial as plog_float.
constexpr float alpha[] = {0.18256296349849254f, 1.0000000190281063f, 1.0000000190281136f};
constexpr float beta[] = {0.049616247954120038f, 0.59923249590823520f, 1.4999999999999927f, 1.0f};
Packet p = ppolevl<Packet, 2>::run(m, alpha);
p = pmul(m, p);
Packet q = ppolevl<Packet, 3>::run(m, beta);
Packet log_m = pdiv(p, q);
// result = e * ln2 + log(m) + dx/u.
// The dx/u term corrects for the rounding error in u = fl(1+x).
const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
// Handle special cases.
Packet neg_mask = pcmp_lt(u, pzero(u));
Packet zero_mask = pcmp_eq(x, pset1<Packet>(-1.0f));
result = pselect(small_mask, x, result);
result = pselect(inf_mask, cst_pos_inf, result);
result = pselect(zero_mask, cst_minus_inf, result);
result = por(neg_mask, result); // NaN for x < -1
return result;
}
/** \internal \returns log(1 + x) for double precision float.
Same direct approach as the float version.
*/
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) {
const Packet one = pset1<Packet>(1.0);
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
Packet u = padd(one, x);
Packet dx = psub(x, psub(u, one));
Packet small_mask = pcmp_eq(u, one);
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
Packet e;
Packet m = pfrexp(u, e);
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
Packet tmp = pand(m, mask);
m = psub(m, one);
e = psub(e, pand(one, mask));
m = padd(m, tmp);
// Same polynomial as plog_double.
const Packet cst_neg_half = pset1<Packet>(-0.5);
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
Packet m2 = pmul(m, m);
Packet m3 = pmul(m2, m);
Packet y, y1, y_;
y = pmadd(cst_cephes_log_p0, m, cst_cephes_log_p1);
y1 = pmadd(cst_cephes_log_p3, m, cst_cephes_log_p4);
y = pmadd(y, m, cst_cephes_log_p2);
y1 = pmadd(y1, m, cst_cephes_log_p5);
y_ = pmadd(y, m3, y1);
y = pmadd(cst_cephes_log_q0, m, cst_cephes_log_q1);
y1 = pmadd(cst_cephes_log_q3, m, cst_cephes_log_q4);
y = pmadd(y, m, cst_cephes_log_q2);
y1 = pmadd(y1, m, cst_cephes_log_q5);
y = pmadd(y, m3, y1);
y_ = pmul(y_, m3);
Packet log_m = pdiv(y_, y);
log_m = pmadd(cst_neg_half, m2, log_m);
log_m = padd(m, log_m);
// result = e * ln2 + log(m) + dx/u.
const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
Packet neg_mask = pcmp_lt(u, pzero(u));
Packet zero_mask = pcmp_eq(x, pset1<Packet>(-1.0));
result = pselect(small_mask, x, result);
result = pselect(inf_mask, cst_pos_inf, result);
result = pselect(zero_mask, cst_minus_inf, result);
result = por(neg_mask, result);
return result;
}
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
This is the generic fallback for types without a specialized implementation.
*/
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x) {

View File

@@ -82,6 +82,26 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x);
/** \internal \returns log(1+x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x);
/** \internal \returns log(1+x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x);
/** \internal \returns log(1+x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_float(const Packet& x) {
return generic_log1p_float(x);
}
/** \internal \returns log(1+x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_double(const Packet& x) {
return generic_log1p_double(x);
}
/** \internal \returns exp(x)-1 */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_expm1(const Packet& x);
@@ -264,7 +284,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(log1p, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET)
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
@@ -284,7 +304,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log1p, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET)
// Macro to instantiate complex math function specializations (psqrt, plog, pexp)