mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Add plog_core_double with fallback for AVX without AVX2
libeigen/eigen!2407 Co-authored-by: Rasmus Munk Larsen <rlarsen@nvidia.com>
This commit is contained in:
@@ -141,6 +141,140 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_float(const Pac
|
||||
return plog_impl_float<Packet, /* base2 */ true>(_x);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Double logarithm: shared polynomial + two range-reduction backends
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Cephes rational-polynomial approximation of log(1+f) for
|
||||
// f in [sqrt(0.5)-1, sqrt(2)-1].
|
||||
// Evaluates x - 0.5*x^2 + x^3 * P(x)/Q(x) where P and Q are degree-5.
|
||||
// See: http://www.netlib.org/cephes/
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE Packet plog_mantissa_double(const Packet x) {
|
||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
||||
// Q0 = 1.0; pmadd(1, x, q1) simplifies to padd(x, q1).
|
||||
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
|
||||
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
|
||||
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
|
||||
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
|
||||
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
|
||||
|
||||
Packet x2 = pmul(x, x);
|
||||
Packet x3 = pmul(x2, x);
|
||||
|
||||
// Evaluate P and Q simultaneously for better ILP.
|
||||
Packet y, y1, y_;
|
||||
y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
|
||||
y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
|
||||
y = pmadd(y, x, cst_cephes_log_p2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_p5);
|
||||
y_ = pmadd(y, x3, y1);
|
||||
|
||||
y = padd(x, cst_cephes_log_q1);
|
||||
y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
|
||||
y = pmadd(y, x, cst_cephes_log_q2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_q5);
|
||||
y = pmadd(y, x3, y1);
|
||||
|
||||
y_ = pmul(y_, x3);
|
||||
y = pdiv(y_, y);
|
||||
y = pnmadd(pset1<Packet>(0.5), x2, y);
|
||||
return padd(x, y);
|
||||
}
|
||||
|
||||
// Detect whether unpacket_traits<Packet>::integer_packet is defined.
|
||||
template <typename Packet, typename = void>
|
||||
struct packet_has_integer_packet : std::false_type {};
|
||||
template <typename Packet>
|
||||
struct packet_has_integer_packet<Packet, void_t<typename unpacket_traits<Packet>::integer_packet>> : std::true_type {};
|
||||
|
||||
// Dispatch struct for double-precision range reduction.
|
||||
// Primary template: pfrexp-based fallback (used when integer_packet is absent).
|
||||
template <typename Packet, bool UseIntegerPacket>
|
||||
struct plog_range_reduce_double {
|
||||
EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) {
|
||||
const Packet one = pset1<Packet>(1.0);
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
||||
// pfrexp: f in [0.5, 1), e = unbiased exponent as double.
|
||||
f = pfrexp(v, e);
|
||||
// Shift [0.5,1) -> [sqrt(0.5)-1, sqrt(2)-1] with exponent correction:
|
||||
// if f < sqrt(0.5): f = f + f - 1, e -= 1 (giving f in [0, sqrt(2)-1))
|
||||
// else: f = f - 1 (giving f in [sqrt(0.5)-1, 0))
|
||||
Packet mask = pcmp_lt(f, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(f, mask);
|
||||
f = psub(f, one);
|
||||
e = psub(e, pand(one, mask));
|
||||
f = padd(f, tmp);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialisation: fast integer-bit-manipulation path (musl-inspired).
|
||||
// Requires unpacket_traits<Packet>::integer_packet to be a 64-bit integer packet.
|
||||
template <typename Packet>
|
||||
struct plog_range_reduce_double<Packet, true> {
|
||||
EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
// 2^-1022: smallest positive normal double.
|
||||
const PacketI cst_min_normal = pset1<PacketI>(static_cast<int64_t>(0x0010000000000000LL));
|
||||
// Lower 52-bit mask (IEEE mantissa field).
|
||||
const PacketI cst_mant_mask = pset1<PacketI>(static_cast<int64_t>(0x000FFFFFFFFFFFFFLL));
|
||||
// Offset = 1.0_bits - sqrt(0.5)_bits. Adding this to the integer
|
||||
// representation shifts the exponent field so that the [sqrt(0.5), sqrt(2))
|
||||
// half-octave boundary falls on an exact biased-exponent boundary, letting
|
||||
// us extract e with a single right shift. The constant is:
|
||||
// 0x3FF0000000000000 - 0x3FE6A09E667F3BCD = 0x00095F619980C433
|
||||
const PacketI cst_sqrt_half_offset =
|
||||
pset1<PacketI>(static_cast<int64_t>(0x3FF0000000000000LL - 0x3FE6A09E667F3BCDLL));
|
||||
// IEEE double exponent bias (1023).
|
||||
const PacketI cst_exp_bias = pset1<PacketI>(static_cast<int64_t>(1023));
|
||||
// sqrt(0.5) IEEE bits — used to reconstruct f from biased mantissa.
|
||||
const PacketI cst_half_mant = pset1<PacketI>(static_cast<int64_t>(0x3FE6A09E667F3BCDLL));
|
||||
|
||||
// Reinterpret v as a 64-bit integer vector.
|
||||
PacketI vi = preinterpret<PacketI>(v);
|
||||
|
||||
// Normalise denormals: multiply by 2^52 and correct the exponent by -52.
|
||||
PacketI is_denormal = pcmp_lt(vi, cst_min_normal);
|
||||
// 2^52 via bit pattern: biased exponent = 52 + 1023 = 0x433, mantissa = 0.
|
||||
Packet v_norm = pmul(v, pset1frombits<Packet>(static_cast<uint64_t>(int64_t(52 + 0x3ff) << 52)));
|
||||
vi = pselect(is_denormal, preinterpret<PacketI>(v_norm), vi);
|
||||
PacketI denorm_adj = pand(is_denormal, pset1<PacketI>(static_cast<int64_t>(52)));
|
||||
|
||||
// Bias the integer representation so the exponent field directly encodes
|
||||
// the half-octave index.
|
||||
PacketI vi_biased = padd(vi, cst_sqrt_half_offset);
|
||||
// Extract unbiased exponent: shift out mantissa bits, subtract IEEE bias
|
||||
// and denormal adjustment.
|
||||
PacketI e_int = psub(psub(plogical_shift_right<52>(vi_biased), cst_exp_bias), denorm_adj);
|
||||
// Convert integer exponent to floating-point.
|
||||
e = pcast<PacketI, Packet>(e_int);
|
||||
|
||||
// Reconstruct mantissa in [sqrt(0.5), sqrt(2)) via integer arithmetic.
|
||||
// The integer addition of the masked mantissa bits and the sqrt(0.5) bit
|
||||
// pattern carries into the exponent field, yielding a value in that range.
|
||||
// Then subtract 1 to centre on 0: f in [sqrt(0.5)-1, sqrt(2)-1].
|
||||
f = psub(preinterpret<Packet>(padd(pand(vi_biased, cst_mant_mask), cst_half_mant)), pset1<Packet>(1.0));
|
||||
}
|
||||
};
|
||||
|
||||
// Core range reduction and polynomial for double logarithm.
|
||||
// Input: v > 0 (zero / negative / inf / nan are handled by the caller).
|
||||
// Output: log_mantissa ≈ log(mantissa of v in [sqrt(0.5), sqrt(2))),
|
||||
// e = unbiased exponent of v as a double.
|
||||
// Selects the fast integer path when integer_packet is available, otherwise
|
||||
// falls back to pfrexp.
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE void plog_core_double(const Packet v, Packet& log_mantissa, Packet& e) {
|
||||
Packet f;
|
||||
plog_range_reduce_double<Packet, packet_has_integer_packet<Packet>::value>::run(v, f, e);
|
||||
log_mantissa = plog_mantissa_double(f);
|
||||
}
|
||||
|
||||
/* Returns the base e (2.718...) or base 2 logarithm of x.
|
||||
* The argument is separated into its exponent and fractional parts.
|
||||
* The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
|
||||
@@ -152,87 +286,29 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_float(const Pac
|
||||
*/
|
||||
template <typename Packet, bool base2>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_double(const Packet _x) {
|
||||
Packet x = _x;
|
||||
|
||||
const Packet cst_1 = pset1<Packet>(1.0);
|
||||
const Packet cst_neg_half = pset1<Packet>(-0.5);
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
|
||||
// Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x)
|
||||
// 1/sqrt(2) <= x < sqrt(2)
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
||||
Packet log_mantissa, e;
|
||||
plog_core_double(_x, log_mantissa, e);
|
||||
|
||||
const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
|
||||
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
|
||||
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
|
||||
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
|
||||
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
|
||||
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
|
||||
|
||||
Packet e;
|
||||
// extract significant in the range [0.5,1) and exponent
|
||||
x = pfrexp(x, e);
|
||||
|
||||
// Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
|
||||
// and shift by -1. The values are then centered around 0, which improves
|
||||
// the stability of the polynomial evaluation.
|
||||
// if( x < SQRTHF ) {
|
||||
// e -= 1;
|
||||
// x = x + x - 1.0;
|
||||
// } else { x = x - 1.0; }
|
||||
Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(x, mask);
|
||||
x = psub(x, cst_1);
|
||||
e = psub(e, pand(cst_1, mask));
|
||||
x = padd(x, tmp);
|
||||
|
||||
Packet x2 = pmul(x, x);
|
||||
Packet x3 = pmul(x2, x);
|
||||
|
||||
// Evaluate the polynomial in factored form for better instruction-level parallelism.
|
||||
// y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
|
||||
Packet y, y1, y_;
|
||||
y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
|
||||
y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
|
||||
y = pmadd(y, x, cst_cephes_log_p2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_p5);
|
||||
y_ = pmadd(y, x3, y1);
|
||||
|
||||
y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1);
|
||||
y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
|
||||
y = pmadd(y, x, cst_cephes_log_q2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_q5);
|
||||
y = pmadd(y, x3, y1);
|
||||
|
||||
y_ = pmul(y_, x3);
|
||||
y = pdiv(y_, y);
|
||||
|
||||
y = pmadd(cst_neg_half, x2, y);
|
||||
x = padd(x, y);
|
||||
|
||||
// Add the logarithm of the exponent back to the result of the interpolation.
|
||||
// Combine: log(x) = e * ln2 + log(mantissa), or log2(x) = log(mantissa)*log2e + e.
|
||||
Packet x;
|
||||
if (base2) {
|
||||
const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
|
||||
x = pmadd(x, cst_log2e, e);
|
||||
x = pmadd(log_mantissa, cst_log2e, e);
|
||||
} else {
|
||||
const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
|
||||
x = pmadd(e, cst_ln2, x);
|
||||
x = pmadd(e, cst_ln2, log_mantissa);
|
||||
}
|
||||
|
||||
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
|
||||
Packet iszero_mask = pcmp_eq(_x, pzero(_x));
|
||||
Packet pos_inf_mask = pcmp_eq(_x, cst_pos_inf);
|
||||
// Filter out invalid inputs, i.e.:
|
||||
// - negative arg will be NAN
|
||||
// - 0 will be -INF
|
||||
// - +INF will be +INF
|
||||
// Filter out invalid inputs:
|
||||
// - negative arg → NAN
|
||||
// - 0 → -INF
|
||||
// - +INF → +INF
|
||||
return pselect(iszero_mask, cst_minus_inf, por(pselect(pos_inf_mask, cst_pos_inf, x), invalid_mask));
|
||||
}
|
||||
|
||||
@@ -286,8 +362,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(c
|
||||
return result;
|
||||
}
|
||||
|
||||
/** \internal \returns log(1 + x) for double precision float.
|
||||
Same direct approach as the float version.
|
||||
/** \internal \returns log(1 + x) for double precision.
|
||||
Computes log(1+x) using plog_core_double for the core range reduction and
|
||||
polynomial evaluation. The rounding error from forming u = fl(1+x) is
|
||||
recovered as dx = x - (u - 1) and folded in as a first-order correction
|
||||
dx/u after the polynomial evaluation.
|
||||
*/
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) {
|
||||
@@ -295,67 +374,31 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
|
||||
// u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1).
|
||||
Packet u = padd(one, x);
|
||||
Packet dx = psub(x, psub(u, one));
|
||||
|
||||
// For |x| tiny enough that u rounds to 1, return x directly.
|
||||
Packet small_mask = pcmp_eq(u, one);
|
||||
// For u = +inf (x very large), return +inf.
|
||||
Packet inf_mask = pcmp_eq(u, cst_pos_inf);
|
||||
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
||||
Packet e;
|
||||
Packet m = pfrexp(u, e);
|
||||
Packet mask = pcmp_lt(m, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(m, mask);
|
||||
m = psub(m, one);
|
||||
e = psub(e, pand(one, mask));
|
||||
m = padd(m, tmp);
|
||||
// Core range reduction and polynomial on u.
|
||||
Packet log_u, e;
|
||||
plog_core_double(u, log_u, e);
|
||||
|
||||
// Same polynomial as plog_double.
|
||||
const Packet cst_neg_half = pset1<Packet>(-0.5);
|
||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
||||
const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
|
||||
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
|
||||
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
|
||||
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
|
||||
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
|
||||
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
|
||||
|
||||
Packet m2 = pmul(m, m);
|
||||
Packet m3 = pmul(m2, m);
|
||||
|
||||
Packet y, y1, y_;
|
||||
y = pmadd(cst_cephes_log_p0, m, cst_cephes_log_p1);
|
||||
y1 = pmadd(cst_cephes_log_p3, m, cst_cephes_log_p4);
|
||||
y = pmadd(y, m, cst_cephes_log_p2);
|
||||
y1 = pmadd(y1, m, cst_cephes_log_p5);
|
||||
y_ = pmadd(y, m3, y1);
|
||||
|
||||
y = pmadd(cst_cephes_log_q0, m, cst_cephes_log_q1);
|
||||
y1 = pmadd(cst_cephes_log_q3, m, cst_cephes_log_q4);
|
||||
y = pmadd(y, m, cst_cephes_log_q2);
|
||||
y1 = pmadd(y1, m, cst_cephes_log_q5);
|
||||
y = pmadd(y, m3, y1);
|
||||
|
||||
y_ = pmul(y_, m3);
|
||||
Packet log_m = pdiv(y_, y);
|
||||
log_m = pmadd(cst_neg_half, m2, log_m);
|
||||
log_m = padd(m, log_m);
|
||||
|
||||
// result = e * ln2 + log(m) + dx/u.
|
||||
// result = e * ln2 + log(u) + dx/u.
|
||||
// The dx/u term corrects for the rounding error in u = fl(1+x).
|
||||
const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
|
||||
Packet result = pmadd(e, cst_ln2, padd(log_m, pdiv(dx, u)));
|
||||
Packet result = pmadd(e, cst_ln2, padd(log_u, pdiv(dx, u)));
|
||||
|
||||
// Handle special cases.
|
||||
Packet neg_mask = pcmp_lt(u, pzero(u));
|
||||
Packet zero_mask = pcmp_eq(x, pset1<Packet>(-1.0));
|
||||
result = pselect(small_mask, x, result);
|
||||
result = pselect(inf_mask, cst_pos_inf, result);
|
||||
result = pselect(zero_mask, cst_minus_inf, result);
|
||||
result = por(neg_mask, result);
|
||||
result = por(neg_mask, result); // NaN for x < -1
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user