Improve psincos_double: faster polynomials + accurate range reduction

libeigen/eigen!2389

Closes #3052

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-04-07 21:24:24 -07:00
parent 110530a4d8
commit def45c5e1e

View File

@@ -230,40 +230,31 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_float(const Pack
return psincos_float<TrigFunction::Tan>(x);
}
// Trigonometric argument reduction for double for inputs smaller than 15.
// Reduces trigonometric arguments for double inputs where x < 15. Given an argument x and its corresponding quadrant
// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t.
// Pi/2 split into 3 double-precision parts (triple-double).
// c1 + c2 + c3 = pi/2 to ~159 bits. Computed by Sollya.
// c1 = RD(pi/2), c2 = RD(pi/2 - c1), c3 = RD(pi/2 - c1 - c2).
template <typename Packet>
Packet trig_reduce_small_double(const Packet& x, const Packet& q) {
// Pi/2 split into 2 values
const Packet cst_pio2_a = pset1<Packet>(-1.570796325802803);
const Packet cst_pio2_b = pset1<Packet>(-9.920935184482005e-10);
Packet t;
t = pmadd(cst_pio2_a, q, x);
t = pmadd(cst_pio2_b, q, t);
return t;
Packet cst_pio2_1() {
return pset1<Packet>(-1.5707963267948965579989817342720925807952880859375); // -0x1.921fb54442d18p0
}
template <typename Packet>
Packet cst_pio2_2() {
return pset1<Packet>(-6.12323399573676603586882014729198302312846062338790e-17); // -0x1.1a62633145c07p-54
}
template <typename Packet>
Packet cst_pio2_3() {
return pset1<Packet>(1.4973849048591698329435081771059920083527504761695190e-33); // 0x1.f1976b7ed8fbcp-110
}
// Trigonometric argument reduction for double for inputs smaller than 1e14.
// Reduces trigonometric arguments for double inputs where x < 1e14. Given an argument x and its corresponding quadrant
// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t.
// Trigonometric argument reduction for double, small inputs (|x| < small_th).
// Reduces x to t such that x = q * pi/2 + t, where |t| <= pi/4.
// Uses a triple-double split of pi/2 with FMA for high accuracy.
template <typename Packet>
Packet trig_reduce_medium_double(const Packet& x, const Packet& q_high, const Packet& q_low) {
// Pi/2 split into 4 values
const Packet cst_pio2_a = pset1<Packet>(-1.570796325802803);
const Packet cst_pio2_b = pset1<Packet>(-9.920935184482005e-10);
const Packet cst_pio2_c = pset1<Packet>(-6.123234014771656e-17);
const Packet cst_pio2_d = pset1<Packet>(1.903488962019325e-25);
Packet trig_reduce_small_double(const Packet& x, const Packet& q) {
Packet t;
t = pmadd(cst_pio2_a, q_high, x);
t = pmadd(cst_pio2_a, q_low, t);
t = pmadd(cst_pio2_b, q_high, t);
t = pmadd(cst_pio2_b, q_low, t);
t = pmadd(cst_pio2_c, q_high, t);
t = pmadd(cst_pio2_c, q_low, t);
t = pmadd(cst_pio2_d, padd(q_low, q_high), t);
t = pmadd(cst_pio2_1<Packet>(), q, x);
t = pmadd(cst_pio2_2<Packet>(), q, t);
t = pmadd(cst_pio2_3<Packet>(), q, t);
return t;
}
@@ -284,11 +275,13 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
// If the argument is bigger than this value, use the non-vectorized std version
const double huge_th = 1e14;
const Packet cst_2oPI = pset1<Packet>(0.63661977236758134307553505349006); // 2/PI
// 2/PI as a double-word: hi + lo = 2/pi to ~107 bits. Computed by Sollya.
const Packet cst_2oPI_hi =
pset1<Packet>(0.63661977236758138243288840385503135621547698974609375); // 0x1.45f306dc9c883p-1
const Packet cst_2oPI_lo =
pset1<Packet>(-3.9357353350364971763790381828183628368294820823718866e-17); // -0x1.6b01ec5417056p-55
// Integer Packet constants
const PacketI cst_one = pset1<PacketI>(ScalarI(1));
// Constant for splitting
const Packet cst_split = pset1<Packet>(1 << 24);
Packet x_abs = pabs(x);
@@ -298,76 +291,56 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
// TODO Implement huge angle argument reduction
if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1<Packet>(small_th), x_abs)))) {
Packet q_high = pmul(pfloor(pmul(x_abs, pdiv(cst_2oPI, cst_split))), cst_split);
Packet q_low_noround = psub(pmul(x_abs, cst_2oPI), q_high);
q_int = pcast<Packet, PacketI>(padd(q_low_noround, pset1<Packet>(0.5)));
Packet q_low = pcast<PacketI, Packet>(q_int);
s = trig_reduce_medium_double(x_abs, q_high, q_low);
// Medium path: use double-word product x * (2/pi) for precise quadrant computation.
Packet prod_hi, prod_lo;
twoprod(x_abs, cst_2oPI_hi, prod_hi, prod_lo);
// Correction for 2/pi truncation: add x * lo(2/pi)
prod_lo = pmadd(x_abs, cst_2oPI_lo, prod_lo);
// Round the double-word (prod_hi, prod_lo) to the nearest integer.
Packet q = pround(prod_hi);
// Compute exact fractional part to check if rounding was correct.
Packet frac = padd(psub(prod_hi, q), prod_lo);
// Correct if fractional part crossed +-0.5 boundary.
q = padd(q, pand(pcmp_lt(pset1<Packet>(0.5), frac), pset1<Packet>(1.0)));
q = padd(q, pand(pcmp_lt(frac, pset1<Packet>(-0.5)), pset1<Packet>(-1.0)));
q_int = pcast<Packet, PacketI>(q);
s = trig_reduce_small_double(x_abs, q);
} else {
Packet qval_noround = pmul(x_abs, cst_2oPI);
// Small path: simple reduction with triple-double pi/2 split.
Packet qval_noround = pmul(x_abs, cst_2oPI_hi);
q_int = pcast<Packet, PacketI>(padd(qval_noround, pset1<Packet>(0.5)));
Packet q = pcast<PacketI, Packet>(q_int);
s = trig_reduce_small_double(x_abs, q);
}
// All the upcoming approximating polynomials have even exponents
Packet ss = pmul(s, s);
// Padé approximant of cos(x)
// Assuring < 1 ULP error on the interval [-pi/4, pi/4]
// cos(x) ~= (80737373*x^8 - 13853547000*x^6 + 727718024880*x^4 - 11275015752000*x^2 + 23594700729600)/(147173*x^8 +
// 39328920*x^6 + 5772800880*x^4 + 522334612800*x^2 + 23594700729600)
// MATLAB code to compute those coefficients:
// syms x;
// cosf = @(x) cos(x);
// pade_cosf = pade(cosf(x), x, 0, 'Order', 8)
const Packet cn4 = pset1<Packet>(80737373);
const Packet cn3 = pset1<Packet>(-13853547000);
const Packet cn2 = pset1<Packet>(727718024880);
const Packet cn1 = pset1<Packet>(-11275015752000);
const Packet cn0 = pset1<Packet>(23594700729600); // shared with cd0
const Packet cd3 = pset1<Packet>(147173);
const Packet cd2 = pset1<Packet>(39328920);
const Packet cd1 = pset1<Packet>(5772800880);
const Packet cd0 = pset1<Packet>(522334612800);
Packet sc1_num = pmadd(ss, cn4, cn3);
Packet sc2_num = pmadd(sc1_num, ss, cn2);
Packet sc3_num = pmadd(sc2_num, ss, cn1);
Packet sc4_num = pmadd(sc3_num, ss, cn0);
Packet sc1_denum = pmadd(ss, cd3, cd2);
Packet sc2_denum = pmadd(sc1_denum, ss, cd1);
Packet sc3_denum = pmadd(sc2_denum, ss, cd0);
Packet sc4_denum = pmadd(sc3_denum, ss, cn0);
Packet scos = pdiv(sc4_num, sc4_denum);
// Minimax polynomial approximation of cos(x) on [-pi/4, pi/4].
// cos(x) = 1 + u * P(u), where u = x^2 and P is degree 6 (7 FMAs total).
// Coefficients computed by Sollya fpminimax. Max polynomial error ~1.3e-19.
Packet scos = pset1<Packet>(-1.1368926065317776472832699312119132152576472805094454088248312473297119140625e-11);
scos = pmadd(scos, ss, pset1<Packet>(2.0875905481768720039634091158002593413556269297259859740734100341796875e-09));
scos = pmadd(scos, ss, pset1<Packet>(-2.7557315712466412785356544880299711763882442028261721134185791015625e-07));
scos = pmadd(scos, ss, pset1<Packet>(2.480158729424286522739599714082459058772656135261058807373046875e-05));
scos = pmadd(scos, ss, pset1<Packet>(-1.388888888888178789471350427220386336557567119598388671875e-03));
scos = pmadd(scos, ss, pset1<Packet>(4.166666666666664353702032030923874117434024810791015625e-02));
scos = pmadd(scos, ss, pset1<Packet>(-0.5));
scos = pmadd(scos, ss, pset1<Packet>(1.0));
// Padé approximant of sin(x)
// Assuring < 1 ULP error on the interval [-pi/4, pi/4]
// sin(x) ~= (x*(4585922449*x^8 - 1066023933480*x^6 + 83284044283440*x^4 - 2303682236856000*x^2 +
// 15605159573203200))/(45*(1029037*x^8 + 345207016*x^6 + 61570292784*x^4 + 6603948711360*x^2 + 346781323848960))
// MATLAB code to compute those coefficients:
// syms x;
// sinf = @(x) sin(x);
// pade_sinf = pade(sinf(x), x, 0, 'Order', 8, 'OrderMode', 'relative')
const Packet sn4 = pset1<Packet>(4585922449);
const Packet sn3 = pset1<Packet>(-1066023933480);
const Packet sn2 = pset1<Packet>(83284044283440);
const Packet sn1 = pset1<Packet>(-2303682236856000);
const Packet sn0 = pset1<Packet>(15605159573203200);
const Packet sd3 = pset1<Packet>(1029037);
const Packet sd2 = pset1<Packet>(345207016);
const Packet sd1 = pset1<Packet>(61570292784);
const Packet sd0_inner = pset1<Packet>(6603948711360);
const Packet sd0 = pset1<Packet>(346781323848960);
const Packet cst_45 = pset1<Packet>(45);
Packet ss1_num = pmadd(ss, sn4, sn3);
Packet ss2_num = pmadd(ss1_num, ss, sn2);
Packet ss3_num = pmadd(ss2_num, ss, sn1);
Packet ss4_num = pmadd(ss3_num, ss, sn0);
Packet ss1_denum = pmadd(ss, sd3, sd2);
Packet ss2_denum = pmadd(ss1_denum, ss, sd1);
Packet ss3_denum = pmadd(ss2_denum, ss, sd0_inner);
Packet ss4_denum = pmadd(ss3_denum, ss, sd0);
Packet ssin = pdiv(pmul(s, ss4_num), pmul(cst_45, ss4_denum));
// Minimax polynomial approximation of sin(x) on [-pi/4, pi/4].
// sin(x) = x * (1 + u * R(u)), where u = x^2 and R is degree 5.
// Computed as: x + x * u * R(u) (6 FMAs + 1 mul).
// Coefficients computed by Sollya fpminimax. Max polynomial error ~1.0e-17.
Packet ssin = pset1<Packet>(1.59193066075142890698150587293845624470289834562208852730691432952880859375e-10);
ssin = pmadd(ssin, ss, pset1<Packet>(-2.50511517945670206974594627392927126408039839589037001132965087890625e-08));
ssin = pmadd(ssin, ss, pset1<Packet>(2.755731622544328228235042954619160582296899519860744476318359375e-06));
ssin = pmadd(ssin, ss, pset1<Packet>(-1.9841269837089632013978068858506276228581555187702178955078125e-04));
ssin = pmadd(ssin, ss, pset1<Packet>(8.333333333331312264835588621281203813850879669189453125e-03));
ssin = pmadd(ssin, ss, pset1<Packet>(-0.1666666666666666574148081281236954964697360992431640625));
ssin = pmul(ssin, ss);
ssin = pmadd(ssin, s, s);
Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(q_int, cst_one), pzero(q_int)));