mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Compare commits
1 Commits
master
...
revert-b1d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
111c4d23a9 |
@@ -141,158 +141,69 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2_float(const Pac
|
|||||||
return plog_impl_float<Packet, /* base2 */ true>(_x);
|
return plog_impl_float<Packet, /* base2 */ true>(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// Core range reduction and polynomial evaluation for double logarithm.
|
||||||
// Double logarithm: shared polynomial + two range-reduction backends
|
//
|
||||||
// -----------------------------------------------------------------------
|
// Same structure as plog_core_float but for double precision.
|
||||||
|
// Given a positive double v (may be denormal), decomposes it as
|
||||||
// Cephes rational-polynomial approximation of log(1+f) for
|
// v = 2^e * (1+f) with f in [sqrt(0.5)-1, sqrt(2)-1], then evaluates
|
||||||
// f in [sqrt(0.5)-1, sqrt(2)-1].
|
// log(1+f) ≈ f - 0.5*f^2 + f^3 * P(f)/Q(f) using the Cephes [5/5]
|
||||||
// Evaluates x - 0.5*x^2 + x^3 * P(x)/Q(x) where P and Q are degree-5.
|
// rational approximation.
|
||||||
// See: http://www.netlib.org/cephes/
|
|
||||||
template <typename Packet>
|
|
||||||
EIGEN_STRONG_INLINE Packet plog_mantissa_double(const Packet x) {
|
|
||||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
|
||||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
|
||||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
|
||||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
|
||||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
|
||||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
|
||||||
// Q0 = 1.0; pmadd(1, x, q1) simplifies to padd(x, q1).
|
|
||||||
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
|
|
||||||
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
|
|
||||||
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
|
|
||||||
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
|
|
||||||
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
|
|
||||||
|
|
||||||
Packet x2 = pmul(x, x);
|
|
||||||
Packet x3 = pmul(x2, x);
|
|
||||||
|
|
||||||
// Evaluate P and Q simultaneously for better ILP.
|
|
||||||
Packet y, y1, y_;
|
|
||||||
y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
|
|
||||||
y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
|
|
||||||
y = pmadd(y, x, cst_cephes_log_p2);
|
|
||||||
y1 = pmadd(y1, x, cst_cephes_log_p5);
|
|
||||||
y_ = pmadd(y, x3, y1);
|
|
||||||
|
|
||||||
y = padd(x, cst_cephes_log_q1);
|
|
||||||
y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
|
|
||||||
y = pmadd(y, x, cst_cephes_log_q2);
|
|
||||||
y1 = pmadd(y1, x, cst_cephes_log_q5);
|
|
||||||
y = pmadd(y, x3, y1);
|
|
||||||
|
|
||||||
y_ = pmul(y_, x3);
|
|
||||||
y = pdiv(y_, y);
|
|
||||||
y = pnmadd(pset1<Packet>(0.5), x2, y);
|
|
||||||
return padd(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect whether unpacket_traits<Packet>::integer_packet is defined.
|
|
||||||
template <typename Packet, typename = void>
|
|
||||||
struct packet_has_integer_packet : std::false_type {};
|
|
||||||
template <typename Packet>
|
|
||||||
struct packet_has_integer_packet<Packet, void_t<typename unpacket_traits<Packet>::integer_packet>> : std::true_type {};
|
|
||||||
|
|
||||||
// Dispatch struct for double-precision range reduction.
|
|
||||||
// Primary template: pfrexp-based fallback (used when integer_packet is absent).
|
|
||||||
template <typename Packet, bool UseIntegerPacket>
|
|
||||||
struct plog_range_reduce_double {
|
|
||||||
EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) {
|
|
||||||
const Packet one = pset1<Packet>(1.0);
|
|
||||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
|
||||||
// pfrexp: f in [0.5, 1), e = unbiased exponent as double.
|
|
||||||
f = pfrexp(v, e);
|
|
||||||
// Shift [0.5,1) -> [sqrt(0.5)-1, sqrt(2)-1] with exponent correction:
|
|
||||||
// if f < sqrt(0.5): f = f + f - 1, e -= 1 (giving f in [0, sqrt(2)-1))
|
|
||||||
// else: f = f - 1 (giving f in [sqrt(0.5)-1, 0))
|
|
||||||
Packet mask = pcmp_lt(f, cst_cephes_SQRTHF);
|
|
||||||
Packet tmp = pand(f, mask);
|
|
||||||
f = psub(f, one);
|
|
||||||
e = psub(e, pand(one, mask));
|
|
||||||
f = padd(f, tmp);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Specialisation: fast integer-bit-manipulation path (musl-inspired).
|
|
||||||
// Requires unpacket_traits<Packet>::integer_packet to be a 64-bit integer packet.
|
|
||||||
template <typename Packet>
|
|
||||||
struct plog_range_reduce_double<Packet, true> {
|
|
||||||
EIGEN_STRONG_INLINE static void run(const Packet v, Packet& f, Packet& e) {
|
|
||||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
|
||||||
// 2^-1022: smallest positive normal double.
|
|
||||||
const PacketI cst_min_normal = pset1<PacketI>(static_cast<int64_t>(0x0010000000000000LL));
|
|
||||||
// Lower 52-bit mask (IEEE mantissa field).
|
|
||||||
const PacketI cst_mant_mask = pset1<PacketI>(static_cast<int64_t>(0x000FFFFFFFFFFFFFLL));
|
|
||||||
// Offset = 1.0_bits - sqrt(0.5)_bits. Adding this to the integer
|
|
||||||
// representation shifts the exponent field so that the [sqrt(0.5), sqrt(2))
|
|
||||||
// half-octave boundary falls on an exact biased-exponent boundary, letting
|
|
||||||
// us extract e with a single right shift. The constant is:
|
|
||||||
// 0x3FF0000000000000 - 0x3FE6A09E667F3BCD = 0x00095F619980C433
|
|
||||||
const PacketI cst_sqrt_half_offset =
|
|
||||||
pset1<PacketI>(static_cast<int64_t>(0x3FF0000000000000LL - 0x3FE6A09E667F3BCDLL));
|
|
||||||
// IEEE double exponent bias (1023).
|
|
||||||
const PacketI cst_exp_bias = pset1<PacketI>(static_cast<int64_t>(1023));
|
|
||||||
// sqrt(0.5) IEEE bits — used to reconstruct f from biased mantissa.
|
|
||||||
const PacketI cst_half_mant = pset1<PacketI>(static_cast<int64_t>(0x3FE6A09E667F3BCDLL));
|
|
||||||
|
|
||||||
// Reinterpret v as a 64-bit integer vector.
|
|
||||||
PacketI vi = preinterpret<PacketI>(v);
|
|
||||||
|
|
||||||
// Normalise denormals: multiply by 2^52 and correct the exponent by -52.
|
|
||||||
PacketI is_denormal = pcmp_lt(vi, cst_min_normal);
|
|
||||||
// 2^52 via bit pattern: biased exponent = 52 + 1023 = 0x433, mantissa = 0.
|
|
||||||
Packet v_norm = pmul(v, pset1frombits<Packet>(static_cast<uint64_t>(int64_t(52 + 0x3ff) << 52)));
|
|
||||||
vi = pselect(is_denormal, preinterpret<PacketI>(v_norm), vi);
|
|
||||||
PacketI denorm_adj = pand(is_denormal, pset1<PacketI>(static_cast<int64_t>(52)));
|
|
||||||
|
|
||||||
// Bias the integer representation so the exponent field directly encodes
|
|
||||||
// the half-octave index.
|
|
||||||
PacketI vi_biased = padd(vi, cst_sqrt_half_offset);
|
|
||||||
// Extract unbiased exponent: shift out mantissa bits, subtract IEEE bias
|
|
||||||
// and denormal adjustment.
|
|
||||||
PacketI e_int = psub(psub(plogical_shift_right<52>(vi_biased), cst_exp_bias), denorm_adj);
|
|
||||||
// Convert integer exponent to floating-point.
|
|
||||||
e = pcast<PacketI, Packet>(e_int);
|
|
||||||
|
|
||||||
// Reconstruct mantissa in [sqrt(0.5), sqrt(2)) via integer arithmetic.
|
|
||||||
// The integer addition of the masked mantissa bits and the sqrt(0.5) bit
|
|
||||||
// pattern carries into the exponent field, yielding a value in that range.
|
|
||||||
// Then subtract 1 to centre on 0: f in [sqrt(0.5)-1, sqrt(2)-1].
|
|
||||||
f = psub(preinterpret<Packet>(padd(pand(vi_biased, cst_mant_mask), cst_half_mant)), pset1<Packet>(1.0));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Core range reduction and polynomial for double logarithm.
|
|
||||||
// Input: v > 0 (zero / negative / inf / nan are handled by the caller).
|
|
||||||
// Output: log_mantissa ≈ log(mantissa of v in [sqrt(0.5), sqrt(2))),
|
|
||||||
// e = unbiased exponent of v as a double.
|
|
||||||
// Selects the fast integer path when integer_packet is available, otherwise
|
|
||||||
// falls back to pfrexp.
|
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_STRONG_INLINE void plog_core_double(const Packet v, Packet& log_mantissa, Packet& e) {
|
EIGEN_STRONG_INLINE void plog_core_double(const Packet v, Packet& log_mantissa, Packet& e) {
|
||||||
Packet f;
|
typedef typename unpacket_traits<Packet>::integer_packet PacketL;
|
||||||
plog_range_reduce_double<Packet, packet_has_integer_packet<Packet>::value>::run(v, f, e);
|
|
||||||
log_mantissa = plog_mantissa_double(f);
|
const PacketL cst_min_normal = pset1<PacketL>(int64_t(0x0010000000000000LL));
|
||||||
|
const PacketL cst_mant_mask = pset1<PacketL>(int64_t(0x000fffffffffffffLL));
|
||||||
|
const PacketL cst_sqrt_half_offset = pset1<PacketL>(int64_t(0x00095f619980c433LL));
|
||||||
|
const PacketL cst_exp_bias = pset1<PacketL>(int64_t(0x3ff)); // 1023
|
||||||
|
const PacketL cst_half_mant = pset1<PacketL>(int64_t(0x3fe6a09e667f3bcdLL)); // sqrt(0.5)
|
||||||
|
|
||||||
|
// Normalize denormals by multiplying by 2^52.
|
||||||
|
PacketL vi = preinterpret<PacketL>(v);
|
||||||
|
PacketL is_denormal = pcmp_lt(vi, cst_min_normal);
|
||||||
|
Packet v_normalized = pmul(v, pset1<Packet>(4503599627370496.0)); // 2^52
|
||||||
|
vi = pselect(is_denormal, preinterpret<PacketL>(v_normalized), vi);
|
||||||
|
PacketL denorm_adj = pand(is_denormal, pset1<PacketL>(int64_t(52)));
|
||||||
|
|
||||||
|
// Combined range reduction via integer bias (same trick as float version).
|
||||||
|
PacketL vi_biased = padd(vi, cst_sqrt_half_offset);
|
||||||
|
PacketL e_int = psub(psub(plogical_shift_right<52>(vi_biased), cst_exp_bias), denorm_adj);
|
||||||
|
e = pcast<PacketL, Packet>(e_int);
|
||||||
|
Packet f = psub(preinterpret<Packet>(padd(pand(vi_biased, cst_mant_mask), cst_half_mant)), pset1<Packet>(1.0));
|
||||||
|
|
||||||
|
// Rational approximation log(1+f) = f - 0.5*f^2 + f^3 * P(f)/Q(f)
|
||||||
|
// from Cephes, [5/5] rational on [sqrt(0.5)-1, sqrt(2)-1].
|
||||||
|
Packet f2 = pmul(f, f);
|
||||||
|
Packet f3 = pmul(f2, f);
|
||||||
|
|
||||||
|
// Evaluate P and Q in factored form for instruction-level parallelism.
|
||||||
|
Packet y, y1, y_;
|
||||||
|
y = pmadd(pset1<Packet>(1.01875663804580931796E-4), f, pset1<Packet>(4.97494994976747001425E-1));
|
||||||
|
y1 = pmadd(pset1<Packet>(1.44989225341610930846E1), f, pset1<Packet>(1.79368678507819816313E1));
|
||||||
|
y = pmadd(y, f, pset1<Packet>(4.70579119878881725854E0));
|
||||||
|
y1 = pmadd(y1, f, pset1<Packet>(7.70838733755885391666E0));
|
||||||
|
y_ = pmadd(y, f3, y1);
|
||||||
|
|
||||||
|
y = pmadd(pset1<Packet>(1.0), f, pset1<Packet>(1.12873587189167450590E1));
|
||||||
|
y1 = pmadd(pset1<Packet>(8.29875266912776603211E1), f, pset1<Packet>(7.11544750618563894466E1));
|
||||||
|
y = pmadd(y, f, pset1<Packet>(4.52279145837532221105E1));
|
||||||
|
y1 = pmadd(y1, f, pset1<Packet>(2.31251620126765340583E1));
|
||||||
|
y = pmadd(y, f3, y1);
|
||||||
|
|
||||||
|
y_ = pmul(y_, f3);
|
||||||
|
y = pdiv(y_, y);
|
||||||
|
|
||||||
|
y = pmadd(pset1<Packet>(-0.5), f2, y);
|
||||||
|
log_mantissa = padd(f, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Returns the base e (2.718...) or base 2 logarithm of x.
|
// Natural or base-2 logarithm for double packets.
|
||||||
* The argument is separated into its exponent and fractional parts.
|
|
||||||
* The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
|
|
||||||
* is approximated by
|
|
||||||
*
|
|
||||||
* log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x).
|
|
||||||
*
|
|
||||||
* for more detail see: http://www.netlib.org/cephes/
|
|
||||||
*/
|
|
||||||
template <typename Packet, bool base2>
|
template <typename Packet, bool base2>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_double(const Packet _x) {
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_double(const Packet _x) {
|
||||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
|
||||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
|
||||||
|
|
||||||
Packet log_mantissa, e;
|
Packet log_mantissa, e;
|
||||||
plog_core_double(_x, log_mantissa, e);
|
plog_core_double(_x, log_mantissa, e);
|
||||||
|
|
||||||
// Combine: log(x) = e * ln2 + log(mantissa), or log2(x) = log(mantissa)*log2e + e.
|
// Add the logarithm of the exponent back to the result.
|
||||||
Packet x;
|
Packet x;
|
||||||
if (base2) {
|
if (base2) {
|
||||||
const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
|
const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
|
||||||
@@ -302,13 +213,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_impl_double(cons
|
|||||||
x = pmadd(e, cst_ln2, log_mantissa);
|
x = pmadd(e, cst_ln2, log_mantissa);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
||||||
|
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||||
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
|
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
|
||||||
Packet iszero_mask = pcmp_eq(_x, pzero(_x));
|
Packet iszero_mask = pcmp_eq(_x, pzero(_x));
|
||||||
Packet pos_inf_mask = pcmp_eq(_x, cst_pos_inf);
|
Packet pos_inf_mask = pcmp_eq(_x, cst_pos_inf);
|
||||||
// Filter out invalid inputs:
|
|
||||||
// - negative arg → NAN
|
|
||||||
// - 0 → -INF
|
|
||||||
// - +INF → +INF
|
|
||||||
return pselect(iszero_mask, cst_minus_inf, por(pselect(pos_inf_mask, cst_pos_inf, x), invalid_mask));
|
return pselect(iszero_mask, cst_minus_inf, por(pselect(pos_inf_mask, cst_pos_inf, x), invalid_mask));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,11 +271,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_float(c
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal \returns log(1 + x) for double precision.
|
/** \internal \returns log(1 + x) for double precision float.
|
||||||
Computes log(1+x) using plog_core_double for the core range reduction and
|
Computes log(1+x) using plog_core_double for the core range reduction
|
||||||
polynomial evaluation. The rounding error from forming u = fl(1+x) is
|
and polynomial evaluation. The rounding error from forming u = fl(1+x)
|
||||||
recovered as dx = x - (u - 1) and folded in as a first-order correction
|
is recovered as dx = x - (u - 1), and folded in as a first-order
|
||||||
dx/u after the polynomial evaluation.
|
correction dx/u after the polynomial evaluation.
|
||||||
*/
|
*/
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) {
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(const Packet& x) {
|
||||||
@@ -374,7 +283,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(
|
|||||||
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
const Packet cst_minus_inf = pset1frombits<Packet>(static_cast<uint64_t>(0xfff0000000000000ull));
|
||||||
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
const Packet cst_pos_inf = pset1frombits<Packet>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||||
|
|
||||||
// u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1).
|
// u = 1 + x, with rounding. Recover the lost low bits: dx = x - (u - 1).
|
||||||
Packet u = padd(one, x);
|
Packet u = padd(one, x);
|
||||||
Packet dx = psub(x, psub(u, one));
|
Packet dx = psub(x, psub(u, one));
|
||||||
|
|
||||||
@@ -398,7 +307,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_log1p_double(
|
|||||||
result = pselect(small_mask, x, result);
|
result = pselect(small_mask, x, result);
|
||||||
result = pselect(inf_mask, cst_pos_inf, result);
|
result = pselect(inf_mask, cst_pos_inf, result);
|
||||||
result = pselect(zero_mask, cst_minus_inf, result);
|
result = pselect(zero_mask, cst_minus_inf, result);
|
||||||
result = por(neg_mask, result); // NaN for x < -1
|
result = por(neg_mask, result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -233,15 +233,17 @@ static std::vector<FuncEntry<Scalar>> build_func_table() {
|
|||||||
// Range iteration helpers
|
// Range iteration helpers
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
// Advances x toward +inf by at least 1 ULP. When step_eps > 0, additionally
|
// Advances a non-negative value toward +inf by at least 1 ULP. When step_eps > 0,
|
||||||
// jumps by a relative factor of (1 + step_eps) to sample the range sparsely.
|
// additionally jumps by max(|x|, min_normal) * step_eps. For normals this is
|
||||||
|
// equivalent to x * (1 + eps). For denormals where x * eps < smallest_denormal,
|
||||||
|
// the min_normal floor ensures we still skip through the denormal region at a
|
||||||
|
// rate matching the smallest normals rather than stalling at 1 ULP per step.
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
static inline Scalar advance_by_step(Scalar x, double step_eps) {
|
static inline Scalar advance_positive(Scalar x, double step_eps) {
|
||||||
Scalar next = std::nextafter(x, std::numeric_limits<Scalar>::infinity());
|
Scalar next = std::nextafter(x, std::numeric_limits<Scalar>::infinity());
|
||||||
if (step_eps > 0.0 && std::isfinite(next)) {
|
if (step_eps > 0.0 && std::isfinite(next)) {
|
||||||
// Try to jump further by a relative amount.
|
Scalar base = std::max(next, std::numeric_limits<Scalar>::min());
|
||||||
Scalar jumped = next > 0 ? next * static_cast<Scalar>(1.0 + step_eps) : next / static_cast<Scalar>(1.0 + step_eps);
|
Scalar jumped = next + base * static_cast<Scalar>(step_eps);
|
||||||
// Use the jump only if it actually advances further (handles denormal stalling).
|
|
||||||
if (jumped > next) next = jumped;
|
if (jumped > next) next = jumped;
|
||||||
}
|
}
|
||||||
return next;
|
return next;
|
||||||
@@ -281,26 +283,60 @@ static double linear_to_scalar(int64_t lin, double /*tag*/) {
|
|||||||
// Dynamic work queue: threads atomically claim chunks for load balancing
|
// Dynamic work queue: threads atomically claim chunks for load balancing
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
|
// Work queue that distributes chunks in positive absolute-value linear space.
|
||||||
|
// Iteration goes outward from 0: the worker tests both +|x| and -|x| for
|
||||||
|
// each sampled magnitude, so the multiplicative step (1 + eps) always works
|
||||||
|
// cleanly — no special handling for negative values needed.
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
struct WorkQueue {
|
struct WorkQueue {
|
||||||
int64_t range_hi_lin;
|
int64_t range_hi_lin;
|
||||||
int64_t chunk_size;
|
int64_t chunk_size;
|
||||||
double step_eps;
|
double step_eps;
|
||||||
std::atomic<int64_t> next_lin;
|
std::atomic<int64_t> next_lin;
|
||||||
|
Scalar orig_lo; // original range for sign filtering
|
||||||
|
Scalar orig_hi;
|
||||||
|
bool test_pos; // whether any positive values are in [lo, hi]
|
||||||
|
bool test_neg; // whether any negative values are in [lo, hi]
|
||||||
|
|
||||||
void init(Scalar lo, Scalar hi, int64_t csz, double step) {
|
void init(Scalar lo, Scalar hi, int num_threads, double step) {
|
||||||
range_hi_lin = scalar_to_linear(hi);
|
orig_lo = lo;
|
||||||
chunk_size = csz;
|
orig_hi = hi;
|
||||||
|
test_pos = (hi >= Scalar(0));
|
||||||
|
test_neg = (lo < Scalar(0));
|
||||||
|
|
||||||
|
// Compute absolute-value iteration range.
|
||||||
|
Scalar abs_lo, abs_hi;
|
||||||
|
if (lo <= Scalar(0) && hi >= Scalar(0)) {
|
||||||
|
abs_lo = Scalar(0);
|
||||||
|
abs_hi = std::max(std::abs(lo), hi);
|
||||||
|
} else {
|
||||||
|
abs_lo = std::min(std::abs(lo), std::abs(hi));
|
||||||
|
abs_hi = std::max(std::abs(lo), std::abs(hi));
|
||||||
|
}
|
||||||
|
|
||||||
|
range_hi_lin = scalar_to_linear(abs_hi);
|
||||||
step_eps = step;
|
step_eps = step;
|
||||||
next_lin.store(scalar_to_linear(lo), std::memory_order_relaxed);
|
next_lin.store(scalar_to_linear(abs_lo), std::memory_order_relaxed);
|
||||||
|
|
||||||
|
uint64_t total_abs = count_scalars_in_range(abs_lo, abs_hi);
|
||||||
|
chunk_size = std::max(int64_t(1), static_cast<int64_t>(total_abs / (num_threads * 16)));
|
||||||
|
if (step > 0.0) {
|
||||||
|
// Ensure chunks are large enough that advance_positive's min_normal floor
|
||||||
|
// can actually skip the denormal region. The denormal region contains
|
||||||
|
// count_scalars_in_range(0, min_normal) ULPs; any chunk must span at
|
||||||
|
// least that many so the min_normal-based jump lands past chunk_hi.
|
||||||
|
int64_t denorm_span = static_cast<int64_t>(count_scalars_in_range(Scalar(0), std::numeric_limits<Scalar>::min()));
|
||||||
|
chunk_size = std::max(chunk_size, denorm_span);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claim the next chunk. Returns false when no work remains.
|
// Claim the next chunk of absolute values. Returns false when no work remains.
|
||||||
bool claim(Scalar& chunk_lo, Scalar& chunk_hi) {
|
bool claim(Scalar& chunk_lo, Scalar& chunk_hi) {
|
||||||
int64_t lo_lin = next_lin.fetch_add(chunk_size, std::memory_order_relaxed);
|
int64_t lo_lin = next_lin.fetch_add(chunk_size, std::memory_order_relaxed);
|
||||||
if (lo_lin > range_hi_lin) return false;
|
if (lo_lin > range_hi_lin || lo_lin < 0) return false;
|
||||||
int64_t hi_lin = lo_lin + chunk_size - 1;
|
// Compute hi_lin carefully to avoid int64_t overflow.
|
||||||
if (hi_lin > range_hi_lin) hi_lin = range_hi_lin;
|
int64_t remaining = range_hi_lin - lo_lin;
|
||||||
|
int64_t hi_lin = (remaining < chunk_size - 1) ? range_hi_lin : lo_lin + chunk_size - 1;
|
||||||
chunk_lo = linear_to_scalar(lo_lin, Scalar(0));
|
chunk_lo = linear_to_scalar(lo_lin, Scalar(0));
|
||||||
chunk_hi = linear_to_scalar(hi_lin, Scalar(0));
|
chunk_hi = linear_to_scalar(hi_lin, Scalar(0));
|
||||||
return true;
|
return true;
|
||||||
@@ -322,8 +358,12 @@ static void worker(const FuncEntry<Scalar>& func, WorkQueue<Scalar>& queue, int
|
|||||||
#ifdef EIGEN_HAS_MPFR
|
#ifdef EIGEN_HAS_MPFR
|
||||||
mpfr_t mp_in, mp_out;
|
mpfr_t mp_in, mp_out;
|
||||||
if (use_mpfr) {
|
if (use_mpfr) {
|
||||||
mpfr_init2(mp_in, 128);
|
// Use 2x the mantissa bits of Scalar for the reference: 48 for float (24-bit
|
||||||
mpfr_init2(mp_out, 128);
|
// mantissa), 106 for double (53-bit mantissa). This is sufficient for correctly-
|
||||||
|
// rounded results while keeping MPFR evaluation fast.
|
||||||
|
constexpr int kMpfrBits = std::is_same<Scalar, float>::value ? 48 : 106;
|
||||||
|
mpfr_init2(mp_in, kMpfrBits);
|
||||||
|
mpfr_init2(mp_out, kMpfrBits);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
(void)use_mpfr;
|
(void)use_mpfr;
|
||||||
@@ -348,32 +388,42 @@ static void worker(const FuncEntry<Scalar>& func, WorkQueue<Scalar>& queue, int
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto flush_batch = [&](int& idx) {
|
||||||
|
if (idx == 0) return;
|
||||||
|
for (int i = idx; i < batch_size; i++) input[i] = input[idx - 1];
|
||||||
|
func.eigen_eval(eigen_out, input);
|
||||||
|
process_batch(idx, input, eigen_out);
|
||||||
|
idx = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto push_value = [&](Scalar v, int& idx) {
|
||||||
|
input[idx++] = v;
|
||||||
|
if (idx == batch_size) flush_batch(idx);
|
||||||
|
};
|
||||||
|
|
||||||
Scalar chunk_lo, chunk_hi;
|
Scalar chunk_lo, chunk_hi;
|
||||||
while (queue.claim(chunk_lo, chunk_hi)) {
|
while (queue.claim(chunk_lo, chunk_hi)) {
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
Scalar x = chunk_lo;
|
Scalar abs_x = chunk_lo;
|
||||||
for (;;) {
|
for (;;) {
|
||||||
input[idx] = x;
|
// Test +|x| if positive values are in range.
|
||||||
idx++;
|
if (queue.test_pos && abs_x >= queue.orig_lo && abs_x <= queue.orig_hi) {
|
||||||
|
push_value(abs_x, idx);
|
||||||
if (idx == batch_size) {
|
}
|
||||||
func.eigen_eval(eigen_out, input);
|
// Test -|x| if negative values are in range (skip -0 to avoid testing 0 twice).
|
||||||
process_batch(batch_size, input, eigen_out);
|
if (queue.test_neg && abs_x != Scalar(0)) {
|
||||||
idx = 0;
|
Scalar neg_x = -abs_x;
|
||||||
|
if (neg_x >= queue.orig_lo && neg_x <= queue.orig_hi) {
|
||||||
|
push_value(neg_x, idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x >= chunk_hi) break;
|
if (abs_x >= chunk_hi) break;
|
||||||
Scalar next = advance_by_step(x, queue.step_eps);
|
Scalar next = advance_positive(abs_x, queue.step_eps);
|
||||||
x = (next > chunk_hi) ? chunk_hi : next;
|
abs_x = (next > chunk_hi) ? chunk_hi : next;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process remaining partial batch. Pad unused slots with the last valid
|
flush_batch(idx);
|
||||||
// input so the full-size vectorized eval doesn't read uninitialized memory.
|
|
||||||
if (idx > 0) {
|
|
||||||
for (int i = idx; i < batch_size; i++) input[i] = input[idx - 1];
|
|
||||||
func.eigen_eval(eigen_out, input);
|
|
||||||
process_batch(idx, input, eigen_out);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef EIGEN_HAS_MPFR
|
#ifdef EIGEN_HAS_MPFR
|
||||||
@@ -439,11 +489,12 @@ static int run_test(const Options& opts) {
|
|||||||
std::printf("Function: %s (%s)\n", opts.func_name.c_str(), kTypeName);
|
std::printf("Function: %s (%s)\n", opts.func_name.c_str(), kTypeName);
|
||||||
std::printf("Range: [%.*g, %.*g]\n", kDigits, double(lo), kDigits, double(hi));
|
std::printf("Range: [%.*g, %.*g]\n", kDigits, double(lo), kDigits, double(hi));
|
||||||
if (opts.step_eps > 0.0) {
|
if (opts.step_eps > 0.0) {
|
||||||
std::printf("Sampling step: (1 + %g) * nextafter(x)\n", opts.step_eps);
|
std::printf("Sampling step: |x| * (1 + %g)\n", opts.step_eps);
|
||||||
} else {
|
} else {
|
||||||
std::printf("Representable values in range: %lu\n", static_cast<unsigned long>(total_scalars));
|
std::printf("Representable values in range: %lu\n", static_cast<unsigned long>(total_scalars));
|
||||||
}
|
}
|
||||||
std::printf("Reference: %s\n", opts.use_mpfr ? "MPFR (128-bit)" : "std C++ math");
|
std::printf("Reference: %s\n",
|
||||||
|
opts.use_mpfr ? (opts.use_double ? "MPFR (106-bit)" : "MPFR (48-bit)") : "std C++ math");
|
||||||
std::printf("Threads: %d\n", num_threads);
|
std::printf("Threads: %d\n", num_threads);
|
||||||
std::printf("Batch size: %d\n", opts.batch_size);
|
std::printf("Batch size: %d\n", opts.batch_size);
|
||||||
std::printf("\n");
|
std::printf("\n");
|
||||||
@@ -459,13 +510,8 @@ static int run_test(const Options& opts) {
|
|||||||
results.back()->init(opts.hist_width);
|
results.back()->init(opts.hist_width);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use dynamic work distribution: threads claim small chunks from a shared
|
|
||||||
// queue. This ensures even load balancing regardless of how per-value
|
|
||||||
// work varies across the range (e.g. log on negatives is trivial).
|
|
||||||
// Choose chunk_size so we get ~16 chunks per thread for good balancing.
|
|
||||||
int64_t chunk_size = std::max(int64_t(1), static_cast<int64_t>(total_scalars / (num_threads * 16)));
|
|
||||||
WorkQueue<Scalar> queue;
|
WorkQueue<Scalar> queue;
|
||||||
queue.init(lo, hi, chunk_size, opts.step_eps);
|
queue.init(lo, hi, num_threads, opts.step_eps);
|
||||||
|
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
auto start_time = std::chrono::steady_clock::now();
|
auto start_time = std::chrono::steady_clock::now();
|
||||||
|
|||||||
Reference in New Issue
Block a user