mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Improve log1p accuracy and speed with direct range reduction
libeigen/eigen!2378 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -60,7 +60,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4d ptan<Packet4d>(cons
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(atan, Packet4d)
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(exp2, Packet4d)
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(expm1, Packet4d)
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(log1p, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(log1p, Packet4d)
|
||||
|
||||
// Notice that for newer processors, it is counterproductive to use Newton
|
||||
// iteration for square root. In particular, Skylake and Zen2 processors
|
||||
|
||||
@@ -204,8 +204,137 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa
|
||||
return plog_impl_double<Packet, /* base2 */ true>(_x);
|
||||
}
|
||||
|
||||
/** \internal \returns log(1 + x) for single precision float.
|
||||
Computes log(1+x) directly with inline range reduction, avoiding
|
||||
the double rounding in the Kahan formula (which calls plog(1+x)
|
||||
as a black box). The rounding error from forming u = fl(1+x) is
|
||||
recovered as dx = x - (u - 1), and folded in as a first-order
|
||||
correction dx/u after the polynomial evaluation.
|
||||
*/
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x) {
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0xff800000u));
|
||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<Eigen::numext::uint32_t>(0x7f800000u));
|
||||
|
||||
// u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1).
|
||||
Packet u = padd(one, x);
|
||||
Packet dx = psub(x, psub(u, one));
|
||||
|
||||
// For |x| tiny enough that u rounds to 1, return x directly.
|
||||
Packet small_mask = pcmp_eq(u, one);
|
||||
// For u = +inf (x very large), return +inf.
|
||||
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
|
||||
|
||||
// Inline the plog range reduction on u = 1 + x.
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
|
||||
Packet e;
|
||||
Packet m = pfrexp(u, e);
|
||||
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(m, mask);
|
||||
m = psub(m, one);
|
||||
e = psub(e, pand(one, mask));
|
||||
m = padd(m, tmp);
|
||||
|
||||
// Same rational polynomial as plog_float.
|
||||
constexpr float alpha[] = {0.18256296349849254f, 1.0000000190281063f, 1.0000000190281136f};
|
||||
constexpr float beta[] = {0.049616247954120038f, 0.59923249590823520f, 1.4999999999999927f, 1.0f};
|
||||
Packet p = ppolevl<Packet, 2>::run(m, alpha);
|
||||
p = pmul(m, p);
|
||||
Packet q = ppolevl<Packet, 3>::run(m, beta);
|
||||
Packet log_m = pdiv(p, q);
|
||||
|
||||
// result = e * ln2 + log(m) + dx/u.
|
||||
// The dx/u term corrects for the rounding error in u = fl(1+x).
|
||||
const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
|
||||
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
|
||||
|
||||
// Handle special cases.
|
||||
Packet neg_mask = pcmp_lt(u, pzero(u));
|
||||
Packet zero_mask = pcmp_eq(x, pset1<Packet>(-1.0f));
|
||||
result = pselect(small_mask, x, result);
|
||||
result = pselect(inf_mask, cst_pos_inf, result);
|
||||
result = pselect(zero_mask, cst_minus_inf, result);
|
||||
result = por(neg_mask, result); // NaN for x < -1
|
||||
return result;
|
||||
}
|
||||
|
||||
/** \internal \returns log(1 + x) for double precision float.
|
||||
Same direct approach as the float version.
|
||||
*/
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) {
|
||||
const Packet one = pset1<Packet>(1.0);
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
|
||||
Packet u = padd(one, x);
|
||||
Packet dx = psub(x, psub(u, one));
|
||||
|
||||
Packet small_mask = pcmp_eq(u, one);
|
||||
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
|
||||
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
||||
Packet e;
|
||||
Packet m = pfrexp(u, e);
|
||||
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(m, mask);
|
||||
m = psub(m, one);
|
||||
e = psub(e, pand(one, mask));
|
||||
m = padd(m, tmp);
|
||||
|
||||
// Same polynomial as plog_double.
|
||||
const Packet cst_neg_half = pset1<Packet>(-0.5);
|
||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
||||
const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
|
||||
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
|
||||
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
|
||||
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
|
||||
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
|
||||
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
|
||||
|
||||
Packet m2 = pmul(m, m);
|
||||
Packet m3 = pmul(m2, m);
|
||||
|
||||
Packet y, y1, y_;
|
||||
y = pmadd(cst_cephes_log_p0, m, cst_cephes_log_p1);
|
||||
y1 = pmadd(cst_cephes_log_p3, m, cst_cephes_log_p4);
|
||||
y = pmadd(y, m, cst_cephes_log_p2);
|
||||
y1 = pmadd(y1, m, cst_cephes_log_p5);
|
||||
y_ = pmadd(y, m3, y1);
|
||||
|
||||
y = pmadd(cst_cephes_log_q0, m, cst_cephes_log_q1);
|
||||
y1 = pmadd(cst_cephes_log_q3, m, cst_cephes_log_q4);
|
||||
y = pmadd(y, m, cst_cephes_log_q2);
|
||||
y1 = pmadd(y1, m, cst_cephes_log_q5);
|
||||
y = pmadd(y, m3, y1);
|
||||
|
||||
y_ = pmul(y_, m3);
|
||||
Packet log_m = pdiv(y_, y);
|
||||
log_m = pmadd(cst_neg_half, m2, log_m);
|
||||
log_m = padd(m, log_m);
|
||||
|
||||
// result = e * ln2 + log(m) + dx/u.
|
||||
const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
|
||||
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
|
||||
|
||||
Packet neg_mask = pcmp_lt(u, pzero(u));
|
||||
Packet zero_mask = pcmp_eq(x, pset1<Packet>(-1.0));
|
||||
result = pselect(small_mask, x, result);
|
||||
result = pselect(inf_mask, cst_pos_inf, result);
|
||||
result = pselect(zero_mask, cst_minus_inf, result);
|
||||
result = por(neg_mask, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
|
||||
See: http://www.plunk.org/~hatch/rightway.php
|
||||
This is the generic fallback for types without a specialized implementation.
|
||||
*/
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x) {
|
||||
|
||||
@@ -82,6 +82,26 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_double(const Pa
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p(const Packet& x);
|
||||
|
||||
/** \internal \returns log(1+x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(const Packet& x);
|
||||
|
||||
/** \internal \returns log(1+x) for double precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x);
|
||||
|
||||
/** \internal \returns log(1+x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_float(const Packet& x) {
|
||||
return generic_log1p_float(x);
|
||||
}
|
||||
|
||||
/** \internal \returns log(1+x) for double precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p_double(const Packet& x) {
|
||||
return generic_log1p_double(x);
|
||||
}
|
||||
|
||||
/** \internal \returns exp(x)-1 */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_expm1(const Packet& x);
|
||||
@@ -264,7 +284,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(log1p, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET)
|
||||
|
||||
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
|
||||
@@ -284,7 +304,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(log1p, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET)
|
||||
|
||||
// Macro to instantiate complex math function specializations (psqrt, plog, pexp)
|
||||
|
||||
Reference in New Issue
Block a user