diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathComplex.h b/Eigen/src/Core/arch/Default/GenericPacketMathComplex.h new file mode 100644 index 000000000..6363d9ec1 --- /dev/null +++ b/Eigen/src/Core/arch/Default/GenericPacketMathComplex.h @@ -0,0 +1,282 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2009-2019 Gael Guennebaud +// Copyright (C) 2018-2025 Rasmus Munk Larsen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H +#define EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +//---------------------------------------------------------------------- +// Complex Arithmetic and Functions +//---------------------------------------------------------------------- + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) { + typedef typename unpacket_traits::as_real RealPacket; + typedef typename unpacket_traits::type RealScalar; + // In the following we annotate the code for the case where the inputs + // are a pair length-2 SIMD vectors representing a single pair of complex + // numbers x = a + i*b, y = c + i*d. + const RealPacket one = pset1(RealScalar(1)); + const RealPacket y_flip = pcplxflip(y).v; + // We need to avoid dividing by Inf/Inf, so use a mask to carefully + // apply the scale. + const RealPacket mask = pcmp_lt(pabs(y.v), pabs(y_flip)); // |c| < |d| + const RealPacket y_scaled = pselect(mask, pdiv(y.v, y_flip), one); + RealPacket denom = pmul(y.v, y_scaled); + denom = padd(denom, pcplxflip(Packet(denom)).v); // c * c' + d * d' + Packet num = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c + return Packet(pdiv(num.v, denom)); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pmul_complex(const Packet& x, const Packet& y) { + // In the following we annotate the code for the case where the inputs + // are a pair length-2 SIMD vectors representing a single pair of complex + // numbers x = a + i*b, y = c + i*d. + Packet x_re = pdupreal(x); // a, a + Packet x_im = pdupimag(x); // b, b + Packet tmp_re = Packet(pmul(x_re.v, y.v)); // a*c, a*d + Packet tmp_im = Packet(pmul(x_im.v, y.v)); // b*c, b*d + tmp_im = pcplxflip(pconj(tmp_im)); // -b*d, d*c + return padd(tmp_im, tmp_re); // a*c - b*d, a*d + b*c +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Real part + RealPacket x_flip = pcplxflip(x).v; // b, a + Packet x_norm = phypot_complex(x); // sqrt(a^2 + b^2), sqrt(a^2 + b^2) + RealPacket xlogr = plog(x_norm.v); // log(sqrt(a^2 + b^2)), log(sqrt(a^2 + b^2)) + + // Imag part + RealPacket ximg = patan2(x.v, x_flip); // atan2(a, b), atan2(b, a) + + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + RealPacket x_abs = pabs(x.v); + RealPacket is_x_pos_inf = pcmp_eq(x_abs, cst_pos_inf); + RealPacket is_y_pos_inf = pcplxflip(Packet(is_x_pos_inf)).v; + RealPacket is_any_inf = por(is_x_pos_inf, is_y_pos_inf); + RealPacket xreal = pselect(is_any_inf, cst_pos_inf, xlogr); + + return Packet(pselect(peven_mask(xreal), xreal, ximg)); // log(sqrt(a^2 + b^2)), atan2(b, a) +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) { + typedef typename unpacket_traits::as_real RealPacket; + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + const RealPacket even_mask = peven_mask(a.v); + const RealPacket odd_mask = pcplxflip(Packet(even_mask)).v; + + // Let a = x + iy. + // exp(a) = exp(x) * cis(y), plus some special edge-case handling. + + // exp(x): + RealPacket x = pand(a.v, even_mask); + x = por(x, pcplxflip(Packet(x)).v); + RealPacket expx = pexp(x); // exp(x); + + // cis(y): + RealPacket y = pand(odd_mask, a.v); + y = por(y, pcplxflip(Packet(y)).v); + RealPacket cisy = psincos_selector(y); + cisy = pcplxflip(Packet(cisy)).v; // cos(y) + i * sin(y) + + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + const RealPacket cst_neg_inf = pset1(-NumTraits::infinity()); + + // If x is -inf, we know that cossin(y) is bounded, + // so the result is (0, +/-0), where the sign of the imaginary part comes + // from the sign of cossin(y). + RealPacket cisy_sign = por(pandnot(cisy, pabs(cisy)), pset1(RealScalar(1))); + cisy = pselect(pcmp_eq(x, cst_neg_inf), cisy_sign, cisy); + + // If x is inf, and cos(y) has unknown sign (y is inf or NaN), the result + // is (+/-inf, NaN), where the signs are undetermined (take the sign of y). + RealPacket y_sign = por(pandnot(y, pabs(y)), pset1(RealScalar(1))); + cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy); + + // If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to + // prevent inf * 0 = NaN. The vectorized sincos may compute exact zero + // for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct. + // The y=0 case is handled separately below. + RealPacket cisy_sign_one = por(pand(cisy, pset1(RealScalar(-0.0))), pset1(RealScalar(1))); + RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf)); + cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy); + + Packet result = Packet(pmul(expx, cisy)); + + // If y is +/- 0, the input is real, so take the real result for consistency. + result = pselect(Packet(pcmp_eq(y, pzero(y))), Packet(por(pand(expx, even_mask), pand(y, odd_mask))), result); + + return result; +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Computes the principal sqrt of the complex numbers in the input. + // + // For example, for packets containing 2 complex numbers stored in interleaved format + // a = [a0, a1] = [x0, y0, x1, y1], + // where x0 = real(a0), y0 = imag(a0) etc., this function returns + // b = [b0, b1] = [u0, v0, u1, v1], + // such that b0^2 = a0, b1^2 = a1. + // + // To derive the formula for the complex square roots, let's consider the equation for + // a single complex square root of the number x + i*y. We want to find real numbers + // u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = 0.5 * (y / u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = 0.5 * (y / v) + // + // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as + // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , + + // In the following, without lack of generality, we have annotated the code, assuming + // that the input is a packet of 2 complex numbers. + // + // Step 1. Compute l = [l0, l0, l1, l1], where + // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), + // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. + + RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] + RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|] + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero. + l = pselect(a_min_zero_mask, a_max, l); + + // Step 2. Compute [rho0, *, rho1, *], where + // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) + // We don't care about the imaginary parts computed here. They will be overwritten later. + const RealPacket cst_half = pset1(RealScalar(0.5)); + Packet rho; + rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); + + // Step 3. Compute [rho0, eta0, rho1, eta1], where + // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. + // set eta = 0 of input is 0 + i0. + RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); + RealPacket real_mask = peven_mask(a.v); + Packet positive_real_result; + // Compute result for inputs with positive real part. + positive_real_result.v = pselect(real_mask, rho.v, eta); + + // Step 4. Compute solution for inputs with negative real part: + // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] + const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), RealScalar(-0.0))).v; + RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); + Packet negative_real_result; + // Notice that rho is positive, so taking its absolute value is a noop. + negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); + + // Step 5. Select solution branch based on the sign of the real parts. + Packet negative_real_mask; + negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); + negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); + Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); + + // Step 6. Handle special cases for infinities: + // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN + // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN + // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y + // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + Packet is_inf; + is_inf.v = pcmp_eq(a_abs, cst_pos_inf); + Packet is_real_inf; + is_real_inf.v = pand(is_inf.v, real_mask); + is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); + // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. + Packet real_inf_result; + real_inf_result.v = pmul(a_abs, pset1(Scalar(RealScalar(1.0), RealScalar(0.0))).v); + real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); + // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. + Packet is_imag_inf; + is_imag_inf.v = pandnot(is_inf.v, real_mask); + is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); + Packet imag_inf_result; + imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); + // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan + Packet result_is_nan = pisnan(result); + result = por(result_is_nan, result); + + return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); +} + +// \internal \returns the norm of a complex number z = x + i*y, defined as sqrt(x^2 + y^2). +// Implemented using the hypot(a,b) algorithm from https://doi.org/10.48550/arXiv.1904.09481 +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet phypot_complex(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + const RealPacket cst_zero_rp = pset1(static_cast(0.0)); + const RealPacket cst_minus_one_rp = pset1(static_cast(-1.0)); + const RealPacket cst_two_rp = pset1(static_cast(2.0)); + const RealPacket evenmask = peven_mask(a.v); + + RealPacket a_abs = pabs(a.v); + RealPacket a_flip = pcplxflip(Packet(a_abs)).v; // |b|, |a| + RealPacket a_all = pselect(evenmask, a_abs, a_flip); // |a|, |a| + RealPacket b_all = pselect(evenmask, a_flip, a_abs); // |b|, |b| + + RealPacket a2 = pmul(a.v, a.v); // |a^2, b^2| + RealPacket a2_flip = pcplxflip(Packet(a2)).v; // |b^2, a^2| + RealPacket h = psqrt(padd(a2, a2_flip)); // |sqrt(a^2 + b^2), sqrt(a^2 + b^2)| + RealPacket h_sq = pmul(h, h); // |a^2 + b^2, a^2 + b^2| + RealPacket a_sq = pselect(evenmask, a2, a2_flip); // |a^2, a^2| + RealPacket m_h_sq = pmul(h_sq, cst_minus_one_rp); + RealPacket m_a_sq = pmul(a_sq, cst_minus_one_rp); + RealPacket x = psub(psub(pmadd(h, h, m_h_sq), pmadd(b_all, b_all, psub(a_sq, h_sq))), pmadd(a_all, a_all, m_a_sq)); + h = psub(h, pdiv(x, pmul(cst_two_rp, h))); // |h - x/(2*h), h - x/(2*h)| + + // handle zero-case + RealPacket iszero = pcmp_eq(por(a_abs, a_flip), cst_zero_rp); + + h = pandnot(h, iszero); // |sqrt(a^2+b^2), sqrt(a^2+b^2)| + return Packet(h); // |sqrt(a^2+b^2), sqrt(a^2+b^2)| +} + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 7fdf961d2..2ab1847fd 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -26,146 +26,6 @@ namespace Eigen { namespace internal { -//---------------------------------------------------------------------- -// Cubic Root Functions -//---------------------------------------------------------------------- - -// This function implements a single step of Halley's iteration for -// computing x = y^(1/3): -// x_{k+1} = x_k - (x_k^3 - y) x_k / (2x_k^3 + y) -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_halley_iteration_step(const Packet& x_k, - const Packet& y) { - typedef typename unpacket_traits::type Scalar; - Packet x_k_cb = pmul(x_k, pmul(x_k, x_k)); - Packet denom = pmadd(pset1(Scalar(2)), x_k_cb, y); - Packet num = psub(x_k_cb, y); - Packet r = pdiv(num, denom); - return pnmadd(x_k, r, x_k); -} - -// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the -// interval [0.125,1]. -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_decompose(const Packet& x, Packet& e_div3) { - typedef typename unpacket_traits::type Scalar; - // Extract the significant s in the range [0.5,1) and exponent e, such that - // x = 2^e * s. - Packet e, s; - s = pfrexp(x, e); - - // Split the exponent into a part divisible by 3 and the remainder. - // e = 3*e_div3 + e_mod3. - constexpr Scalar kOneThird = Scalar(1) / 3; - e_div3 = pceil(pmul(e, pset1(kOneThird))); - Packet e_mod3 = pnmadd(pset1(Scalar(3)), e_div3, e); - - // Replace s by y = (s * 2^e_mod3). - return pldexp_fast(s, e_mod3); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_special_cases_and_sign(const Packet& x, - const Packet& abs_root) { - typedef typename unpacket_traits::type Scalar; - - // Set sign. - const Packet sign_mask = pset1(Scalar(-0.0)); - const Packet x_sign = pand(sign_mask, x); - Packet root = por(x_sign, abs_root); - - // Pass non-finite and zero values of x straight through. - const Packet is_not_finite = por(pisinf(x), pisnan(x)); - const Packet is_zero = pcmp_eq(pzero(x), x); - const Packet use_x = por(is_not_finite, is_zero); - return pselect(use_x, x, root); -} - -// Generic implementation of cbrt(x) for float. -// -// The algorithm computes the cubic root of the input by first -// decomposing it into a exponent and significant -// x = s * 2^e. -// -// We can then write the cube root as -// -// x^(1/3) = 2^(e/3) * s^(1/3) -// = 2^((3*e_div3 + e_mod3)/3) * s^(1/3) -// = 2^(e_div3) * 2^(e_mod3/3) * s^(1/3) -// = 2^(e_div3) * (s * 2^e_mod3)^(1/3) -// -// where e_div3 = ceil(e/3) and e_mod3 = e - 3*e_div3. -// -// The cube root of the second term y = (s * 2^e_mod3)^(1/3) is coarsely -// approximated using a cubic polynomial and subsequently refined using a -// single step of Halley's iteration, and finally the two terms are combined -// using pldexp_fast. -// -// Note: Many alternatives exist for implementing cbrt. See, for example, -// the excellent discussion in Kahan's note: -// https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf -// This particular implementation was found to be very fast and accurate -// among several alternatives tried, but is probably not "optimal" on all -// platforms. -// -// This is accurate to 2 ULP. -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be float"); - - // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the - // interval [0.125,1]. - Packet e_div3; - const Packet y = cbrt_decompose(pabs(x), e_div3); - - // Compute initial approximation accurate to 5.22e-3. - // The polynomial was computed using Rminimax. - constexpr float alpha[] = {5.9220016002655029296875e-01f, -1.3859539031982421875e+00f, 1.4581282138824462890625e+00f, - 3.408401906490325927734375e-01f}; - Packet r = ppolevl::run(y, alpha); - - // Take one step of Halley's iteration. - r = cbrt_halley_iteration_step(r, y); - - // Finally multiply by 2^(e_div3) - r = pldexp_fast(r, e_div3); - - return cbrt_special_cases_and_sign(x, r); -} - -// Generic implementation of cbrt(x) for double. -// -// The algorithm is identical to the one for float except that a different initial -// approximation is used for y^(1/3) and two Halley iteration steps are performed. -// -// This is accurate to 1 ULP. -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be double"); - - // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the - // interval [0.125,1]. - Packet e_div3; - const Packet y = cbrt_decompose(pabs(x), e_div3); - - // Compute initial approximation accurate to 0.016. - // The polynomial was computed using Rminimax. - constexpr double alpha[] = {-4.69470621553356115551736138513660989701747894287109375e-01, - 1.072314636518546304699839311069808900356292724609375e+00, - 3.81249427609571867048288140722434036433696746826171875e-01}; - Packet r = ppolevl::run(y, alpha); - - // Take two steps of Halley's iteration. - r = cbrt_halley_iteration_step(r, y); - r = cbrt_halley_iteration_step(r, y); - - // Finally multiply by 2^(e_div3). - r = pldexp_fast(r, e_div3); - return cbrt_special_cases_and_sign(x, r); -} - //---------------------------------------------------------------------- // Exponential and Logarithmic Functions //---------------------------------------------------------------------- @@ -543,1073 +403,17 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_exp2(const Pa return pmul(exp2_hi, exp2_lo); } -//---------------------------------------------------------------------- -// Trigonometric Functions -//---------------------------------------------------------------------- +} // end namespace internal +} // end namespace Eigen -// Enum for selecting which function to compute. SinCos is intended to compute -// pairs of Sin and Cos of the even entries in the packet, e.g. -// SinCos([a, *, b, *]) = [sin(a), cos(a), sin(b), cos(b)]. -enum class TrigFunction : uint8_t { Sin, Cos, Tan, SinCos }; +// Include the split-out sections. Order matters: Pow depends on exp/log and FrexpLdexp, +// Trig depends on exp (for ptanh_float), Complex depends on Trig (for psincos_selector). +#include "GenericPacketMathPow.h" +#include "GenericPacketMathTrig.h" +#include "GenericPacketMathComplex.h" -// The following code is inspired by the following stack-overflow answer: -// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751 -// It has been largely optimized: -// - By-pass calls to frexp. -// - Aligned loads of required 96 bits of 2/pi. This is accomplished by -// (1) balancing the mantissa and exponent to the required bits of 2/pi are -// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi. -// - Avoid a branch in rounding and extraction of the remaining fractional part. -// Overall, I measured a speed up higher than x2 on x86-64. -inline float trig_reduce_huge(float xf, Eigen::numext::int32_t* quadrant) { - using Eigen::numext::int32_t; - using Eigen::numext::int64_t; - using Eigen::numext::uint32_t; - using Eigen::numext::uint64_t; - - const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62 - const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point format - - // 192 bits of 2/pi for Payne-Hanek reduction - // Bits are introduced by packet of 8 to enable aligned reads. - static const uint32_t two_over_pi[] = { - 0x00000028, 0x000028be, 0x0028be60, 0x28be60db, 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a, 0x91054a7f, - 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4, 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770, 0x4d377036, 0x377036d8, - 0x7036d8a5, 0x36d8a566, 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410, 0x10e41000, 0xe4100000}; - - uint32_t xi = numext::bit_cast(xf); - // Below, -118 = -126 + 8. - // -126 is to get the exponent, - // +8 is to enable alignment of 2/pi's bits on 8 bits. - // This is possible because the fractional part of x as only 24 meaningful bits. - uint32_t e = (xi >> 23) - 118; - // Extract the mantissa and shift it to align it wrt the exponent - xi = ((xi & 0x007fffffu) | 0x00800000u) << (e & 0x7); - - uint32_t i = e >> 3; - uint32_t twoopi_1 = two_over_pi[i - 1]; - uint32_t twoopi_2 = two_over_pi[i + 3]; - uint32_t twoopi_3 = two_over_pi[i + 7]; - - // Compute x * 2/pi in 2.62-bit fixed-point format. - uint64_t p; - p = uint64_t(xi) * twoopi_3; - p = uint64_t(xi) * twoopi_2 + (p >> 32); - p = (uint64_t(xi * twoopi_1) << 32) + p; - - // Round to nearest: add 0.5 and extract integral part. - uint64_t q = (p + zero_dot_five) >> 62; - *quadrant = int(q); - // Now it remains to compute "r = x - q*pi/2" with high accuracy, - // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as: - // r = (p-q)*pi/2, - // where the product can be be carried out with sufficient accuracy using double precision. - p -= q << 62; - return float(double(int64_t(p)) * pio2_62); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -#if EIGEN_COMP_GNUC_STRICT - __attribute__((optimize("-fno-unsafe-math-optimizations"))) -#endif - Packet - psincos_float(const Packet& _x) { - typedef typename unpacket_traits::integer_packet PacketI; - - const Packet cst_2oPI = pset1(0.636619746685028076171875f); // 2/PI - const Packet cst_rounding_magic = pset1(12582912); // 2^23 for rounding - const PacketI csti_1 = pset1(1); - const Packet cst_sign_mask = pset1frombits(static_cast(0x80000000u)); - - Packet x = pabs(_x); - - // Scale x by 2/Pi to find x's octant. - Packet y = pmul(x, cst_2oPI); - - // Rounding trick to find nearest integer: - Packet y_round = padd(y, cst_rounding_magic); - EIGEN_OPTIMIZATION_BARRIER(y_round) - PacketI y_int = preinterpret(y_round); // last 23 digits represent integer (if abs(x)<2^24) - y = psub(y_round, cst_rounding_magic); // nearest integer to x * (2/pi) - -// Subtract y * Pi/2 to reduce x to the interval -Pi/4 <= x <= +Pi/4 -// using "Extended precision modular arithmetic" -#if defined(EIGEN_VECTORIZE_FMA) - // This version requires true FMA for high accuracy. - // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08): - constexpr float huge_th = (Func == TrigFunction::Sin) ? 117435.992f : 71476.0625f; - x = pmadd(y, pset1(-1.57079601287841796875f), x); - x = pmadd(y, pset1(-3.1391647326017846353352069854736328125e-07f), x); - x = pmadd(y, pset1(-5.390302529957764765544681040410068817436695098876953125e-15f), x); -#else - // Without true FMA, the previous set of coefficients maintain 1ULP accuracy - // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7. - // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs. - - // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively. - // and 2 ULP up to: - constexpr float huge_th = (Func == TrigFunction::Sin) ? 25966.f : 18838.f; - x = pmadd(y, pset1(-1.5703125), x); // = 0xbfc90000 - EIGEN_OPTIMIZATION_BARRIER(x) - x = pmadd(y, pset1(-0.000483989715576171875), x); // = 0xb9fdc000 - EIGEN_OPTIMIZATION_BARRIER(x) - x = pmadd(y, pset1(1.62865035235881805419921875e-07), x); // = 0x342ee000 - x = pmadd(y, pset1(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee - -// For the record, the following set of coefficients maintain 2ULP up -// to a slightly larger range: -// const float huge_th = ComputeSine ? 51981.f : 39086.125f; -// but it slightly fails to maintain 1ULP for two values of sin below pi. -// x = pmadd(y, pset1(-3.140625/2.), x); -// x = pmadd(y, pset1(-0.00048351287841796875), x); -// x = pmadd(y, pset1(-3.13855707645416259765625e-07), x); -// x = pmadd(y, pset1(-6.0771006282767103812147979624569416046142578125e-11), x); - -// For the record, with only 3 iterations it is possible to maintain -// 1 ULP up to 3PI (maybe more) and 2ULP up to 255. -// The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee -#endif - - if (predux_any(pcmp_le(pset1(huge_th), pabs(_x)))) { - const int PacketSize = unpacket_traits::size; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize]; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize]; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Eigen::numext::int32_t y_int2[PacketSize]; - pstoreu(vals, pabs(_x)); - pstoreu(x_cpy, x); - pstoreu(y_int2, y_int); - for (int k = 0; k < PacketSize; ++k) { - float val = vals[k]; - if (val >= huge_th && (numext::isfinite)(val)) x_cpy[k] = trig_reduce_huge(val, &y_int2[k]); - } - x = ploadu(x_cpy); - y_int = ploadu(y_int2); - } - - // Get the polynomial selection mask from the second bit of y_int - // We'll calculate both (sin and cos) polynomials and then select from the two. - Packet poly_mask = preinterpret(pcmp_eq(pand(y_int, csti_1), pzero(y_int))); - - Packet x2 = pmul(x, x); - - // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4) - Packet y1 = pset1(2.4372266125283204019069671630859375e-05f); - y1 = pmadd(y1, x2, pset1(-0.00138865201734006404876708984375f)); - y1 = pmadd(y1, x2, pset1(0.041666619479656219482421875f)); - y1 = pmadd(y1, x2, pset1(-0.5f)); - y1 = pmadd(y1, x2, pset1(1.f)); - - // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4) - // octave/matlab code to compute those coefficients: - // x = (0:0.0001:pi/4)'; - // A = [x.^3 x.^5 x.^7]; - // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy - // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1 - // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1)) - // - Packet y2 = pset1(-0.0001959234114083702898469196984621021329076029360294342041015625f); - y2 = pmadd(y2, x2, pset1(0.0083326873655616851693794799871284340042620897293090820312500000f)); - y2 = pmadd(y2, x2, pset1(-0.1666666203982298255503735617821803316473960876464843750000000000f)); - y2 = pmul(y2, x2); - y2 = pmadd(y2, x, x); - - // Select the correct result from the two polynomials. - // Compute the sign to apply to the polynomial. - // sin: sign = second_bit(y_int) xor signbit(_x) - // cos: sign = second_bit(y_int+1) - Packet sign_bit = (Func == TrigFunction::Sin) ? pxor(_x, preinterpret(plogical_shift_left<30>(y_int))) - : preinterpret(plogical_shift_left<30>(padd(y_int, csti_1))); - sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit - - if ((Func == TrigFunction::SinCos) || (Func == TrigFunction::Tan)) { - // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div. - Packet peven = peven_mask(x); - Packet ysin = pselect(poly_mask, y2, y1); - Packet ycos = pselect(poly_mask, y1, y2); - Packet sign_bit_sin = pxor(_x, preinterpret(plogical_shift_left<30>(y_int))); - Packet sign_bit_cos = preinterpret(plogical_shift_left<30>(padd(y_int, csti_1))); - sign_bit_sin = pand(sign_bit_sin, cst_sign_mask); // clear all but left most bit - sign_bit_cos = pand(sign_bit_cos, cst_sign_mask); // clear all but left most bit - y = (Func == TrigFunction::SinCos) ? pselect(peven, pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos)) - : pdiv(pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos)); - } else { - y = (Func == TrigFunction::Sin) ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2); - y = pxor(y, sign_bit); - } - return y; -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_float(const Packet& x) { - return psincos_float(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Packet& x) { - return psincos_float(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_float(const Packet& x) { - return psincos_float(x); -} - -// Trigonometric argument reduction for double for inputs smaller than 15. -// Reduces trigonometric arguments for double inputs where x < 15. Given an argument x and its corresponding quadrant -// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t. -template -Packet trig_reduce_small_double(const Packet& x, const Packet& q) { - // Pi/2 split into 2 values - const Packet cst_pio2_a = pset1(-1.570796325802803); - const Packet cst_pio2_b = pset1(-9.920935184482005e-10); - - Packet t; - t = pmadd(cst_pio2_a, q, x); - t = pmadd(cst_pio2_b, q, t); - return t; -} - -// Trigonometric argument reduction for double for inputs smaller than 1e14. -// Reduces trigonometric arguments for double inputs where x < 1e14. Given an argument x and its corresponding quadrant -// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t. -template -Packet trig_reduce_medium_double(const Packet& x, const Packet& q_high, const Packet& q_low) { - // Pi/2 split into 4 values - const Packet cst_pio2_a = pset1(-1.570796325802803); - const Packet cst_pio2_b = pset1(-9.920935184482005e-10); - const Packet cst_pio2_c = pset1(-6.123234014771656e-17); - const Packet cst_pio2_d = pset1(1.903488962019325e-25); - - Packet t; - t = pmadd(cst_pio2_a, q_high, x); - t = pmadd(cst_pio2_a, q_low, t); - t = pmadd(cst_pio2_b, q_high, t); - t = pmadd(cst_pio2_b, q_low, t); - t = pmadd(cst_pio2_c, q_high, t); - t = pmadd(cst_pio2_c, q_low, t); - t = pmadd(cst_pio2_d, padd(q_low, q_high), t); - return t; -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -#if EIGEN_COMP_GNUC_STRICT - __attribute__((optimize("-fno-unsafe-math-optimizations"))) -#endif - Packet - psincos_double(const Packet& x) { - typedef typename unpacket_traits::integer_packet PacketI; - typedef typename unpacket_traits::type ScalarI; - - const Packet cst_sign_mask = pset1frombits(static_cast(0x8000000000000000u)); - - // If the argument is smaller than this value, use a simpler argument reduction - const double small_th = 15; - // If the argument is bigger than this value, use the non-vectorized std version - const double huge_th = 1e14; - - const Packet cst_2oPI = pset1(0.63661977236758134307553505349006); // 2/PI - // Integer Packet constants - const PacketI cst_one = pset1(ScalarI(1)); - // Constant for splitting - const Packet cst_split = pset1(1 << 24); - - Packet x_abs = pabs(x); - - // Scale x by 2/Pi - PacketI q_int; - Packet s; - - // TODO Implement huge angle argument reduction - if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1(small_th), x_abs)))) { - Packet q_high = pmul(pfloor(pmul(x_abs, pdiv(cst_2oPI, cst_split))), cst_split); - Packet q_low_noround = psub(pmul(x_abs, cst_2oPI), q_high); - q_int = pcast(padd(q_low_noround, pset1(0.5))); - Packet q_low = pcast(q_int); - s = trig_reduce_medium_double(x_abs, q_high, q_low); - } else { - Packet qval_noround = pmul(x_abs, cst_2oPI); - q_int = pcast(padd(qval_noround, pset1(0.5))); - Packet q = pcast(q_int); - s = trig_reduce_small_double(x_abs, q); - } - - // All the upcoming approximating polynomials have even exponents - Packet ss = pmul(s, s); - - // Padé approximant of cos(x) - // Assuring < 1 ULP error on the interval [-pi/4, pi/4] - // cos(x) ~= (80737373*x^8 - 13853547000*x^6 + 727718024880*x^4 - 11275015752000*x^2 + 23594700729600)/(147173*x^8 + - // 39328920*x^6 + 5772800880*x^4 + 522334612800*x^2 + 23594700729600) - // MATLAB code to compute those coefficients: - // syms x; - // cosf = @(x) cos(x); - // pade_cosf = pade(cosf(x), x, 0, 'Order', 8) - const Packet cn4 = pset1(80737373); - const Packet cn3 = pset1(-13853547000); - const Packet cn2 = pset1(727718024880); - const Packet cn1 = pset1(-11275015752000); - const Packet cn0 = pset1(23594700729600); // shared with cd0 - const Packet cd3 = pset1(147173); - const Packet cd2 = pset1(39328920); - const Packet cd1 = pset1(5772800880); - const Packet cd0 = pset1(522334612800); - Packet sc1_num = pmadd(ss, cn4, cn3); - Packet sc2_num = pmadd(sc1_num, ss, cn2); - Packet sc3_num = pmadd(sc2_num, ss, cn1); - Packet sc4_num = pmadd(sc3_num, ss, cn0); - Packet sc1_denum = pmadd(ss, cd3, cd2); - Packet sc2_denum = pmadd(sc1_denum, ss, cd1); - Packet sc3_denum = pmadd(sc2_denum, ss, cd0); - Packet sc4_denum = pmadd(sc3_denum, ss, cn0); - Packet scos = pdiv(sc4_num, sc4_denum); - - // Padé approximant of sin(x) - // Assuring < 1 ULP error on the interval [-pi/4, pi/4] - // sin(x) ~= (x*(4585922449*x^8 - 1066023933480*x^6 + 83284044283440*x^4 - 2303682236856000*x^2 + - // 15605159573203200))/(45*(1029037*x^8 + 345207016*x^6 + 61570292784*x^4 + 6603948711360*x^2 + 346781323848960)) - // MATLAB code to compute those coefficients: - // syms x; - // sinf = @(x) sin(x); - // pade_sinf = pade(sinf(x), x, 0, 'Order', 8, 'OrderMode', 'relative') - const Packet sn4 = pset1(4585922449); - const Packet sn3 = pset1(-1066023933480); - const Packet sn2 = pset1(83284044283440); - const Packet sn1 = pset1(-2303682236856000); - const Packet sn0 = pset1(15605159573203200); - const Packet sd3 = pset1(1029037); - const Packet sd2 = pset1(345207016); - const Packet sd1 = pset1(61570292784); - const Packet sd0_inner = pset1(6603948711360); - const Packet sd0 = pset1(346781323848960); - const Packet cst_45 = pset1(45); - Packet ss1_num = pmadd(ss, sn4, sn3); - Packet ss2_num = pmadd(ss1_num, ss, sn2); - Packet ss3_num = pmadd(ss2_num, ss, sn1); - Packet ss4_num = pmadd(ss3_num, ss, sn0); - Packet ss1_denum = pmadd(ss, sd3, sd2); - Packet ss2_denum = pmadd(ss1_denum, ss, sd1); - Packet ss3_denum = pmadd(ss2_denum, ss, sd0_inner); - Packet ss4_denum = pmadd(ss3_denum, ss, sd0); - Packet ssin = pdiv(pmul(s, ss4_num), pmul(cst_45, ss4_denum)); - - Packet poly_mask = preinterpret(pcmp_eq(pand(q_int, cst_one), pzero(q_int))); - - Packet sign_sin = pxor(x, preinterpret(plogical_shift_left<62>(q_int))); - Packet sign_cos = preinterpret(plogical_shift_left<62>(padd(q_int, cst_one))); - Packet sign_bit, sFinalRes; - if (Func == TrigFunction::Sin) { - sign_bit = sign_sin; - sFinalRes = pselect(poly_mask, ssin, scos); - } else if (Func == TrigFunction::Cos) { - sign_bit = sign_cos; - sFinalRes = pselect(poly_mask, scos, ssin); - } else if (Func == TrigFunction::Tan) { - // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div. - sign_bit = pxor(sign_sin, sign_cos); - sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin)); - } else if (Func == TrigFunction::SinCos) { - Packet peven = peven_mask(x); - sign_bit = pselect(peven, sign_sin, sign_cos); - sFinalRes = pselect(pxor(peven, poly_mask), scos, ssin); - } - sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit - sFinalRes = pxor(sFinalRes, sign_bit); - - // If the inputs values are higher than that a value that the argument reduction can currently address, compute them - // using the C++ standard library. - // TODO Remove it when huge angle argument reduction is implemented - if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1(huge_th), x_abs)))) { - const int PacketSize = unpacket_traits::size; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double sincos_vals[PacketSize]; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double x_cpy[PacketSize]; - pstoreu(x_cpy, x); - pstoreu(sincos_vals, sFinalRes); - for (int k = 0; k < PacketSize; ++k) { - double val = x_cpy[k]; - if (std::abs(val) > huge_th && (numext::isfinite)(val)) { - if (Func == TrigFunction::Sin) { - sincos_vals[k] = std::sin(val); - } else if (Func == TrigFunction::Cos) { - sincos_vals[k] = std::cos(val); - } else if (Func == TrigFunction::Tan) { - sincos_vals[k] = std::tan(val); - } else if (Func == TrigFunction::SinCos) { - sincos_vals[k] = k % 2 == 0 ? std::sin(val) : std::cos(val); - } - } - } - sFinalRes = ploadu(sincos_vals); - } - return sFinalRes; -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x) { - return psincos_double(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x) { - return psincos_double(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_double(const Packet& x) { - return psincos_double(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS - std::enable_if_t::type, float>::value, Packet> - psincos_selector(const Packet& x) { - return psincos_float(x); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS - std::enable_if_t::type, double>::value, Packet> - psincos_selector(const Packet& x) { - return psincos_double(x); -} - -//---------------------------------------------------------------------- -// Inverse Trigonometric Functions -//---------------------------------------------------------------------- - -// Generic implementation of acos(x). -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pacos_float(const Packet& x_in) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be float"); - - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_pi = pset1(Scalar(EIGEN_PI)); - const Packet p6 = pset1(Scalar(2.36423197202384471893310546875e-3)); - const Packet p5 = pset1(Scalar(-1.1368644423782825469970703125e-2)); - const Packet p4 = pset1(Scalar(2.717843465507030487060546875e-2)); - const Packet p3 = pset1(Scalar(-4.8969544470310211181640625e-2)); - const Packet p2 = pset1(Scalar(8.8804088532924652099609375e-2)); - const Packet p1 = pset1(Scalar(-0.214591205120086669921875)); - const Packet p0 = pset1(Scalar(1.57079637050628662109375)); - - // For x in [0:1], we approximate acos(x)/sqrt(1-x), which is a smooth - // function, by a 6'th order polynomial. - // For x in [-1:0) we use that acos(-x) = pi - acos(x). - const Packet neg_mask = psignbit(x_in); - const Packet abs_x = pabs(x_in); - - // Evaluate the polynomial using Horner's rule: - // P(x) = p0 + x * (p1 + x * (p2 + ... (p5 + x * p6)) ... ) . - // We evaluate even and odd terms independently to increase - // instruction level parallelism. - Packet x2 = pmul(x_in, x_in); - Packet p_even = pmadd(p6, x2, p4); - Packet p_odd = pmadd(p5, x2, p3); - p_even = pmadd(p_even, x2, p2); - p_odd = pmadd(p_odd, x2, p1); - p_even = pmadd(p_even, x2, p0); - Packet p = pmadd(p_odd, abs_x, p_even); - - // The polynomial approximates acos(x)/sqrt(1-x), so - // multiply by sqrt(1-x) to get acos(x). - // Conveniently returns NaN for arguments outside [-1:1]. - Packet denom = psqrt(psub(cst_one, abs_x)); - Packet result = pmul(denom, p); - // Undo mapping for negative arguments. - return pselect(neg_mask, psub(cst_pi, result), result); -} - -// Generic implementation of asin(x). -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Packet& x_in) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be float"); - - constexpr float kPiOverTwo = static_cast(EIGEN_PI / 2); - - const Packet cst_half = pset1(0.5f); - const Packet cst_one = pset1(1.0f); - const Packet cst_two = pset1(2.0f); - const Packet cst_pi_over_two = pset1(kPiOverTwo); - - const Packet abs_x = pabs(x_in); - const Packet sign_mask = pandnot(x_in, abs_x); - const Packet invalid_mask = pcmp_lt(cst_one, abs_x); - - // For arguments |x| > 0.5, we map x back to [0:0.5] using - // the transformation x_large = sqrt(0.5*(1-x)), and use the - // identity - // asin(x) = pi/2 - 2 * asin( sqrt( 0.5 * (1 - x))) - - const Packet x_large = psqrt(pnmadd(cst_half, abs_x, cst_half)); - const Packet large_mask = pcmp_lt(cst_half, abs_x); - const Packet x = pselect(large_mask, x_large, abs_x); - const Packet x2 = pmul(x, x); - - // For |x| < 0.5 approximate asin(x)/x by an 8th order polynomial with - // even terms only. - constexpr float alpha[] = {5.08838854730129241943359375e-2f, 3.95139865577220916748046875e-2f, - 7.550220191478729248046875e-2f, 0.16664917767047882080078125f, 1.00000011920928955078125f}; - Packet p = ppolevl::run(x2, alpha); - p = pmul(p, x); - - const Packet p_large = pnmadd(cst_two, p, cst_pi_over_two); - p = pselect(large_mask, p_large, p); - // Flip the sign for negative arguments. - p = pxor(p, sign_mask); - // Return NaN for arguments outside [-1:1]. - return por(invalid_mask, p); -} - -template -struct patan_reduced { - template - static EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet run(const Packet& x); -}; - -template <> -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced::run(const Packet& x) { - constexpr double alpha[] = {2.6667153866462208e-05, 3.0917513112462781e-03, 5.2574296781008604e-02, - 3.0409318473444424e-01, 7.5365702534987022e-01, 8.2704055405494614e-01, - 3.3004361289279920e-01}; - - constexpr double beta[] = { - 2.7311202462436667e-04, 1.0899150928962708e-02, 1.1548932646420353e-01, 4.9716458728465573e-01, 1.0, - 9.3705509168587852e-01, 3.3004361289279920e-01}; - - Packet x2 = pmul(x, x); - Packet p = ppolevl::run(x2, alpha); - Packet q = ppolevl::run(x2, beta); - return pmul(x, pdiv(p, q)); -} - -// Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy. -template <> -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced::run(const Packet& x) { - constexpr float alpha[] = {1.12026982009410858154296875e-01f, 7.296695709228515625e-01f, 8.109951019287109375e-01f}; - - constexpr float beta[] = {1.00917108356952667236328125e-02f, 2.8318560123443603515625e-01f, 1.0f, - 8.109951019287109375e-01f}; - - Packet x2 = pmul(x, x); - Packet p = ppolevl::run(x2, alpha); - Packet q = ppolevl::run(x2, beta); - return pmul(x, pdiv(p, q)); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_atan(const Packet& x_in) { - typedef typename unpacket_traits::type Scalar; - - constexpr Scalar kPiOverTwo = static_cast(EIGEN_PI / 2); - - const Packet cst_signmask = pset1(Scalar(-0.0)); - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_pi_over_two = pset1(kPiOverTwo); - - // "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x). - // "Small": For |x| <= 1, approximate atan(x) directly by a polynomial - // calculated using Rminimax. - - const Packet abs_x = pabs(x_in); - const Packet x_signmask = pand(x_in, cst_signmask); - const Packet large_mask = pcmp_lt(cst_one, abs_x); - const Packet x = pselect(large_mask, preciprocal(abs_x), abs_x); - const Packet p = patan_reduced::run(x); - // Apply transformations according to the range reduction masks. - Packet result = pselect(large_mask, psub(cst_pi_over_two, p), p); - // Return correct sign - return pxor(result, x_signmask); -} - -//---------------------------------------------------------------------- -// Hyperbolic Functions -//---------------------------------------------------------------------- - -#ifdef EIGEN_FAST_MATH - -/** \internal \returns the hyperbolic tan of \a a (coeff-wise) - Doesn't do anything fancy, just a 9/8-degree rational interpolant which - is accurate up to a couple of ulps in the (approximate) range [-8, 8], - outside of which tanh(x) = +/-1 in single precision. The input is clamped - to the range [-c, c]. The value c is chosen as the smallest value where - the approximation evaluates to exactly 1. - - This implementation works on both scalars and packets. -*/ -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) { - // Clamp the inputs to the range [-c, c] and set everything - // outside that range to 1.0. The value c is chosen as the smallest - // floating point argument such that the approximation is exactly 1. - // This saves clamping the value at the end. -#ifdef EIGEN_VECTORIZE_FMA - const T plus_clamp = pset1(8.01773357391357422f); - const T minus_clamp = pset1(-8.01773357391357422f); -#else - const T plus_clamp = pset1(7.90738964080810547f); - const T minus_clamp = pset1(-7.90738964080810547f); -#endif - const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); - - // The following rational approximation was generated by rminimax - // (https://gitlab.inria.fr/sfilip/rminimax) using the following - // command: - // $ ratapprox --function="tanh(x)" --dom='[-8.67,8.67]' --num="odd" - // --den="even" --type="[9,8]" --numF="[SG]" --denF="[SG]" --log - // --output=tanhf.sollya --dispCoeff="dec" - - // The monomial coefficients of the numerator polynomial (odd). - constexpr float alpha[] = {1.394553628e-8f, 2.102733560e-5f, 3.520756727e-3f, 1.340216100e-1f}; - - // The monomial coefficients of the denominator polynomial (even). - constexpr float beta[] = {8.015776984e-7f, 3.326951409e-4f, 2.597254514e-2f, 4.673548340e-1f, 1.0f}; - - // Since the polynomials are odd/even, we need x^2. - const T x2 = pmul(x, x); - const T x3 = pmul(x2, x); - - T p = ppolevl::run(x2, alpha); - T q = ppolevl::run(x2, beta); - // Take advantage of the fact that the constant term in p is 1 to compute - // x*(x^2*p + 1) = x^3 * p + x. - p = pmadd(x3, p, x); - - // Divide the numerator by the denominator. - return pdiv(p, q); -} - -#else - -/** \internal \returns the hyperbolic tan of \a a (coeff-wise). - On the domain [-1.25:1.25] we use an approximation of the form - tanh(x) ~= x^3 * (P(x) / Q(x)) + x, where P and Q are polynomials in x^2. - For |x| > 1.25, tanh is implemented as tanh(x) = 1 - (2 / (1 + exp(2*x))). - - This implementation has a maximum error of 1 ULP (measured with AVX2+FMA). - - This implementation works on both scalars and packets. -*/ -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& x) { - // The polynomial coefficients were computed using Rminimax: - // % ./ratapprox --function="tanh(x)-x" --dom='[-1.25,1.25]' --num="[x^3,x^5]" --den="even" - // --type="[3,4]" --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec" --output=tanhf.solly - constexpr float alpha[] = {-1.46725140511989593505859375e-02f, -3.333333432674407958984375e-01f}; - constexpr float beta[] = {1.570280082523822784423828125e-02, 4.4401752948760986328125e-01, 1.0f}; - const T x2 = pmul(x, x); - const T x3 = pmul(x2, x); - const T p = ppolevl::run(x2, alpha); - const T q = ppolevl::run(x2, beta); - const T small_tanh = pmadd(x3, pdiv(p, q), x); - - const T sign_mask = pset1(-0.0f); - const T abs_x = pandnot(x, sign_mask); - constexpr float kSmallThreshold = 1.25f; - const T large_mask = pcmp_lt(pset1(kSmallThreshold), abs_x); - // Fast exit if all elements are small. - if (!predux_any(large_mask)) { - return small_tanh; - } - - // Compute as 1 - (2 / (1 + exp(2*x))) - const T one = pset1(1.0f); - const T two = pset1(2.0f); - const T s = pexp_float(pmul(two, abs_x)); - const T abs_tanh = psub(one, pdiv(two, padd(s, one))); - - // Handle infinite inputs and set sign bit. - constexpr float kHugeThreshold = 16.0f; - const T huge_mask = pcmp_lt(pset1(kHugeThreshold), abs_x); - const T x_sign = pand(sign_mask, x); - const T large_tanh = por(x_sign, pselect(huge_mask, one, abs_tanh)); - return pselect(large_mask, large_tanh, small_tanh); -} - -#endif // EIGEN_FAST_MATH - -/** \internal \returns the hyperbolic tan of \a a (coeff-wise) - This uses a 19/18-degree rational interpolant which - is accurate up to a couple of ulps in the (approximate) range [-18.7, 18.7], - outside of which tanh(x) = +/-1 in single precision. The input is clamped - to the range [-c, c]. The value c is chosen as the smallest value where - the approximation evaluates to exactly 1. - - This implementation works on both scalars and packets. -*/ -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_double(const T& a_x) { - // Clamp the inputs to the range [-c, c] and set everything - // outside that range to 1.0. The value c is chosen as the smallest - // floating point argument such that the approximation is exactly 1. - // This saves clamping the value at the end. -#ifdef EIGEN_VECTORIZE_FMA - const T plus_clamp = pset1(17.6610191624600077); - const T minus_clamp = pset1(-17.6610191624600077); -#else - const T plus_clamp = pset1(17.714196154005176); - const T minus_clamp = pset1(-17.714196154005176); -#endif - const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); - // The following rational approximation was generated by rminimax - // (https://gitlab.inria.fr/sfilip/rminimax) using the following - // command: - // $ ./ratapprox --function="tanh(x)" --dom='[-18.72,18.72]' - // --num="odd" --den="even" --type="[19,18]" --numF="[D]" - // --denF="[D]" --log --output=tanh.sollya --dispCoeff="dec" - - // The monomial coefficients of the numerator polynomial (odd). - constexpr double alpha[] = {2.6158007860482230e-23, 7.6534862268749319e-19, 3.1309488231386680e-15, - 4.2303918148209176e-12, 2.4618379131293676e-09, 6.8644367682497074e-07, - 9.3839087674268880e-05, 5.9809711724441161e-03, 1.5184719640284322e-01}; - - // The monomial coefficients of the denominator polynomial (even). - constexpr double beta[] = {6.463747022670968018e-21, 5.782506856739003571e-17, - 1.293019623712687916e-13, 1.123643448069621992e-10, - 4.492975677839633985e-08, 8.785185266237658698e-06, - 8.295161192716231542e-04, 3.437448108450402717e-02, - 4.851805297361760360e-01, 1.0}; - - // Since the polynomials are odd/even, we need x^2. - const T x2 = pmul(x, x); - const T x3 = pmul(x2, x); - - // Interleave the evaluation of the numerator polynomial p and - // denominator polynomial q. - T p = ppolevl::run(x2, alpha); - T q = ppolevl::run(x2, beta); - // Take advantage of the fact that the constant term in p is 1 to compute - // x*(x^2*p + 1) = x^3 * p + x. - p = pmadd(x3, p, x); - - // Divide the numerator by the denominator. - return pdiv(p, q); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be float"); - - // For |x| in [0:0.5] we use a polynomial approximation of the form - // P(x) = x + x^3*(alpha[4] + x^2 * (alpha[3] + x^2 * (... x^2 * alpha[0]) ... )). - constexpr float alpha[] = {0.1819281280040740966796875f, 8.2311116158962249755859375e-2f, - 0.14672131836414337158203125f, 0.1997792422771453857421875f, 0.3333373963832855224609375f}; - const Packet x2 = pmul(x, x); - const Packet x3 = pmul(x, x2); - Packet p = ppolevl::run(x2, alpha); - p = pmadd(x3, p, x); - - // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); - const Packet half = pset1(0.5f); - const Packet one = pset1(1.0f); - Packet r = pdiv(padd(one, x), psub(one, x)); - r = pmul(half, plog(r)); - - const Packet x_gt_half = pcmp_le(half, pabs(x)); - const Packet x_eq_one = pcmp_eq(one, pabs(x)); - const Packet x_gt_one = pcmp_lt(one, pabs(x)); - const Packet sign_mask = pset1(-0.0f); - const Packet x_sign = pand(sign_mask, x); - const Packet inf = pset1(std::numeric_limits::infinity()); - return por(x_gt_one, pselect(x_eq_one, por(x_sign, inf), pselect(x_gt_half, r, p))); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_double(const Packet& x) { - typedef typename unpacket_traits::type Scalar; - static_assert(std::is_same::value, "Scalar type must be double"); - // For x in [-0.5:0.5] we use a rational approximation of the form - // R(x) = x + x^3*P(x^2)/Q(x^2), where P is or order 4 and Q is of order 5. - constexpr double alpha[] = {3.3071338469301391e-03, -4.7129526768798737e-02, 1.8185306179826699e-01, - -2.5949536095445679e-01, 1.2306328729812676e-01}; - - constexpr double beta[] = {-3.8679974580640881e-03, 7.6391885763341910e-02, -4.2828141436397615e-01, - 9.8733495886883648e-01, -1.0000000000000000e+00, 3.6918986189438030e-01}; - - const Packet x2 = pmul(x, x); - const Packet x3 = pmul(x, x2); - Packet p = ppolevl::run(x2, alpha); - Packet q = ppolevl::run(x2, beta); - Packet y_small = pmadd(x3, pdiv(p, q), x); - - // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); - const Packet half = pset1(0.5); - const Packet one = pset1(1.0); - Packet y_large = pdiv(padd(one, x), psub(one, x)); - y_large = pmul(half, plog(y_large)); - - const Packet x_gt_half = pcmp_le(half, pabs(x)); - const Packet x_eq_one = pcmp_eq(one, pabs(x)); - const Packet x_gt_one = pcmp_lt(one, pabs(x)); - const Packet sign_mask = pset1(-0.0); - const Packet x_sign = pand(sign_mask, x); - const Packet inf = pset1(std::numeric_limits::infinity()); - return por(x_gt_one, pselect(x_eq_one, por(x_sign, inf), pselect(x_gt_half, y_large, y_small))); -} - -//---------------------------------------------------------------------- -// Complex Arithmetic and Functions -//---------------------------------------------------------------------- - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) { - typedef typename unpacket_traits::as_real RealPacket; - typedef typename unpacket_traits::type RealScalar; - // In the following we annotate the code for the case where the inputs - // are a pair length-2 SIMD vectors representing a single pair of complex - // numbers x = a + i*b, y = c + i*d. - const RealPacket one = pset1(RealScalar(1)); - const RealPacket y_flip = pcplxflip(y).v; - // We need to avoid dividing by Inf/Inf, so use a mask to carefully - // apply the scale. - const RealPacket mask = pcmp_lt(pabs(y.v), pabs(y_flip)); // |c| < |d| - const RealPacket y_scaled = pselect(mask, pdiv(y.v, y_flip), one); - RealPacket denom = pmul(y.v, y_scaled); - denom = padd(denom, pcplxflip(Packet(denom)).v); // c * c' + d * d' - Packet num = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c - return Packet(pdiv(num.v, denom)); -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pmul_complex(const Packet& x, const Packet& y) { - // In the following we annotate the code for the case where the inputs - // are a pair length-2 SIMD vectors representing a single pair of complex - // numbers x = a + i*b, y = c + i*d. - Packet x_re = pdupreal(x); // a, a - Packet x_im = pdupimag(x); // b, b - Packet tmp_re = Packet(pmul(x_re.v, y.v)); // a*c, a*d - Packet tmp_im = Packet(pmul(x_im.v, y.v)); // b*c, b*d - tmp_im = pcplxflip(pconj(tmp_im)); // -b*d, d*c - return padd(tmp_im, tmp_re); // a*c - b*d, a*d + b*c -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Packet& x) { - typedef typename unpacket_traits::type Scalar; - typedef typename Scalar::value_type RealScalar; - typedef typename unpacket_traits::as_real RealPacket; - - // Real part - RealPacket x_flip = pcplxflip(x).v; // b, a - Packet x_norm = phypot_complex(x); // sqrt(a^2 + b^2), sqrt(a^2 + b^2) - RealPacket xlogr = plog(x_norm.v); // log(sqrt(a^2 + b^2)), log(sqrt(a^2 + b^2)) - - // Imag part - RealPacket ximg = patan2(x.v, x_flip); // atan2(a, b), atan2(b, a) - - const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); - RealPacket x_abs = pabs(x.v); - RealPacket is_x_pos_inf = pcmp_eq(x_abs, cst_pos_inf); - RealPacket is_y_pos_inf = pcplxflip(Packet(is_x_pos_inf)).v; - RealPacket is_any_inf = por(is_x_pos_inf, is_y_pos_inf); - RealPacket xreal = pselect(is_any_inf, cst_pos_inf, xlogr); - - return Packet(pselect(peven_mask(xreal), xreal, ximg)); // log(sqrt(a^2 + b^2)), atan2(b, a) -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) { - typedef typename unpacket_traits::as_real RealPacket; - typedef typename unpacket_traits::type Scalar; - typedef typename Scalar::value_type RealScalar; - const RealPacket even_mask = peven_mask(a.v); - const RealPacket odd_mask = pcplxflip(Packet(even_mask)).v; - - // Let a = x + iy. - // exp(a) = exp(x) * cis(y), plus some special edge-case handling. - - // exp(x): - RealPacket x = pand(a.v, even_mask); - x = por(x, pcplxflip(Packet(x)).v); - RealPacket expx = pexp(x); // exp(x); - - // cis(y): - RealPacket y = pand(odd_mask, a.v); - y = por(y, pcplxflip(Packet(y)).v); - RealPacket cisy = psincos_selector(y); - cisy = pcplxflip(Packet(cisy)).v; // cos(y) + i * sin(y) - - const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); - const RealPacket cst_neg_inf = pset1(-NumTraits::infinity()); - - // If x is -inf, we know that cossin(y) is bounded, - // so the result is (0, +/-0), where the sign of the imaginary part comes - // from the sign of cossin(y). - RealPacket cisy_sign = por(pandnot(cisy, pabs(cisy)), pset1(RealScalar(1))); - cisy = pselect(pcmp_eq(x, cst_neg_inf), cisy_sign, cisy); - - // If x is inf, and cos(y) has unknown sign (y is inf or NaN), the result - // is (+/-inf, NaN), where the signs are undetermined (take the sign of y). - RealPacket y_sign = por(pandnot(y, pabs(y)), pset1(RealScalar(1))); - cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy); - - // If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to - // prevent inf * 0 = NaN. The vectorized sincos may compute exact zero - // for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct. - // The y=0 case is handled separately below. - RealPacket cisy_sign_one = por(pand(cisy, pset1(RealScalar(-0.0))), pset1(RealScalar(1))); - RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf)); - cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy); - - Packet result = Packet(pmul(expx, cisy)); - - // If y is +/- 0, the input is real, so take the real result for consistency. - result = pselect(Packet(pcmp_eq(y, pzero(y))), Packet(por(pand(expx, even_mask), pand(y, odd_mask))), result); - - return result; -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a) { - typedef typename unpacket_traits::type Scalar; - typedef typename Scalar::value_type RealScalar; - typedef typename unpacket_traits::as_real RealPacket; - - // Computes the principal sqrt of the complex numbers in the input. - // - // For example, for packets containing 2 complex numbers stored in interleaved format - // a = [a0, a1] = [x0, y0, x1, y1], - // where x0 = real(a0), y0 = imag(a0) etc., this function returns - // b = [b0, b1] = [u0, v0, u1, v1], - // such that b0^2 = a0, b1^2 = a1. - // - // To derive the formula for the complex square roots, let's consider the equation for - // a single complex square root of the number x + i*y. We want to find real numbers - // u and v such that - // (u + i*v)^2 = x + i*y <=> - // u^2 - v^2 + i*2*u*v = x + i*v. - // By equating the real and imaginary parts we get: - // u^2 - v^2 = x - // 2*u*v = y. - // - // For x >= 0, this has the numerically stable solution - // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) - // v = 0.5 * (y / u) - // and for x < 0, - // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) - // u = 0.5 * (y / v) - // - // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as - // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , - - // In the following, without lack of generality, we have annotated the code, assuming - // that the input is a packet of 2 complex numbers. - // - // Step 1. Compute l = [l0, l0, l1, l1], where - // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) - // To avoid over- and underflow, we use the stable formula for each hypotenuse - // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), - // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. - - RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] - RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|] - RealPacket a_max = pmax(a_abs, a_abs_flip); - RealPacket a_min = pmin(a_abs, a_abs_flip); - RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); - RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); - RealPacket r = pdiv(a_min, a_max); - const RealPacket cst_one = pset1(RealScalar(1)); - RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] - // Set l to a_max if a_min is zero. - l = pselect(a_min_zero_mask, a_max, l); - - // Step 2. Compute [rho0, *, rho1, *], where - // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) - // We don't care about the imaginary parts computed here. They will be overwritten later. - const RealPacket cst_half = pset1(RealScalar(0.5)); - Packet rho; - rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); - - // Step 3. Compute [rho0, eta0, rho1, eta1], where - // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. - // set eta = 0 of input is 0 + i0. - RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); - RealPacket real_mask = peven_mask(a.v); - Packet positive_real_result; - // Compute result for inputs with positive real part. - positive_real_result.v = pselect(real_mask, rho.v, eta); - - // Step 4. Compute solution for inputs with negative real part: - // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] - const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), RealScalar(-0.0))).v; - RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); - Packet negative_real_result; - // Notice that rho is positive, so taking its absolute value is a noop. - negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); - - // Step 5. Select solution branch based on the sign of the real parts. - Packet negative_real_mask; - negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); - negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); - Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); - - // Step 6. Handle special cases for infinities: - // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN - // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN - // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y - // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y - const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); - Packet is_inf; - is_inf.v = pcmp_eq(a_abs, cst_pos_inf); - Packet is_real_inf; - is_real_inf.v = pand(is_inf.v, real_mask); - is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); - // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. - Packet real_inf_result; - real_inf_result.v = pmul(a_abs, pset1(Scalar(RealScalar(1.0), RealScalar(0.0))).v); - real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); - // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. - Packet is_imag_inf; - is_imag_inf.v = pandnot(is_inf.v, real_mask); - is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); - Packet imag_inf_result; - imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); - // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan - Packet result_is_nan = pisnan(result); - result = por(result_is_nan, result); - - return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); -} - -// \internal \returns the norm of a complex number z = x + i*y, defined as sqrt(x^2 + y^2). -// Implemented using the hypot(a,b) algorithm from https://doi.org/10.48550/arXiv.1904.09481 -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet phypot_complex(const Packet& a) { - typedef typename unpacket_traits::type Scalar; - typedef typename Scalar::value_type RealScalar; - typedef typename unpacket_traits::as_real RealPacket; - - const RealPacket cst_zero_rp = pset1(static_cast(0.0)); - const RealPacket cst_minus_one_rp = pset1(static_cast(-1.0)); - const RealPacket cst_two_rp = pset1(static_cast(2.0)); - const RealPacket evenmask = peven_mask(a.v); - - RealPacket a_abs = pabs(a.v); - RealPacket a_flip = pcplxflip(Packet(a_abs)).v; // |b|, |a| - RealPacket a_all = pselect(evenmask, a_abs, a_flip); // |a|, |a| - RealPacket b_all = pselect(evenmask, a_flip, a_abs); // |b|, |b| - - RealPacket a2 = pmul(a.v, a.v); // |a^2, b^2| - RealPacket a2_flip = pcplxflip(Packet(a2)).v; // |b^2, a^2| - RealPacket h = psqrt(padd(a2, a2_flip)); // |sqrt(a^2 + b^2), sqrt(a^2 + b^2)| - RealPacket h_sq = pmul(h, h); // |a^2 + b^2, a^2 + b^2| - RealPacket a_sq = pselect(evenmask, a2, a2_flip); // |a^2, a^2| - RealPacket m_h_sq = pmul(h_sq, cst_minus_one_rp); - RealPacket m_a_sq = pmul(a_sq, cst_minus_one_rp); - RealPacket x = psub(psub(pmadd(h, h, m_h_sq), pmadd(b_all, b_all, psub(a_sq, h_sq))), pmadd(a_all, a_all, m_a_sq)); - h = psub(h, pdiv(x, pmul(cst_two_rp, h))); // |h - x/(2*h), h - x/(2*h)| - - // handle zero-case - RealPacket iszero = pcmp_eq(por(a_abs, a_flip), cst_zero_rp); - - h = pandnot(h, iszero); // |sqrt(a^2+b^2), sqrt(a^2+b^2)| - return Packet(h); // |sqrt(a^2+b^2), sqrt(a^2+b^2)| -} +namespace Eigen { +namespace internal { //---------------------------------------------------------------------- // Sign Function @@ -1702,554 +506,6 @@ struct psign_impl::value && } }; -//---------------------------------------------------------------------- -// Power Functions (accurate_log2, generic_pow, unary_pow) -//---------------------------------------------------------------------- - -// This function computes log2(x) and returns the result as a double word. -template -struct accurate_log2 { - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { - log2_x_hi = plog2(x); - log2_x_lo = pzero(x); - } -}; - -// This specialization uses a more accurate algorithm to compute log2(x) for -// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.56508e-10. -// This additional accuracy is needed to counter the error-magnification -// inherent in multiplying by a potentially large exponent in pow(x,y). -// The minimax polynomial used was calculated using the Rminimax tool, -// see https://gitlab.inria.fr/sfilip/rminimax. -// Command line: -// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]' -// --type=[10,0] -// --numF="[D,D,SG]" --denF="[SG]" --log --dispCoeff="dec" -// -// The resulting implementation of pow(x,y) is accurate to 3 ulps. -template <> -struct accurate_log2 { - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) { - // Split the two lowest order constant coefficient into double-word representation. - constexpr double kC0 = 1.442695041742110273474963832995854318141937255859375e+00; - constexpr float kC0_hi = static_cast(kC0); - constexpr float kC0_lo = static_cast(kC0 - static_cast(kC0_hi)); - const Packet c0_hi = pset1(kC0_hi); - const Packet c0_lo = pset1(kC0_lo); - - constexpr double kC1 = -7.2134751588268664068692714863573201000690460205078125e-01; - constexpr float kC1_hi = static_cast(kC1); - constexpr float kC1_lo = static_cast(kC1 - static_cast(kC1_hi)); - const Packet c1_hi = pset1(kC1_hi); - const Packet c1_lo = pset1(kC1_lo); - - constexpr float c[] = { - 9.7010828554630279541015625e-02, -1.6896486282348632812500000e-01, 1.7200836539268493652343750e-01, - -1.7892272770404815673828125e-01, 2.0505344867706298828125000e-01, -2.4046677350997924804687500e-01, - 2.8857553005218505859375000e-01, -3.6067414283752441406250000e-01, 4.8089790344238281250000000e-01}; - - // Evaluate the higher order terms in the polynomial using - // standard arithmetic. - const Packet one = pset1(1.0f); - const Packet x = psub(z, one); - Packet p = ppolevl::run(x, c); - // Evaluate the final two step in Horner's rule using double-word - // arithmetic. - Packet p_hi, p_lo; - twoprod(x, p, p_hi, p_lo); - fast_twosum(c1_hi, c1_lo, p_hi, p_lo, p_hi, p_lo); - twoprod(p_hi, p_lo, x, p_hi, p_lo); - fast_twosum(c0_hi, c0_lo, p_hi, p_lo, p_hi, p_lo); - // Multiply by x to recover log2(z). - twoprod(p_hi, p_lo, x, log2_x_hi, log2_x_lo); - } -}; - -// This specialization uses a more accurate algorithm to compute log2(x) for -// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18. -// This additional accuracy is needed to counter the error-magnification -// inherent in multiplying by a potentially large exponent in pow(x,y). -// The minimax polynomial used was calculated using the Sollya tool. -// See sollya.org. - -template <> -struct accurate_log2 { - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { - // We use a transformation of variables: - // r = c * (x-1) / (x+1), - // such that - // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r). - // The function f(r) can be approximated well using an odd polynomial - // of the form - // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r, - // For the implementation of log2 here, Q is of degree 6 with - // coefficient represented in working precision (double), while C is a - // constant represented in extra precision as a double word to achieve - // full accuracy. - // - // The polynomial coefficients were computed by the Sollya script: - // - // c = 2 / log(2); - // trans = c * (x-1)/(x+1); - // itrans = (1+x/c)/(1-x/c); - // interval=[trans(sqrt(0.5)); trans(sqrt(2))]; - // print(interval); - // f = log2(itrans(x)); - // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating); - const Packet q12 = pset1(2.87074255468000586e-9); - const Packet q10 = pset1(2.38957980901884082e-8); - const Packet q8 = pset1(2.31032094540014656e-7); - const Packet q6 = pset1(2.27279857398537278e-6); - const Packet q4 = pset1(2.31271023278625638e-5); - const Packet q2 = pset1(2.47556738444535513e-4); - const Packet q0 = pset1(2.88543873228900172e-3); - const Packet C_hi = pset1(0.0400377511598501157); - const Packet C_lo = pset1(-4.77726582251425391e-19); - const Packet one = pset1(1.0); - - const Packet cst_2_log2e_hi = pset1(2.88539008177792677); - const Packet cst_2_log2e_lo = pset1(4.07660016854549667e-17); - // c * (x - 1) - Packet t_hi, t_lo; - // t = c * (x-1) - twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), t_hi, t_lo); - // r = c * (x-1) / (x+1), - Packet r_hi, r_lo; - doubleword_div_fp(t_hi, t_lo, padd(x, one), r_hi, r_lo); - - // r2 = r * r - Packet r2_hi, r2_lo; - twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo); - // r4 = r2 * r2 - Packet r4_hi, r4_lo; - twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo); - - // Evaluate Q(r^2) in working precision. We evaluate it in two parts - // (even and odd in r^2) to improve instruction level parallelism. - Packet q_even = pmadd(q12, r4_hi, q8); - Packet q_odd = pmadd(q10, r4_hi, q6); - q_even = pmadd(q_even, r4_hi, q4); - q_odd = pmadd(q_odd, r4_hi, q2); - q_even = pmadd(q_even, r4_hi, q0); - Packet q = pmadd(q_odd, r2_hi, q_even); - - // Now evaluate the low order terms of P(x) in double word precision. - // In the following, due to the increasing magnitude of the coefficients - // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead - // of the slower twosum. - // Q(r^2) * r^2 - Packet p_hi, p_lo; - twoprod(r2_hi, r2_lo, q, p_hi, p_lo); - // Q(r^2) * r^2 + C - Packet p1_hi, p1_lo; - fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo); - // (Q(r^2) * r^2 + C) * r^2 - Packet p2_hi, p2_lo; - twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo); - // ((Q(r^2) * r^2 + C) * r^2 + 1) - Packet p3_hi, p3_lo; - fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo); - - // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r - twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo); - } -}; - -// This function implements the non-trivial case of pow(x,y) where x is -// positive and y is (possibly) non-integer. -// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. -// TODO(rmlarsen): We should probably add this as a packet op 'ppow', to make it -// easier to specialize or turn off for specific types and/or backends. -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { - typedef typename unpacket_traits::type Scalar; - // Split x into exponent e_x and mantissa m_x. - Packet e_x; - Packet m_x = pfrexp(x, e_x); - - // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x). - constexpr Scalar sqrt_half = Scalar(0.70710678118654752440); - const Packet m_x_scale_mask = pcmp_lt(m_x, pset1(sqrt_half)); - m_x = pselect(m_x_scale_mask, pmul(pset1(Scalar(2)), m_x), m_x); - e_x = pselect(m_x_scale_mask, psub(e_x, pset1(Scalar(1))), e_x); - - // Compute log2(m_x) with 6 extra bits of accuracy. - Packet rx_hi, rx_lo; - accurate_log2()(m_x, rx_hi, rx_lo); - - // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled - // precision using double word arithmetic. - Packet f1_hi, f1_lo, f2_hi, f2_lo; - twoprod(e_x, y, f1_hi, f1_lo); - twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo); - // Sum the two terms in f using double word arithmetic. We know - // that |e_x| > |log2(m_x)|, except for the case where e_x==0. - // This means that we can use fast_twosum(f1,f2). - // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any - // accuracy by violating the assumption of fast_twosum, because - // it's a no-op. - Packet f_hi, f_lo; - fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo); - - // Split f into integer and fractional parts. - Packet n_z, r_z; - absolute_split(f_hi, n_z, r_z); - r_z = padd(r_z, f_lo); - Packet n_r; - absolute_split(r_z, n_r, r_z); - n_z = padd(n_z, n_r); - - // We now have an accurate split of f = n_z + r_z and can compute - // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. - // Multiplication by the second factor can be done exactly using pldexp(), since - // it is an integer power of 2. - const Packet e_r = generic_exp2(r_z); - - // Since we know that e_r is in [1/sqrt(2); sqrt(2)], we can use the fast version - // of pldexp to multiply by 2**{n_z} when |n_z| is sufficiently small. - constexpr Scalar kPldExpThresh = std::numeric_limits::max_exponent - 2; - const Packet pldexp_fast_unsafe = pcmp_lt(pset1(kPldExpThresh), pabs(n_z)); - if (predux_any(pldexp_fast_unsafe)) { - return pldexp(e_r, n_z); - } - return pldexp_fast(e_r, n_z); -} - -// Generic implementation of pow(x,y). -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Packet> generic_pow( - const Packet& x, const Packet& y) { - typedef typename unpacket_traits::type Scalar; - - const Packet cst_inf = pset1(NumTraits::infinity()); - const Packet cst_zero = pset1(Scalar(0)); - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_nan = pset1(NumTraits::quiet_NaN()); - - const Packet x_abs = pabs(x); - Packet pow = generic_pow_impl(x_abs, y); - - // In the following we enforce the special case handling prescribed in - // https://en.cppreference.com/w/cpp/numeric/math/pow. - - // Predicates for sign and magnitude of x. - const Packet x_is_negative = pcmp_lt(x, cst_zero); - const Packet x_is_zero = pcmp_eq(x, cst_zero); - const Packet x_is_one = pcmp_eq(x, cst_one); - const Packet x_has_signbit = psignbit(x); - const Packet x_abs_gt_one = pcmp_lt(cst_one, x_abs); - const Packet x_abs_is_inf = pcmp_eq(x_abs, cst_inf); - - // Predicates for sign and magnitude of y. - const Packet y_abs = pabs(y); - const Packet y_abs_is_inf = pcmp_eq(y_abs, cst_inf); - const Packet y_is_negative = pcmp_lt(y, cst_zero); - const Packet y_is_zero = pcmp_eq(y, cst_zero); - const Packet y_is_one = pcmp_eq(y, cst_one); - // Predicates for whether y is integer and odd/even. - const Packet y_is_int = pandnot(pcmp_eq(pfloor(y), y), y_abs_is_inf); - const Packet y_div_2 = pmul(y, pset1(Scalar(0.5))); - const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); - const Packet y_is_odd_int = pandnot(y_is_int, y_is_even); - // Smallest exponent for which (1 + epsilon) overflows to infinity. - constexpr Scalar huge_exponent = - (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits::epsilon(); - const Packet y_abs_is_huge = pcmp_le(pset1(huge_exponent), y_abs); - - // * pow(base, exp) returns NaN if base is finite and negative - // and exp is finite and non-integer. - pow = pselect(pandnot(x_is_negative, y_is_int), cst_nan, pow); - - // * pow(±0, exp), where exp is negative, finite, and is an even integer or - // a non-integer, returns +∞ - // * pow(±0, exp), where exp is positive non-integer or a positive even - // integer, returns +0 - // * pow(+0, exp), where exp is a negative odd integer, returns +∞ - // * pow(-0, exp), where exp is a negative odd integer, returns -∞ - // * pow(+0, exp), where exp is a positive odd integer, returns +0 - // * pow(-0, exp), where exp is a positive odd integer, returns -0 - // Sign is flipped by the rule below. - pow = pselect(x_is_zero, pselect(y_is_negative, cst_inf, cst_zero), pow); - - // pow(base, exp) returns -pow(abs(base), exp) if base has the sign bit set, - // and exp is an odd integer exponent. - pow = pselect(pand(x_has_signbit, y_is_odd_int), pnegate(pow), pow); - - // * pow(base, -∞) returns +∞ for any |base|<1 - // * pow(base, -∞) returns +0 for any |base|>1 - // * pow(base, +∞) returns +0 for any |base|<1 - // * pow(base, +∞) returns +∞ for any |base|>1 - // * pow(±0, -∞) returns +∞ - // * pow(-1, +-∞) = 1 - Packet inf_y_val = pselect(por(pand(y_is_negative, x_is_zero), pxor(y_is_negative, x_abs_gt_one)), cst_inf, cst_zero); - inf_y_val = pselect(pcmp_eq(x, pset1(Scalar(-1.0))), cst_one, inf_y_val); - pow = pselect(y_abs_is_huge, inf_y_val, pow); - - // * pow(+∞, exp) returns +0 for any negative exp - // * pow(+∞, exp) returns +∞ for any positive exp - // * pow(-∞, exp) returns -0 if exp is a negative odd integer. - // * pow(-∞, exp) returns +0 if exp is a negative non-integer or negative - // even integer. - // * pow(-∞, exp) returns -∞ if exp is a positive odd integer. - // * pow(-∞, exp) returns +∞ if exp is a positive non-integer or positive - // even integer. - auto x_pos_inf_value = pselect(y_is_negative, cst_zero, cst_inf); - auto x_neg_inf_value = pselect(y_is_odd_int, pnegate(x_pos_inf_value), x_pos_inf_value); - pow = pselect(x_abs_is_inf, pselect(x_is_negative, x_neg_inf_value, x_pos_inf_value), pow); - - // All cases of NaN inputs return NaN, except the two below. - pow = pselect(por(pisnan(x), pisnan(y)), cst_nan, pow); - - // * pow(base, 1) returns base. - // * pow(base, +/-0) returns 1, regardless of base, even NaN. - // * pow(+1, exp) returns 1, regardless of exponent, even NaN. - pow = pselect(y_is_one, x, pselect(por(x_is_one, y_is_zero), cst_one, pow)); - - return pow; -} - -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Scalar> generic_pow( - const Scalar& x, const Scalar& y) { - return numext::pow(x, y); -} - -namespace unary_pow { - -template ::IsInteger> -struct exponent_helper { - using safe_abs_type = ScalarExponent; - static constexpr ScalarExponent one_half = ScalarExponent(0.5); - // these routines assume that exp is an integer stored as a floating point type - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent safe_abs(const ScalarExponent& exp) { - return numext::abs(exp); - } - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const ScalarExponent& exp) { - eigen_assert(((numext::isfinite)(exp) && exp == numext::floor(exp)) && "exp must be an integer"); - ScalarExponent exp_div_2 = exp * one_half; - ScalarExponent floor_exp_div_2 = numext::floor(exp_div_2); - return exp_div_2 != floor_exp_div_2; - } - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent floor_div_two(const ScalarExponent& exp) { - ScalarExponent exp_div_2 = exp * one_half; - return numext::floor(exp_div_2); - } -}; - -template -struct exponent_helper { - // if `exp` is a signed integer type, cast it to its unsigned counterpart to safely store its absolute value - // consider the (rare) case where `exp` is an int32_t: abs(-2147483648) != 2147483648 - using safe_abs_type = typename numext::get_integer_by_size::unsigned_type; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type safe_abs(const ScalarExponent& exp) { - ScalarExponent mask = numext::signbit(exp); - safe_abs_type result = safe_abs_type(exp ^ mask); - return result + safe_abs_type(ScalarExponent(1) & mask); - } - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const safe_abs_type& exp) { - return exp % safe_abs_type(2) != safe_abs_type(0); - } - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type floor_div_two(const safe_abs_type& exp) { - return exp >> safe_abs_type(1); - } -}; - -template ::type>::IsInteger && NumTraits::IsSigned> -struct reciprocate { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - using Scalar = typename unpacket_traits::type; - const Packet cst_pos_one = pset1(Scalar(1)); - return exponent < 0 ? pdiv(cst_pos_one, x) : x; - } -}; - -template -struct reciprocate { - // pdiv not defined, nor necessary for integer base types - // if the exponent is unsigned, then the exponent cannot be negative - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; } -}; - -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { - using Scalar = typename unpacket_traits::type; - using ExponentHelper = exponent_helper; - using AbsExponentType = typename ExponentHelper::safe_abs_type; - const Packet cst_pos_one = pset1(Scalar(1)); - if (exponent == ScalarExponent(0)) return cst_pos_one; - - Packet result = reciprocate::run(x, exponent); - Packet y = cst_pos_one; - AbsExponentType m = ExponentHelper::safe_abs(exponent); - - while (m > 1) { - bool odd = ExponentHelper::is_odd(m); - if (odd) y = pmul(y, result); - result = pmul(result, result); - m = ExponentHelper::floor_div_two(m); - } - - return pmul(y, result); -} - -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> gen_pow( - const Packet& x, const typename unpacket_traits::type& exponent) { - const Packet exponent_packet = pset1(exponent); - return generic_pow_impl(x, exponent_packet); -} - -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Scalar> gen_pow( - const Scalar& x, const Scalar& exponent) { - return numext::pow(x, exponent); -} - -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, - const ScalarExponent& exponent) { - using Scalar = typename unpacket_traits::type; - - // non-integer base and exponent case - const Packet cst_pos_zero = pzero(x); - const Packet cst_pos_one = pset1(Scalar(1)); - const Packet cst_pos_inf = pset1(NumTraits::infinity()); - const Packet cst_true = ptrue(x); - - const bool exponent_is_not_fin = !(numext::isfinite)(exponent); - const bool exponent_is_neg = exponent < ScalarExponent(0); - const bool exponent_is_pos = exponent > ScalarExponent(0); - - const Packet exp_is_not_fin = exponent_is_not_fin ? cst_true : cst_pos_zero; - const Packet exp_is_neg = exponent_is_neg ? cst_true : cst_pos_zero; - const Packet exp_is_pos = exponent_is_pos ? cst_true : cst_pos_zero; - const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); - const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); - - const Packet x_is_le_zero = pcmp_le(x, cst_pos_zero); - const Packet x_is_ge_zero = pcmp_le(cst_pos_zero, x); - const Packet x_is_zero = pand(x_is_le_zero, x_is_ge_zero); - - const Packet abs_x = pabs(x); - const Packet abs_x_is_le_one = pcmp_le(abs_x, cst_pos_one); - const Packet abs_x_is_ge_one = pcmp_le(cst_pos_one, abs_x); - const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); - const Packet abs_x_is_one = pand(abs_x_is_le_one, abs_x_is_ge_one); - - Packet pow_is_inf_if_exp_is_neg = por(x_is_zero, pand(abs_x_is_le_one, exp_is_inf)); - Packet pow_is_inf_if_exp_is_pos = por(abs_x_is_inf, pand(abs_x_is_ge_one, exp_is_inf)); - Packet pow_is_one = pand(abs_x_is_one, por(exp_is_inf, x_is_ge_zero)); - - Packet result = powx; - result = por(x_is_le_zero, result); - result = pselect(pow_is_inf_if_exp_is_neg, pand(cst_pos_inf, exp_is_neg), result); - result = pselect(pow_is_inf_if_exp_is_pos, pand(cst_pos_inf, exp_is_pos), result); - result = por(exp_is_nan, result); - result = pselect(pow_is_one, cst_pos_one, result); - return result; -} - -template ::type>::IsSigned, bool> = true> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent& exponent) { - using Scalar = typename unpacket_traits::type; - - // signed integer base, signed integer exponent case - - // This routine handles negative exponents. - // The return value is either 0, 1, or -1. - const Packet cst_pos_one = pset1(Scalar(1)); - const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0); - const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x); - - const Packet abs_x = pabs(x); - const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); - - Packet result = pselect(exp_is_odd, x, abs_x); - result = pselect(abs_x_is_one, result, pzero(x)); - return result; -} - -template ::type>::IsSigned, bool> = true> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent&) { - using Scalar = typename unpacket_traits::type; - - // unsigned integer base, signed integer exponent case - - // This routine handles negative exponents. - // The return value is either 0 or 1 - - const Scalar pos_one = Scalar(1); - - const Packet cst_pos_one = pset1(pos_one); - - const Packet x_is_one = pcmp_eq(x, cst_pos_one); - - return pand(x_is_one, x); -} - -} // end namespace unary_pow - -template ::type>::IsInteger, - bool ExponentIsIntegerType = NumTraits::IsInteger, - bool ExponentIsSigned = NumTraits::IsSigned> -struct unary_pow_impl; - -template -struct unary_pow_impl { - typedef typename unpacket_traits::type Scalar; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; - if (exponent_is_integer) { - // The simple recursive doubling implementation is only accurate to 3 ulps - // for integer exponents in [-3:7]. Since this is a common case, we - // specialize it here. - bool use_repeated_squaring = - (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))); - return use_repeated_squaring ? unary_pow::int_pow(x, exponent) : generic_pow(x, pset1(exponent)); - } else { - Packet result = unary_pow::gen_pow(x, exponent); - result = unary_pow::handle_nonint_nonint_errors(x, result, exponent); - return result; - } - } -}; - -template -struct unary_pow_impl { - typedef typename unpacket_traits::type Scalar; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - return unary_pow::int_pow(x, exponent); - } -}; - -template -struct unary_pow_impl { - typedef typename unpacket_traits::type Scalar; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - if (exponent < ScalarExponent(0)) { - return unary_pow::handle_negative_exponent(x, exponent); - } else { - return unary_pow::int_pow(x, exponent); - } - } -}; - -template -struct unary_pow_impl { - typedef typename unpacket_traits::type Scalar; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - return unary_pow::int_pow(x, exponent); - } -}; - //---------------------------------------------------------------------- // Rounding Functions //---------------------------------------------------------------------- diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathPow.h b/Eigen/src/Core/arch/Default/GenericPacketMathPow.h new file mode 100644 index 000000000..08686f86b --- /dev/null +++ b/Eigen/src/Core/arch/Default/GenericPacketMathPow.h @@ -0,0 +1,710 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018-2025 Rasmus Munk Larsen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_POW_H +#define EIGEN_ARCH_GENERIC_PACKET_MATH_POW_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +//---------------------------------------------------------------------- +// Cubic Root Functions +//---------------------------------------------------------------------- + +// This function implements a single step of Halley's iteration for +// computing x = y^(1/3): +// x_{k+1} = x_k - (x_k^3 - y) x_k / (2x_k^3 + y) +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_halley_iteration_step(const Packet& x_k, + const Packet& y) { + typedef typename unpacket_traits::type Scalar; + Packet x_k_cb = pmul(x_k, pmul(x_k, x_k)); + Packet denom = pmadd(pset1(Scalar(2)), x_k_cb, y); + Packet num = psub(x_k_cb, y); + Packet r = pdiv(num, denom); + return pnmadd(x_k, r, x_k); +} + +// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the +// interval [0.125,1]. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_decompose(const Packet& x, Packet& e_div3) { + typedef typename unpacket_traits::type Scalar; + // Extract the significant s in the range [0.5,1) and exponent e, such that + // x = 2^e * s. + Packet e, s; + s = pfrexp(x, e); + + // Split the exponent into a part divisible by 3 and the remainder. + // e = 3*e_div3 + e_mod3. + constexpr Scalar kOneThird = Scalar(1) / 3; + e_div3 = pceil(pmul(e, pset1(kOneThird))); + Packet e_mod3 = pnmadd(pset1(Scalar(3)), e_div3, e); + + // Replace s by y = (s * 2^e_mod3). + return pldexp_fast(s, e_mod3); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_special_cases_and_sign(const Packet& x, + const Packet& abs_root) { + typedef typename unpacket_traits::type Scalar; + + // Set sign. + const Packet sign_mask = pset1(Scalar(-0.0)); + const Packet x_sign = pand(sign_mask, x); + Packet root = por(x_sign, abs_root); + + // Pass non-finite and zero values of x straight through. + const Packet is_not_finite = por(pisinf(x), pisnan(x)); + const Packet is_zero = pcmp_eq(pzero(x), x); + const Packet use_x = por(is_not_finite, is_zero); + return pselect(use_x, x, root); +} + +// Generic implementation of cbrt(x) for float. +// +// The algorithm computes the cubic root of the input by first +// decomposing it into a exponent and significant +// x = s * 2^e. +// +// We can then write the cube root as +// +// x^(1/3) = 2^(e/3) * s^(1/3) +// = 2^((3*e_div3 + e_mod3)/3) * s^(1/3) +// = 2^(e_div3) * 2^(e_mod3/3) * s^(1/3) +// = 2^(e_div3) * (s * 2^e_mod3)^(1/3) +// +// where e_div3 = ceil(e/3) and e_mod3 = e - 3*e_div3. +// +// The cube root of the second term y = (s * 2^e_mod3)^(1/3) is coarsely +// approximated using a cubic polynomial and subsequently refined using a +// single step of Halley's iteration, and finally the two terms are combined +// using pldexp_fast. +// +// Note: Many alternatives exist for implementing cbrt. See, for example, +// the excellent discussion in Kahan's note: +// https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf +// This particular implementation was found to be very fast and accurate +// among several alternatives tried, but is probably not "optimal" on all +// platforms. +// +// This is accurate to 2 ULP. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be float"); + + // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the + // interval [0.125,1]. + Packet e_div3; + const Packet y = cbrt_decompose(pabs(x), e_div3); + + // Compute initial approximation accurate to 5.22e-3. + // The polynomial was computed using Rminimax. + constexpr float alpha[] = {5.9220016002655029296875e-01f, -1.3859539031982421875e+00f, 1.4581282138824462890625e+00f, + 3.408401906490325927734375e-01f}; + Packet r = ppolevl::run(y, alpha); + + // Take one step of Halley's iteration. + r = cbrt_halley_iteration_step(r, y); + + // Finally multiply by 2^(e_div3) + r = pldexp_fast(r, e_div3); + + return cbrt_special_cases_and_sign(x, r); +} + +// Generic implementation of cbrt(x) for double. +// +// The algorithm is identical to the one for float except that a different initial +// approximation is used for y^(1/3) and two Halley iteration steps are performed. +// +// This is accurate to 1 ULP. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be double"); + + // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the + // interval [0.125,1]. + Packet e_div3; + const Packet y = cbrt_decompose(pabs(x), e_div3); + + // Compute initial approximation accurate to 0.016. + // The polynomial was computed using Rminimax. + constexpr double alpha[] = {-4.69470621553356115551736138513660989701747894287109375e-01, + 1.072314636518546304699839311069808900356292724609375e+00, + 3.81249427609571867048288140722434036433696746826171875e-01}; + Packet r = ppolevl::run(y, alpha); + + // Take two steps of Halley's iteration. + r = cbrt_halley_iteration_step(r, y); + r = cbrt_halley_iteration_step(r, y); + + // Finally multiply by 2^(e_div3). + r = pldexp_fast(r, e_div3); + return cbrt_special_cases_and_sign(x, r); +} + +//---------------------------------------------------------------------- +// Power Functions (accurate_log2, generic_pow, unary_pow) +//---------------------------------------------------------------------- + +// This function computes log2(x) and returns the result as a double word. +template +struct accurate_log2 { + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + log2_x_hi = plog2(x); + log2_x_lo = pzero(x); + } +}; + +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.56508e-10. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Rminimax tool, +// see https://gitlab.inria.fr/sfilip/rminimax. +// Command line: +// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]' +// --type=[10,0] +// --numF="[D,D,SG]" --denF="[SG]" --log --dispCoeff="dec" +// +// The resulting implementation of pow(x,y) is accurate to 3 ulps. +template <> +struct accurate_log2 { + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) { + // Split the two lowest order constant coefficient into double-word representation. + constexpr double kC0 = 1.442695041742110273474963832995854318141937255859375e+00; + constexpr float kC0_hi = static_cast(kC0); + constexpr float kC0_lo = static_cast(kC0 - static_cast(kC0_hi)); + const Packet c0_hi = pset1(kC0_hi); + const Packet c0_lo = pset1(kC0_lo); + + constexpr double kC1 = -7.2134751588268664068692714863573201000690460205078125e-01; + constexpr float kC1_hi = static_cast(kC1); + constexpr float kC1_lo = static_cast(kC1 - static_cast(kC1_hi)); + const Packet c1_hi = pset1(kC1_hi); + const Packet c1_lo = pset1(kC1_lo); + + constexpr float c[] = { + 9.7010828554630279541015625e-02, -1.6896486282348632812500000e-01, 1.7200836539268493652343750e-01, + -1.7892272770404815673828125e-01, 2.0505344867706298828125000e-01, -2.4046677350997924804687500e-01, + 2.8857553005218505859375000e-01, -3.6067414283752441406250000e-01, 4.8089790344238281250000000e-01}; + + // Evaluate the higher order terms in the polynomial using + // standard arithmetic. + const Packet one = pset1(1.0f); + const Packet x = psub(z, one); + Packet p = ppolevl::run(x, c); + // Evaluate the final two step in Horner's rule using double-word + // arithmetic. + Packet p_hi, p_lo; + twoprod(x, p, p_hi, p_lo); + fast_twosum(c1_hi, c1_lo, p_hi, p_lo, p_hi, p_lo); + twoprod(p_hi, p_lo, x, p_hi, p_lo); + fast_twosum(c0_hi, c0_lo, p_hi, p_lo, p_hi, p_lo); + // Multiply by x to recover log2(z). + twoprod(p_hi, p_lo, x, log2_x_hi, log2_x_lo); + } +}; + +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. + +template <> +struct accurate_log2 { + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + // We use a transformation of variables: + // r = c * (x-1) / (x+1), + // such that + // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r). + // The function f(r) can be approximated well using an odd polynomial + // of the form + // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r, + // For the implementation of log2 here, Q is of degree 6 with + // coefficient represented in working precision (double), while C is a + // constant represented in extra precision as a double word to achieve + // full accuracy. + // + // The polynomial coefficients were computed by the Sollya script: + // + // c = 2 / log(2); + // trans = c * (x-1)/(x+1); + // itrans = (1+x/c)/(1-x/c); + // interval=[trans(sqrt(0.5)); trans(sqrt(2))]; + // print(interval); + // f = log2(itrans(x)); + // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating); + const Packet q12 = pset1(2.87074255468000586e-9); + const Packet q10 = pset1(2.38957980901884082e-8); + const Packet q8 = pset1(2.31032094540014656e-7); + const Packet q6 = pset1(2.27279857398537278e-6); + const Packet q4 = pset1(2.31271023278625638e-5); + const Packet q2 = pset1(2.47556738444535513e-4); + const Packet q0 = pset1(2.88543873228900172e-3); + const Packet C_hi = pset1(0.0400377511598501157); + const Packet C_lo = pset1(-4.77726582251425391e-19); + const Packet one = pset1(1.0); + + const Packet cst_2_log2e_hi = pset1(2.88539008177792677); + const Packet cst_2_log2e_lo = pset1(4.07660016854549667e-17); + // c * (x - 1) + Packet t_hi, t_lo; + // t = c * (x-1) + twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), t_hi, t_lo); + // r = c * (x-1) / (x+1), + Packet r_hi, r_lo; + doubleword_div_fp(t_hi, t_lo, padd(x, one), r_hi, r_lo); + + // r2 = r * r + Packet r2_hi, r2_lo; + twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo); + // r4 = r2 * r2 + Packet r4_hi, r4_lo; + twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo); + + // Evaluate Q(r^2) in working precision. We evaluate it in two parts + // (even and odd in r^2) to improve instruction level parallelism. + Packet q_even = pmadd(q12, r4_hi, q8); + Packet q_odd = pmadd(q10, r4_hi, q6); + q_even = pmadd(q_even, r4_hi, q4); + q_odd = pmadd(q_odd, r4_hi, q2); + q_even = pmadd(q_even, r4_hi, q0); + Packet q = pmadd(q_odd, r2_hi, q_even); + + // Now evaluate the low order terms of P(x) in double word precision. + // In the following, due to the increasing magnitude of the coefficients + // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead + // of the slower twosum. + // Q(r^2) * r^2 + Packet p_hi, p_lo; + twoprod(r2_hi, r2_lo, q, p_hi, p_lo); + // Q(r^2) * r^2 + C + Packet p1_hi, p1_lo; + fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo); + // (Q(r^2) * r^2 + C) * r^2 + Packet p2_hi, p2_lo; + twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo); + // ((Q(r^2) * r^2 + C) * r^2 + 1) + Packet p3_hi, p3_lo; + fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo); + + // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r + twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo); + } +}; + +// This function implements the non-trivial case of pow(x,y) where x is +// positive and y is (possibly) non-integer. +// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. +// TODO(rmlarsen): We should probably add this as a packet op 'ppow', to make it +// easier to specialize or turn off for specific types and/or backends. +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { + typedef typename unpacket_traits::type Scalar; + // Split x into exponent e_x and mantissa m_x. + Packet e_x; + Packet m_x = pfrexp(x, e_x); + + // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x). + constexpr Scalar sqrt_half = Scalar(0.70710678118654752440); + const Packet m_x_scale_mask = pcmp_lt(m_x, pset1(sqrt_half)); + m_x = pselect(m_x_scale_mask, pmul(pset1(Scalar(2)), m_x), m_x); + e_x = pselect(m_x_scale_mask, psub(e_x, pset1(Scalar(1))), e_x); + + // Compute log2(m_x) with 6 extra bits of accuracy. + Packet rx_hi, rx_lo; + accurate_log2()(m_x, rx_hi, rx_lo); + + // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled + // precision using double word arithmetic. + Packet f1_hi, f1_lo, f2_hi, f2_lo; + twoprod(e_x, y, f1_hi, f1_lo); + twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo); + // Sum the two terms in f using double word arithmetic. We know + // that |e_x| > |log2(m_x)|, except for the case where e_x==0. + // This means that we can use fast_twosum(f1,f2). + // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any + // accuracy by violating the assumption of fast_twosum, because + // it's a no-op. + Packet f_hi, f_lo; + fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo); + + // Split f into integer and fractional parts. + Packet n_z, r_z; + absolute_split(f_hi, n_z, r_z); + r_z = padd(r_z, f_lo); + Packet n_r; + absolute_split(r_z, n_r, r_z); + n_z = padd(n_z, n_r); + + // We now have an accurate split of f = n_z + r_z and can compute + // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. + // Multiplication by the second factor can be done exactly using pldexp(), since + // it is an integer power of 2. + const Packet e_r = generic_exp2(r_z); + + // Since we know that e_r is in [1/sqrt(2); sqrt(2)], we can use the fast version + // of pldexp to multiply by 2**{n_z} when |n_z| is sufficiently small. + constexpr Scalar kPldExpThresh = std::numeric_limits::max_exponent - 2; + const Packet pldexp_fast_unsafe = pcmp_lt(pset1(kPldExpThresh), pabs(n_z)); + if (predux_any(pldexp_fast_unsafe)) { + return pldexp(e_r, n_z); + } + return pldexp_fast(e_r, n_z); +} + +// Generic implementation of pow(x,y). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Packet> generic_pow( + const Packet& x, const Packet& y) { + typedef typename unpacket_traits::type Scalar; + + const Packet cst_inf = pset1(NumTraits::infinity()); + const Packet cst_zero = pset1(Scalar(0)); + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_nan = pset1(NumTraits::quiet_NaN()); + + const Packet x_abs = pabs(x); + Packet pow = generic_pow_impl(x_abs, y); + + // In the following we enforce the special case handling prescribed in + // https://en.cppreference.com/w/cpp/numeric/math/pow. + + // Predicates for sign and magnitude of x. + const Packet x_is_negative = pcmp_lt(x, cst_zero); + const Packet x_is_zero = pcmp_eq(x, cst_zero); + const Packet x_is_one = pcmp_eq(x, cst_one); + const Packet x_has_signbit = psignbit(x); + const Packet x_abs_gt_one = pcmp_lt(cst_one, x_abs); + const Packet x_abs_is_inf = pcmp_eq(x_abs, cst_inf); + + // Predicates for sign and magnitude of y. + const Packet y_abs = pabs(y); + const Packet y_abs_is_inf = pcmp_eq(y_abs, cst_inf); + const Packet y_is_negative = pcmp_lt(y, cst_zero); + const Packet y_is_zero = pcmp_eq(y, cst_zero); + const Packet y_is_one = pcmp_eq(y, cst_one); + // Predicates for whether y is integer and odd/even. + const Packet y_is_int = pandnot(pcmp_eq(pfloor(y), y), y_abs_is_inf); + const Packet y_div_2 = pmul(y, pset1(Scalar(0.5))); + const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); + const Packet y_is_odd_int = pandnot(y_is_int, y_is_even); + // Smallest exponent for which (1 + epsilon) overflows to infinity. + constexpr Scalar huge_exponent = + (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits::epsilon(); + const Packet y_abs_is_huge = pcmp_le(pset1(huge_exponent), y_abs); + + // * pow(base, exp) returns NaN if base is finite and negative + // and exp is finite and non-integer. + pow = pselect(pandnot(x_is_negative, y_is_int), cst_nan, pow); + + // * pow(±0, exp), where exp is negative, finite, and is an even integer or + // a non-integer, returns +∞ + // * pow(±0, exp), where exp is positive non-integer or a positive even + // integer, returns +0 + // * pow(+0, exp), where exp is a negative odd integer, returns +∞ + // * pow(-0, exp), where exp is a negative odd integer, returns -∞ + // * pow(+0, exp), where exp is a positive odd integer, returns +0 + // * pow(-0, exp), where exp is a positive odd integer, returns -0 + // Sign is flipped by the rule below. + pow = pselect(x_is_zero, pselect(y_is_negative, cst_inf, cst_zero), pow); + + // pow(base, exp) returns -pow(abs(base), exp) if base has the sign bit set, + // and exp is an odd integer exponent. + pow = pselect(pand(x_has_signbit, y_is_odd_int), pnegate(pow), pow); + + // * pow(base, -∞) returns +∞ for any |base|<1 + // * pow(base, -∞) returns +0 for any |base|>1 + // * pow(base, +∞) returns +0 for any |base|<1 + // * pow(base, +∞) returns +∞ for any |base|>1 + // * pow(±0, -∞) returns +∞ + // * pow(-1, +-∞) = 1 + Packet inf_y_val = pselect(por(pand(y_is_negative, x_is_zero), pxor(y_is_negative, x_abs_gt_one)), cst_inf, cst_zero); + inf_y_val = pselect(pcmp_eq(x, pset1(Scalar(-1.0))), cst_one, inf_y_val); + pow = pselect(y_abs_is_huge, inf_y_val, pow); + + // * pow(+∞, exp) returns +0 for any negative exp + // * pow(+∞, exp) returns +∞ for any positive exp + // * pow(-∞, exp) returns -0 if exp is a negative odd integer. + // * pow(-∞, exp) returns +0 if exp is a negative non-integer or negative + // even integer. + // * pow(-∞, exp) returns -∞ if exp is a positive odd integer. + // * pow(-∞, exp) returns +∞ if exp is a positive non-integer or positive + // even integer. + auto x_pos_inf_value = pselect(y_is_negative, cst_zero, cst_inf); + auto x_neg_inf_value = pselect(y_is_odd_int, pnegate(x_pos_inf_value), x_pos_inf_value); + pow = pselect(x_abs_is_inf, pselect(x_is_negative, x_neg_inf_value, x_pos_inf_value), pow); + + // All cases of NaN inputs return NaN, except the two below. + pow = pselect(por(pisnan(x), pisnan(y)), cst_nan, pow); + + // * pow(base, 1) returns base. + // * pow(base, +/-0) returns 1, regardless of base, even NaN. + // * pow(+1, exp) returns 1, regardless of exponent, even NaN. + pow = pselect(y_is_one, x, pselect(por(x_is_one, y_is_zero), cst_one, pow)); + + return pow; +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Scalar> generic_pow( + const Scalar& x, const Scalar& y) { + return numext::pow(x, y); +} + +namespace unary_pow { + +template ::IsInteger> +struct exponent_helper { + using safe_abs_type = ScalarExponent; + static constexpr ScalarExponent one_half = ScalarExponent(0.5); + // these routines assume that exp is an integer stored as a floating point type + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent safe_abs(const ScalarExponent& exp) { + return numext::abs(exp); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const ScalarExponent& exp) { + eigen_assert(((numext::isfinite)(exp) && exp == numext::floor(exp)) && "exp must be an integer"); + ScalarExponent exp_div_2 = exp * one_half; + ScalarExponent floor_exp_div_2 = numext::floor(exp_div_2); + return exp_div_2 != floor_exp_div_2; + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent floor_div_two(const ScalarExponent& exp) { + ScalarExponent exp_div_2 = exp * one_half; + return numext::floor(exp_div_2); + } +}; + +template +struct exponent_helper { + // if `exp` is a signed integer type, cast it to its unsigned counterpart to safely store its absolute value + // consider the (rare) case where `exp` is an int32_t: abs(-2147483648) != 2147483648 + using safe_abs_type = typename numext::get_integer_by_size::unsigned_type; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type safe_abs(const ScalarExponent& exp) { + ScalarExponent mask = numext::signbit(exp); + safe_abs_type result = safe_abs_type(exp ^ mask); + return result + safe_abs_type(ScalarExponent(1) & mask); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const safe_abs_type& exp) { + return exp % safe_abs_type(2) != safe_abs_type(0); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type floor_div_two(const safe_abs_type& exp) { + return exp >> safe_abs_type(1); + } +}; + +template ::type>::IsInteger && NumTraits::IsSigned> +struct reciprocate { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; + const Packet cst_pos_one = pset1(Scalar(1)); + return exponent < 0 ? pdiv(cst_pos_one, x) : x; + } +}; + +template +struct reciprocate { + // pdiv not defined, nor necessary for integer base types + // if the exponent is unsigned, then the exponent cannot be negative + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; } +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; + using ExponentHelper = exponent_helper; + using AbsExponentType = typename ExponentHelper::safe_abs_type; + const Packet cst_pos_one = pset1(Scalar(1)); + if (exponent == ScalarExponent(0)) return cst_pos_one; + + Packet result = reciprocate::run(x, exponent); + Packet y = cst_pos_one; + AbsExponentType m = ExponentHelper::safe_abs(exponent); + + while (m > 1) { + bool odd = ExponentHelper::is_odd(m); + if (odd) y = pmul(y, result); + result = pmul(result, result); + m = ExponentHelper::floor_div_two(m); + } + + return pmul(y, result); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> gen_pow( + const Packet& x, const typename unpacket_traits::type& exponent) { + const Packet exponent_packet = pset1(exponent); + return generic_pow_impl(x, exponent_packet); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Scalar> gen_pow( + const Scalar& x, const Scalar& exponent) { + return numext::pow(x, exponent); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, + const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; + + // non-integer base and exponent case + const Packet cst_pos_zero = pzero(x); + const Packet cst_pos_one = pset1(Scalar(1)); + const Packet cst_pos_inf = pset1(NumTraits::infinity()); + const Packet cst_true = ptrue(x); + + const bool exponent_is_not_fin = !(numext::isfinite)(exponent); + const bool exponent_is_neg = exponent < ScalarExponent(0); + const bool exponent_is_pos = exponent > ScalarExponent(0); + + const Packet exp_is_not_fin = exponent_is_not_fin ? cst_true : cst_pos_zero; + const Packet exp_is_neg = exponent_is_neg ? cst_true : cst_pos_zero; + const Packet exp_is_pos = exponent_is_pos ? cst_true : cst_pos_zero; + const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); + const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); + + const Packet x_is_le_zero = pcmp_le(x, cst_pos_zero); + const Packet x_is_ge_zero = pcmp_le(cst_pos_zero, x); + const Packet x_is_zero = pand(x_is_le_zero, x_is_ge_zero); + + const Packet abs_x = pabs(x); + const Packet abs_x_is_le_one = pcmp_le(abs_x, cst_pos_one); + const Packet abs_x_is_ge_one = pcmp_le(cst_pos_one, abs_x); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + const Packet abs_x_is_one = pand(abs_x_is_le_one, abs_x_is_ge_one); + + Packet pow_is_inf_if_exp_is_neg = por(x_is_zero, pand(abs_x_is_le_one, exp_is_inf)); + Packet pow_is_inf_if_exp_is_pos = por(abs_x_is_inf, pand(abs_x_is_ge_one, exp_is_inf)); + Packet pow_is_one = pand(abs_x_is_one, por(exp_is_inf, x_is_ge_zero)); + + Packet result = powx; + result = por(x_is_le_zero, result); + result = pselect(pow_is_inf_if_exp_is_neg, pand(cst_pos_inf, exp_is_neg), result); + result = pselect(pow_is_inf_if_exp_is_pos, pand(cst_pos_inf, exp_is_pos), result); + result = por(exp_is_nan, result); + result = pselect(pow_is_one, cst_pos_one, result); + return result; +} + +template ::type>::IsSigned, bool> = true> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; + + // signed integer base, signed integer exponent case + + // This routine handles negative exponents. + // The return value is either 0, 1, or -1. + const Packet cst_pos_one = pset1(Scalar(1)); + const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0); + const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x); + + const Packet abs_x = pabs(x); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); + + Packet result = pselect(exp_is_odd, x, abs_x); + result = pselect(abs_x_is_one, result, pzero(x)); + return result; +} + +template ::type>::IsSigned, bool> = true> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent&) { + using Scalar = typename unpacket_traits::type; + + // unsigned integer base, signed integer exponent case + + // This routine handles negative exponents. + // The return value is either 0 or 1 + + const Scalar pos_one = Scalar(1); + + const Packet cst_pos_one = pset1(pos_one); + + const Packet x_is_one = pcmp_eq(x, cst_pos_one); + + return pand(x_is_one, x); +} + +} // end namespace unary_pow + +template ::type>::IsInteger, + bool ExponentIsIntegerType = NumTraits::IsInteger, + bool ExponentIsSigned = NumTraits::IsSigned> +struct unary_pow_impl; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; + if (exponent_is_integer) { + // The simple recursive doubling implementation is only accurate to 3 ulps + // for integer exponents in [-3:7]. Since this is a common case, we + // specialize it here. + bool use_repeated_squaring = + (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))); + return use_repeated_squaring ? unary_pow::int_pow(x, exponent) : generic_pow(x, pset1(exponent)); + } else { + Packet result = unary_pow::gen_pow(x, exponent); + result = unary_pow::handle_nonint_nonint_errors(x, result, exponent); + return result; + } + } +}; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + return unary_pow::int_pow(x, exponent); + } +}; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + if (exponent < ScalarExponent(0)) { + return unary_pow::handle_negative_exponent(x, exponent); + } else { + return unary_pow::int_pow(x, exponent); + } + } +}; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + return unary_pow::int_pow(x, exponent); + } +}; + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_POW_H diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h b/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h new file mode 100644 index 000000000..1d7c594bb --- /dev/null +++ b/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h @@ -0,0 +1,833 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2009-2019 Gael Guennebaud +// Copyright (C) 2018-2025 Rasmus Munk Larsen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_TRIG_H +#define EIGEN_ARCH_GENERIC_PACKET_MATH_TRIG_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +//---------------------------------------------------------------------- +// Trigonometric Functions +//---------------------------------------------------------------------- + +// Enum for selecting which function to compute. SinCos is intended to compute +// pairs of Sin and Cos of the even entries in the packet, e.g. +// SinCos([a, *, b, *]) = [sin(a), cos(a), sin(b), cos(b)]. +enum class TrigFunction : uint8_t { Sin, Cos, Tan, SinCos }; + +// The following code is inspired by the following stack-overflow answer: +// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751 +// It has been largely optimized: +// - By-pass calls to frexp. +// - Aligned loads of required 96 bits of 2/pi. This is accomplished by +// (1) balancing the mantissa and exponent to the required bits of 2/pi are +// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi. +// - Avoid a branch in rounding and extraction of the remaining fractional part. +// Overall, I measured a speed up higher than x2 on x86-64. +inline float trig_reduce_huge(float xf, Eigen::numext::int32_t* quadrant) { + using Eigen::numext::int32_t; + using Eigen::numext::int64_t; + using Eigen::numext::uint32_t; + using Eigen::numext::uint64_t; + + const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62 + const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point format + + // 192 bits of 2/pi for Payne-Hanek reduction + // Bits are introduced by packet of 8 to enable aligned reads. + static const uint32_t two_over_pi[] = { + 0x00000028, 0x000028be, 0x0028be60, 0x28be60db, 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a, 0x91054a7f, + 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4, 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770, 0x4d377036, 0x377036d8, + 0x7036d8a5, 0x36d8a566, 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410, 0x10e41000, 0xe4100000}; + + uint32_t xi = numext::bit_cast(xf); + // Below, -118 = -126 + 8. + // -126 is to get the exponent, + // +8 is to enable alignment of 2/pi's bits on 8 bits. + // This is possible because the fractional part of x as only 24 meaningful bits. + uint32_t e = (xi >> 23) - 118; + // Extract the mantissa and shift it to align it wrt the exponent + xi = ((xi & 0x007fffffu) | 0x00800000u) << (e & 0x7); + + uint32_t i = e >> 3; + uint32_t twoopi_1 = two_over_pi[i - 1]; + uint32_t twoopi_2 = two_over_pi[i + 3]; + uint32_t twoopi_3 = two_over_pi[i + 7]; + + // Compute x * 2/pi in 2.62-bit fixed-point format. + uint64_t p; + p = uint64_t(xi) * twoopi_3; + p = uint64_t(xi) * twoopi_2 + (p >> 32); + p = (uint64_t(xi * twoopi_1) << 32) + p; + + // Round to nearest: add 0.5 and extract integral part. + uint64_t q = (p + zero_dot_five) >> 62; + *quadrant = int(q); + // Now it remains to compute "r = x - q*pi/2" with high accuracy, + // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as: + // r = (p-q)*pi/2, + // where the product can be be carried out with sufficient accuracy using double precision. + p -= q << 62; + return float(double(int64_t(p)) * pio2_62); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +#if EIGEN_COMP_GNUC_STRICT + __attribute__((optimize("-fno-unsafe-math-optimizations"))) +#endif + Packet + psincos_float(const Packet& _x) { + typedef typename unpacket_traits::integer_packet PacketI; + + const Packet cst_2oPI = pset1(0.636619746685028076171875f); // 2/PI + const Packet cst_rounding_magic = pset1(12582912); // 2^23 for rounding + const PacketI csti_1 = pset1(1); + const Packet cst_sign_mask = pset1frombits(static_cast(0x80000000u)); + + Packet x = pabs(_x); + + // Scale x by 2/Pi to find x's octant. + Packet y = pmul(x, cst_2oPI); + + // Rounding trick to find nearest integer: + Packet y_round = padd(y, cst_rounding_magic); + EIGEN_OPTIMIZATION_BARRIER(y_round) + PacketI y_int = preinterpret(y_round); // last 23 digits represent integer (if abs(x)<2^24) + y = psub(y_round, cst_rounding_magic); // nearest integer to x * (2/pi) + +// Subtract y * Pi/2 to reduce x to the interval -Pi/4 <= x <= +Pi/4 +// using "Extended precision modular arithmetic" +#if defined(EIGEN_VECTORIZE_FMA) + // This version requires true FMA for high accuracy. + // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08): + constexpr float huge_th = (Func == TrigFunction::Sin) ? 117435.992f : 71476.0625f; + x = pmadd(y, pset1(-1.57079601287841796875f), x); + x = pmadd(y, pset1(-3.1391647326017846353352069854736328125e-07f), x); + x = pmadd(y, pset1(-5.390302529957764765544681040410068817436695098876953125e-15f), x); +#else + // Without true FMA, the previous set of coefficients maintain 1ULP accuracy + // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7. + // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs. + + // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively. + // and 2 ULP up to: + constexpr float huge_th = (Func == TrigFunction::Sin) ? 25966.f : 18838.f; + x = pmadd(y, pset1(-1.5703125), x); // = 0xbfc90000 + EIGEN_OPTIMIZATION_BARRIER(x) + x = pmadd(y, pset1(-0.000483989715576171875), x); // = 0xb9fdc000 + EIGEN_OPTIMIZATION_BARRIER(x) + x = pmadd(y, pset1(1.62865035235881805419921875e-07), x); // = 0x342ee000 + x = pmadd(y, pset1(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee + +// For the record, the following set of coefficients maintain 2ULP up +// to a slightly larger range: +// const float huge_th = ComputeSine ? 51981.f : 39086.125f; +// but it slightly fails to maintain 1ULP for two values of sin below pi. +// x = pmadd(y, pset1(-3.140625/2.), x); +// x = pmadd(y, pset1(-0.00048351287841796875), x); +// x = pmadd(y, pset1(-3.13855707645416259765625e-07), x); +// x = pmadd(y, pset1(-6.0771006282767103812147979624569416046142578125e-11), x); + +// For the record, with only 3 iterations it is possible to maintain +// 1 ULP up to 3PI (maybe more) and 2ULP up to 255. +// The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee +#endif + + if (predux_any(pcmp_le(pset1(huge_th), pabs(_x)))) { + const int PacketSize = unpacket_traits::size; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize]; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize]; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Eigen::numext::int32_t y_int2[PacketSize]; + pstoreu(vals, pabs(_x)); + pstoreu(x_cpy, x); + pstoreu(y_int2, y_int); + for (int k = 0; k < PacketSize; ++k) { + float val = vals[k]; + if (val >= huge_th && (numext::isfinite)(val)) x_cpy[k] = trig_reduce_huge(val, &y_int2[k]); + } + x = ploadu(x_cpy); + y_int = ploadu(y_int2); + } + + // Get the polynomial selection mask from the second bit of y_int + // We'll calculate both (sin and cos) polynomials and then select from the two. + Packet poly_mask = preinterpret(pcmp_eq(pand(y_int, csti_1), pzero(y_int))); + + Packet x2 = pmul(x, x); + + // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4) + Packet y1 = pset1(2.4372266125283204019069671630859375e-05f); + y1 = pmadd(y1, x2, pset1(-0.00138865201734006404876708984375f)); + y1 = pmadd(y1, x2, pset1(0.041666619479656219482421875f)); + y1 = pmadd(y1, x2, pset1(-0.5f)); + y1 = pmadd(y1, x2, pset1(1.f)); + + // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4) + // octave/matlab code to compute those coefficients: + // x = (0:0.0001:pi/4)'; + // A = [x.^3 x.^5 x.^7]; + // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy + // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1 + // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1)) + // + Packet y2 = pset1(-0.0001959234114083702898469196984621021329076029360294342041015625f); + y2 = pmadd(y2, x2, pset1(0.0083326873655616851693794799871284340042620897293090820312500000f)); + y2 = pmadd(y2, x2, pset1(-0.1666666203982298255503735617821803316473960876464843750000000000f)); + y2 = pmul(y2, x2); + y2 = pmadd(y2, x, x); + + // Select the correct result from the two polynomials. + // Compute the sign to apply to the polynomial. + // sin: sign = second_bit(y_int) xor signbit(_x) + // cos: sign = second_bit(y_int+1) + Packet sign_bit = (Func == TrigFunction::Sin) ? pxor(_x, preinterpret(plogical_shift_left<30>(y_int))) + : preinterpret(plogical_shift_left<30>(padd(y_int, csti_1))); + sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit + + if ((Func == TrigFunction::SinCos) || (Func == TrigFunction::Tan)) { + // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div. + Packet peven = peven_mask(x); + Packet ysin = pselect(poly_mask, y2, y1); + Packet ycos = pselect(poly_mask, y1, y2); + Packet sign_bit_sin = pxor(_x, preinterpret(plogical_shift_left<30>(y_int))); + Packet sign_bit_cos = preinterpret(plogical_shift_left<30>(padd(y_int, csti_1))); + sign_bit_sin = pand(sign_bit_sin, cst_sign_mask); // clear all but left most bit + sign_bit_cos = pand(sign_bit_cos, cst_sign_mask); // clear all but left most bit + y = (Func == TrigFunction::SinCos) ? pselect(peven, pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos)) + : pdiv(pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos)); + } else { + y = (Func == TrigFunction::Sin) ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2); + y = pxor(y, sign_bit); + } + return y; +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_float(const Packet& x) { + return psincos_float(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Packet& x) { + return psincos_float(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_float(const Packet& x) { + return psincos_float(x); +} + +// Trigonometric argument reduction for double for inputs smaller than 15. +// Reduces trigonometric arguments for double inputs where x < 15. Given an argument x and its corresponding quadrant +// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t. +template +Packet trig_reduce_small_double(const Packet& x, const Packet& q) { + // Pi/2 split into 2 values + const Packet cst_pio2_a = pset1(-1.570796325802803); + const Packet cst_pio2_b = pset1(-9.920935184482005e-10); + + Packet t; + t = pmadd(cst_pio2_a, q, x); + t = pmadd(cst_pio2_b, q, t); + return t; +} + +// Trigonometric argument reduction for double for inputs smaller than 1e14. +// Reduces trigonometric arguments for double inputs where x < 1e14. Given an argument x and its corresponding quadrant +// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t. +template +Packet trig_reduce_medium_double(const Packet& x, const Packet& q_high, const Packet& q_low) { + // Pi/2 split into 4 values + const Packet cst_pio2_a = pset1(-1.570796325802803); + const Packet cst_pio2_b = pset1(-9.920935184482005e-10); + const Packet cst_pio2_c = pset1(-6.123234014771656e-17); + const Packet cst_pio2_d = pset1(1.903488962019325e-25); + + Packet t; + t = pmadd(cst_pio2_a, q_high, x); + t = pmadd(cst_pio2_a, q_low, t); + t = pmadd(cst_pio2_b, q_high, t); + t = pmadd(cst_pio2_b, q_low, t); + t = pmadd(cst_pio2_c, q_high, t); + t = pmadd(cst_pio2_c, q_low, t); + t = pmadd(cst_pio2_d, padd(q_low, q_high), t); + return t; +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +#if EIGEN_COMP_GNUC_STRICT + __attribute__((optimize("-fno-unsafe-math-optimizations"))) +#endif + Packet + psincos_double(const Packet& x) { + typedef typename unpacket_traits::integer_packet PacketI; + typedef typename unpacket_traits::type ScalarI; + + const Packet cst_sign_mask = pset1frombits(static_cast(0x8000000000000000u)); + + // If the argument is smaller than this value, use a simpler argument reduction + const double small_th = 15; + // If the argument is bigger than this value, use the non-vectorized std version + const double huge_th = 1e14; + + const Packet cst_2oPI = pset1(0.63661977236758134307553505349006); // 2/PI + // Integer Packet constants + const PacketI cst_one = pset1(ScalarI(1)); + // Constant for splitting + const Packet cst_split = pset1(1 << 24); + + Packet x_abs = pabs(x); + + // Scale x by 2/Pi + PacketI q_int; + Packet s; + + // TODO Implement huge angle argument reduction + if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1(small_th), x_abs)))) { + Packet q_high = pmul(pfloor(pmul(x_abs, pdiv(cst_2oPI, cst_split))), cst_split); + Packet q_low_noround = psub(pmul(x_abs, cst_2oPI), q_high); + q_int = pcast(padd(q_low_noround, pset1(0.5))); + Packet q_low = pcast(q_int); + s = trig_reduce_medium_double(x_abs, q_high, q_low); + } else { + Packet qval_noround = pmul(x_abs, cst_2oPI); + q_int = pcast(padd(qval_noround, pset1(0.5))); + Packet q = pcast(q_int); + s = trig_reduce_small_double(x_abs, q); + } + + // All the upcoming approximating polynomials have even exponents + Packet ss = pmul(s, s); + + // Padé approximant of cos(x) + // Assuring < 1 ULP error on the interval [-pi/4, pi/4] + // cos(x) ~= (80737373*x^8 - 13853547000*x^6 + 727718024880*x^4 - 11275015752000*x^2 + 23594700729600)/(147173*x^8 + + // 39328920*x^6 + 5772800880*x^4 + 522334612800*x^2 + 23594700729600) + // MATLAB code to compute those coefficients: + // syms x; + // cosf = @(x) cos(x); + // pade_cosf = pade(cosf(x), x, 0, 'Order', 8) + const Packet cn4 = pset1(80737373); + const Packet cn3 = pset1(-13853547000); + const Packet cn2 = pset1(727718024880); + const Packet cn1 = pset1(-11275015752000); + const Packet cn0 = pset1(23594700729600); // shared with cd0 + const Packet cd3 = pset1(147173); + const Packet cd2 = pset1(39328920); + const Packet cd1 = pset1(5772800880); + const Packet cd0 = pset1(522334612800); + Packet sc1_num = pmadd(ss, cn4, cn3); + Packet sc2_num = pmadd(sc1_num, ss, cn2); + Packet sc3_num = pmadd(sc2_num, ss, cn1); + Packet sc4_num = pmadd(sc3_num, ss, cn0); + Packet sc1_denum = pmadd(ss, cd3, cd2); + Packet sc2_denum = pmadd(sc1_denum, ss, cd1); + Packet sc3_denum = pmadd(sc2_denum, ss, cd0); + Packet sc4_denum = pmadd(sc3_denum, ss, cn0); + Packet scos = pdiv(sc4_num, sc4_denum); + + // Padé approximant of sin(x) + // Assuring < 1 ULP error on the interval [-pi/4, pi/4] + // sin(x) ~= (x*(4585922449*x^8 - 1066023933480*x^6 + 83284044283440*x^4 - 2303682236856000*x^2 + + // 15605159573203200))/(45*(1029037*x^8 + 345207016*x^6 + 61570292784*x^4 + 6603948711360*x^2 + 346781323848960)) + // MATLAB code to compute those coefficients: + // syms x; + // sinf = @(x) sin(x); + // pade_sinf = pade(sinf(x), x, 0, 'Order', 8, 'OrderMode', 'relative') + const Packet sn4 = pset1(4585922449); + const Packet sn3 = pset1(-1066023933480); + const Packet sn2 = pset1(83284044283440); + const Packet sn1 = pset1(-2303682236856000); + const Packet sn0 = pset1(15605159573203200); + const Packet sd3 = pset1(1029037); + const Packet sd2 = pset1(345207016); + const Packet sd1 = pset1(61570292784); + const Packet sd0_inner = pset1(6603948711360); + const Packet sd0 = pset1(346781323848960); + const Packet cst_45 = pset1(45); + Packet ss1_num = pmadd(ss, sn4, sn3); + Packet ss2_num = pmadd(ss1_num, ss, sn2); + Packet ss3_num = pmadd(ss2_num, ss, sn1); + Packet ss4_num = pmadd(ss3_num, ss, sn0); + Packet ss1_denum = pmadd(ss, sd3, sd2); + Packet ss2_denum = pmadd(ss1_denum, ss, sd1); + Packet ss3_denum = pmadd(ss2_denum, ss, sd0_inner); + Packet ss4_denum = pmadd(ss3_denum, ss, sd0); + Packet ssin = pdiv(pmul(s, ss4_num), pmul(cst_45, ss4_denum)); + + Packet poly_mask = preinterpret(pcmp_eq(pand(q_int, cst_one), pzero(q_int))); + + Packet sign_sin = pxor(x, preinterpret(plogical_shift_left<62>(q_int))); + Packet sign_cos = preinterpret(plogical_shift_left<62>(padd(q_int, cst_one))); + Packet sign_bit, sFinalRes; + if (Func == TrigFunction::Sin) { + sign_bit = sign_sin; + sFinalRes = pselect(poly_mask, ssin, scos); + } else if (Func == TrigFunction::Cos) { + sign_bit = sign_cos; + sFinalRes = pselect(poly_mask, scos, ssin); + } else if (Func == TrigFunction::Tan) { + // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div. + sign_bit = pxor(sign_sin, sign_cos); + sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin)); + } else if (Func == TrigFunction::SinCos) { + Packet peven = peven_mask(x); + sign_bit = pselect(peven, sign_sin, sign_cos); + sFinalRes = pselect(pxor(peven, poly_mask), scos, ssin); + } + sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit + sFinalRes = pxor(sFinalRes, sign_bit); + + // If the inputs values are higher than that a value that the argument reduction can currently address, compute them + // using the C++ standard library. + // TODO Remove it when huge angle argument reduction is implemented + if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1(huge_th), x_abs)))) { + const int PacketSize = unpacket_traits::size; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double sincos_vals[PacketSize]; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double x_cpy[PacketSize]; + pstoreu(x_cpy, x); + pstoreu(sincos_vals, sFinalRes); + for (int k = 0; k < PacketSize; ++k) { + double val = x_cpy[k]; + if (std::abs(val) > huge_th && (numext::isfinite)(val)) { + if (Func == TrigFunction::Sin) { + sincos_vals[k] = std::sin(val); + } else if (Func == TrigFunction::Cos) { + sincos_vals[k] = std::cos(val); + } else if (Func == TrigFunction::Tan) { + sincos_vals[k] = std::tan(val); + } else if (Func == TrigFunction::SinCos) { + sincos_vals[k] = k % 2 == 0 ? std::sin(val) : std::cos(val); + } + } + } + sFinalRes = ploadu(sincos_vals); + } + return sFinalRes; +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x) { + return psincos_double(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x) { + return psincos_double(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_double(const Packet& x) { + return psincos_double(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS + std::enable_if_t::type, float>::value, Packet> + psincos_selector(const Packet& x) { + return psincos_float(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS + std::enable_if_t::type, double>::value, Packet> + psincos_selector(const Packet& x) { + return psincos_double(x); +} + +//---------------------------------------------------------------------- +// Inverse Trigonometric Functions +//---------------------------------------------------------------------- + +// Generic implementation of acos(x). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pacos_float(const Packet& x_in) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be float"); + + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_pi = pset1(Scalar(EIGEN_PI)); + const Packet p6 = pset1(Scalar(2.36423197202384471893310546875e-3)); + const Packet p5 = pset1(Scalar(-1.1368644423782825469970703125e-2)); + const Packet p4 = pset1(Scalar(2.717843465507030487060546875e-2)); + const Packet p3 = pset1(Scalar(-4.8969544470310211181640625e-2)); + const Packet p2 = pset1(Scalar(8.8804088532924652099609375e-2)); + const Packet p1 = pset1(Scalar(-0.214591205120086669921875)); + const Packet p0 = pset1(Scalar(1.57079637050628662109375)); + + // For x in [0:1], we approximate acos(x)/sqrt(1-x), which is a smooth + // function, by a 6'th order polynomial. + // For x in [-1:0) we use that acos(-x) = pi - acos(x). + const Packet neg_mask = psignbit(x_in); + const Packet abs_x = pabs(x_in); + + // Evaluate the polynomial using Horner's rule: + // P(x) = p0 + x * (p1 + x * (p2 + ... (p5 + x * p6)) ... ) . + // We evaluate even and odd terms independently to increase + // instruction level parallelism. + Packet x2 = pmul(x_in, x_in); + Packet p_even = pmadd(p6, x2, p4); + Packet p_odd = pmadd(p5, x2, p3); + p_even = pmadd(p_even, x2, p2); + p_odd = pmadd(p_odd, x2, p1); + p_even = pmadd(p_even, x2, p0); + Packet p = pmadd(p_odd, abs_x, p_even); + + // The polynomial approximates acos(x)/sqrt(1-x), so + // multiply by sqrt(1-x) to get acos(x). + // Conveniently returns NaN for arguments outside [-1:1]. + Packet denom = psqrt(psub(cst_one, abs_x)); + Packet result = pmul(denom, p); + // Undo mapping for negative arguments. + return pselect(neg_mask, psub(cst_pi, result), result); +} + +// Generic implementation of asin(x). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Packet& x_in) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be float"); + + constexpr float kPiOverTwo = static_cast(EIGEN_PI / 2); + + const Packet cst_half = pset1(0.5f); + const Packet cst_one = pset1(1.0f); + const Packet cst_two = pset1(2.0f); + const Packet cst_pi_over_two = pset1(kPiOverTwo); + + const Packet abs_x = pabs(x_in); + const Packet sign_mask = pandnot(x_in, abs_x); + const Packet invalid_mask = pcmp_lt(cst_one, abs_x); + + // For arguments |x| > 0.5, we map x back to [0:0.5] using + // the transformation x_large = sqrt(0.5*(1-x)), and use the + // identity + // asin(x) = pi/2 - 2 * asin( sqrt( 0.5 * (1 - x))) + + const Packet x_large = psqrt(pnmadd(cst_half, abs_x, cst_half)); + const Packet large_mask = pcmp_lt(cst_half, abs_x); + const Packet x = pselect(large_mask, x_large, abs_x); + const Packet x2 = pmul(x, x); + + // For |x| < 0.5 approximate asin(x)/x by an 8th order polynomial with + // even terms only. + constexpr float alpha[] = {5.08838854730129241943359375e-2f, 3.95139865577220916748046875e-2f, + 7.550220191478729248046875e-2f, 0.16664917767047882080078125f, 1.00000011920928955078125f}; + Packet p = ppolevl::run(x2, alpha); + p = pmul(p, x); + + const Packet p_large = pnmadd(cst_two, p, cst_pi_over_two); + p = pselect(large_mask, p_large, p); + // Flip the sign for negative arguments. + p = pxor(p, sign_mask); + // Return NaN for arguments outside [-1:1]. + return por(invalid_mask, p); +} + +template +struct patan_reduced { + template + static EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet run(const Packet& x); +}; + +template <> +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced::run(const Packet& x) { + constexpr double alpha[] = {2.6667153866462208e-05, 3.0917513112462781e-03, 5.2574296781008604e-02, + 3.0409318473444424e-01, 7.5365702534987022e-01, 8.2704055405494614e-01, + 3.3004361289279920e-01}; + + constexpr double beta[] = { + 2.7311202462436667e-04, 1.0899150928962708e-02, 1.1548932646420353e-01, 4.9716458728465573e-01, 1.0, + 9.3705509168587852e-01, 3.3004361289279920e-01}; + + Packet x2 = pmul(x, x); + Packet p = ppolevl::run(x2, alpha); + Packet q = ppolevl::run(x2, beta); + return pmul(x, pdiv(p, q)); +} + +// Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy. +template <> +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced::run(const Packet& x) { + constexpr float alpha[] = {1.12026982009410858154296875e-01f, 7.296695709228515625e-01f, 8.109951019287109375e-01f}; + + constexpr float beta[] = {1.00917108356952667236328125e-02f, 2.8318560123443603515625e-01f, 1.0f, + 8.109951019287109375e-01f}; + + Packet x2 = pmul(x, x); + Packet p = ppolevl::run(x2, alpha); + Packet q = ppolevl::run(x2, beta); + return pmul(x, pdiv(p, q)); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_atan(const Packet& x_in) { + typedef typename unpacket_traits::type Scalar; + + constexpr Scalar kPiOverTwo = static_cast(EIGEN_PI / 2); + + const Packet cst_signmask = pset1(Scalar(-0.0)); + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_pi_over_two = pset1(kPiOverTwo); + + // "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x). + // "Small": For |x| <= 1, approximate atan(x) directly by a polynomial + // calculated using Rminimax. + + const Packet abs_x = pabs(x_in); + const Packet x_signmask = pand(x_in, cst_signmask); + const Packet large_mask = pcmp_lt(cst_one, abs_x); + const Packet x = pselect(large_mask, preciprocal(abs_x), abs_x); + const Packet p = patan_reduced::run(x); + // Apply transformations according to the range reduction masks. + Packet result = pselect(large_mask, psub(cst_pi_over_two, p), p); + // Return correct sign + return pxor(result, x_signmask); +} + +//---------------------------------------------------------------------- +// Hyperbolic Functions +//---------------------------------------------------------------------- + +#ifdef EIGEN_FAST_MATH + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) + Doesn't do anything fancy, just a 9/8-degree rational interpolant which + is accurate up to a couple of ulps in the (approximate) range [-8, 8], + outside of which tanh(x) = +/-1 in single precision. The input is clamped + to the range [-c, c]. The value c is chosen as the smallest value where + the approximation evaluates to exactly 1. + + This implementation works on both scalars and packets. +*/ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) { + // Clamp the inputs to the range [-c, c] and set everything + // outside that range to 1.0. The value c is chosen as the smallest + // floating point argument such that the approximation is exactly 1. + // This saves clamping the value at the end. +#ifdef EIGEN_VECTORIZE_FMA + const T plus_clamp = pset1(8.01773357391357422f); + const T minus_clamp = pset1(-8.01773357391357422f); +#else + const T plus_clamp = pset1(7.90738964080810547f); + const T minus_clamp = pset1(-7.90738964080810547f); +#endif + const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); + + // The following rational approximation was generated by rminimax + // (https://gitlab.inria.fr/sfilip/rminimax) using the following + // command: + // $ ratapprox --function="tanh(x)" --dom='[-8.67,8.67]' --num="odd" + // --den="even" --type="[9,8]" --numF="[SG]" --denF="[SG]" --log + // --output=tanhf.sollya --dispCoeff="dec" + + // The monomial coefficients of the numerator polynomial (odd). + constexpr float alpha[] = {1.394553628e-8f, 2.102733560e-5f, 3.520756727e-3f, 1.340216100e-1f}; + + // The monomial coefficients of the denominator polynomial (even). + constexpr float beta[] = {8.015776984e-7f, 3.326951409e-4f, 2.597254514e-2f, 4.673548340e-1f, 1.0f}; + + // Since the polynomials are odd/even, we need x^2. + const T x2 = pmul(x, x); + const T x3 = pmul(x2, x); + + T p = ppolevl::run(x2, alpha); + T q = ppolevl::run(x2, beta); + // Take advantage of the fact that the constant term in p is 1 to compute + // x*(x^2*p + 1) = x^3 * p + x. + p = pmadd(x3, p, x); + + // Divide the numerator by the denominator. + return pdiv(p, q); +} + +#else + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise). + On the domain [-1.25:1.25] we use an approximation of the form + tanh(x) ~= x^3 * (P(x) / Q(x)) + x, where P and Q are polynomials in x^2. + For |x| > 1.25, tanh is implemented as tanh(x) = 1 - (2 / (1 + exp(2*x))). + + This implementation has a maximum error of 1 ULP (measured with AVX2+FMA). + + This implementation works on both scalars and packets. +*/ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& x) { + // The polynomial coefficients were computed using Rminimax: + // % ./ratapprox --function="tanh(x)-x" --dom='[-1.25,1.25]' --num="[x^3,x^5]" --den="even" + // --type="[3,4]" --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec" --output=tanhf.solly + constexpr float alpha[] = {-1.46725140511989593505859375e-02f, -3.333333432674407958984375e-01f}; + constexpr float beta[] = {1.570280082523822784423828125e-02, 4.4401752948760986328125e-01, 1.0f}; + const T x2 = pmul(x, x); + const T x3 = pmul(x2, x); + const T p = ppolevl::run(x2, alpha); + const T q = ppolevl::run(x2, beta); + const T small_tanh = pmadd(x3, pdiv(p, q), x); + + const T sign_mask = pset1(-0.0f); + const T abs_x = pandnot(x, sign_mask); + constexpr float kSmallThreshold = 1.25f; + const T large_mask = pcmp_lt(pset1(kSmallThreshold), abs_x); + // Fast exit if all elements are small. + if (!predux_any(large_mask)) { + return small_tanh; + } + + // Compute as 1 - (2 / (1 + exp(2*x))) + const T one = pset1(1.0f); + const T two = pset1(2.0f); + const T s = pexp_float(pmul(two, abs_x)); + const T abs_tanh = psub(one, pdiv(two, padd(s, one))); + + // Handle infinite inputs and set sign bit. + constexpr float kHugeThreshold = 16.0f; + const T huge_mask = pcmp_lt(pset1(kHugeThreshold), abs_x); + const T x_sign = pand(sign_mask, x); + const T large_tanh = por(x_sign, pselect(huge_mask, one, abs_tanh)); + return pselect(large_mask, large_tanh, small_tanh); +} + +#endif // EIGEN_FAST_MATH + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) + This uses a 19/18-degree rational interpolant which + is accurate up to a couple of ulps in the (approximate) range [-18.7, 18.7], + outside of which tanh(x) = +/-1 in single precision. The input is clamped + to the range [-c, c]. The value c is chosen as the smallest value where + the approximation evaluates to exactly 1. + + This implementation works on both scalars and packets. +*/ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_double(const T& a_x) { + // Clamp the inputs to the range [-c, c] and set everything + // outside that range to 1.0. The value c is chosen as the smallest + // floating point argument such that the approximation is exactly 1. + // This saves clamping the value at the end. +#ifdef EIGEN_VECTORIZE_FMA + const T plus_clamp = pset1(17.6610191624600077); + const T minus_clamp = pset1(-17.6610191624600077); +#else + const T plus_clamp = pset1(17.714196154005176); + const T minus_clamp = pset1(-17.714196154005176); +#endif + const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); + // The following rational approximation was generated by rminimax + // (https://gitlab.inria.fr/sfilip/rminimax) using the following + // command: + // $ ./ratapprox --function="tanh(x)" --dom='[-18.72,18.72]' + // --num="odd" --den="even" --type="[19,18]" --numF="[D]" + // --denF="[D]" --log --output=tanh.sollya --dispCoeff="dec" + + // The monomial coefficients of the numerator polynomial (odd). + constexpr double alpha[] = {2.6158007860482230e-23, 7.6534862268749319e-19, 3.1309488231386680e-15, + 4.2303918148209176e-12, 2.4618379131293676e-09, 6.8644367682497074e-07, + 9.3839087674268880e-05, 5.9809711724441161e-03, 1.5184719640284322e-01}; + + // The monomial coefficients of the denominator polynomial (even). + constexpr double beta[] = {6.463747022670968018e-21, 5.782506856739003571e-17, + 1.293019623712687916e-13, 1.123643448069621992e-10, + 4.492975677839633985e-08, 8.785185266237658698e-06, + 8.295161192716231542e-04, 3.437448108450402717e-02, + 4.851805297361760360e-01, 1.0}; + + // Since the polynomials are odd/even, we need x^2. + const T x2 = pmul(x, x); + const T x3 = pmul(x2, x); + + // Interleave the evaluation of the numerator polynomial p and + // denominator polynomial q. + T p = ppolevl::run(x2, alpha); + T q = ppolevl::run(x2, beta); + // Take advantage of the fact that the constant term in p is 1 to compute + // x*(x^2*p + 1) = x^3 * p + x. + p = pmadd(x3, p, x); + + // Divide the numerator by the denominator. + return pdiv(p, q); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be float"); + + // For |x| in [0:0.5] we use a polynomial approximation of the form + // P(x) = x + x^3*(alpha[4] + x^2 * (alpha[3] + x^2 * (... x^2 * alpha[0]) ... )). + constexpr float alpha[] = {0.1819281280040740966796875f, 8.2311116158962249755859375e-2f, + 0.14672131836414337158203125f, 0.1997792422771453857421875f, 0.3333373963832855224609375f}; + const Packet x2 = pmul(x, x); + const Packet x3 = pmul(x, x2); + Packet p = ppolevl::run(x2, alpha); + p = pmadd(x3, p, x); + + // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); + const Packet half = pset1(0.5f); + const Packet one = pset1(1.0f); + Packet r = pdiv(padd(one, x), psub(one, x)); + r = pmul(half, plog(r)); + + const Packet x_gt_half = pcmp_le(half, pabs(x)); + const Packet x_eq_one = pcmp_eq(one, pabs(x)); + const Packet x_gt_one = pcmp_lt(one, pabs(x)); + const Packet sign_mask = pset1(-0.0f); + const Packet x_sign = pand(sign_mask, x); + const Packet inf = pset1(std::numeric_limits::infinity()); + return por(x_gt_one, pselect(x_eq_one, por(x_sign, inf), pselect(x_gt_half, r, p))); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_double(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be double"); + // For x in [-0.5:0.5] we use a rational approximation of the form + // R(x) = x + x^3*P(x^2)/Q(x^2), where P is or order 4 and Q is of order 5. + constexpr double alpha[] = {3.3071338469301391e-03, -4.7129526768798737e-02, 1.8185306179826699e-01, + -2.5949536095445679e-01, 1.2306328729812676e-01}; + + constexpr double beta[] = {-3.8679974580640881e-03, 7.6391885763341910e-02, -4.2828141436397615e-01, + 9.8733495886883648e-01, -1.0000000000000000e+00, 3.6918986189438030e-01}; + + const Packet x2 = pmul(x, x); + const Packet x3 = pmul(x, x2); + Packet p = ppolevl::run(x2, alpha); + Packet q = ppolevl::run(x2, beta); + Packet y_small = pmadd(x3, pdiv(p, q), x); + + // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); + const Packet half = pset1(0.5); + const Packet one = pset1(1.0); + Packet y_large = pdiv(padd(one, x), psub(one, x)); + y_large = pmul(half, plog(y_large)); + + const Packet x_gt_half = pcmp_le(half, pabs(x)); + const Packet x_eq_one = pcmp_eq(one, pabs(x)); + const Packet x_gt_one = pcmp_lt(one, pabs(x)); + const Packet sign_mask = pset1(-0.0); + const Packet x_sign = pand(sign_mask, x); + const Packet inf = pset1(std::numeric_limits::infinity()); + return por(x_gt_one, pselect(x_eq_one, por(x_sign, inf), pselect(x_gt_half, y_large, y_small))); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_TRIG_H