mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Refactor special case handling in pow(x,y) and revert to repeated squaring for <float,int>
This commit is contained in:
@@ -1843,7 +1843,8 @@ struct accurate_log2 {
|
||||
// The minimax polynomial used was calculated using the Rminimax tool,
|
||||
// see https://gitlab.inria.fr/sfilip/rminimax.
|
||||
// Command line:
|
||||
// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]' --type=[10,0]
|
||||
// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]'
|
||||
// --type=[10,0]
|
||||
// --numF="[D,D,SG]" --denF="[SG]" --log --dispCoeff="dec"
|
||||
//
|
||||
// The resulting implementation of pow(x,y) is accurate to 3 ulps.
|
||||
@@ -1851,7 +1852,7 @@ template <>
|
||||
struct accurate_log2<float> {
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
|
||||
// // Split the two lowest order constant coefficient into double-word representation.
|
||||
// Split the two lowest order constant coefficient into double-word representation.
|
||||
constexpr double kC0 = 1.442695041742110273474963832995854318141937255859375e+00;
|
||||
constexpr float kC0_hi = static_cast<float>(kC0);
|
||||
constexpr float kC0_lo = static_cast<float>(kC0 - static_cast<double>(kC0_hi));
|
||||
@@ -1874,7 +1875,8 @@ struct accurate_log2<float> {
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
const Packet x = psub(z, one);
|
||||
Packet p = ppolevl<Packet, 8>::run(x, c);
|
||||
// Evaluate the final two step in Horner's rule using double-word arithmetic.
|
||||
// Evaluate the final two step in Horner's rule using double-word
|
||||
// arithmetic.
|
||||
Packet p_hi, p_lo;
|
||||
twoprod(x, p, p_hi, p_lo);
|
||||
fast_twosum(c1_hi, c1_lo, p_hi, p_lo, p_hi, p_lo);
|
||||
@@ -2041,69 +2043,91 @@ template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||
const Packet cst_neg_inf = pset1<Packet>(-NumTraits<Scalar>::infinity());
|
||||
const Packet cst_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||
const Packet cst_zero = pset1<Packet>(Scalar(0));
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet x_abs = pabs(x);
|
||||
Packet pow = generic_pow_impl(x_abs, y);
|
||||
|
||||
// In the following we enforce the special case handling prescribed in
|
||||
// https://en.cppreference.com/w/cpp/numeric/math/pow.
|
||||
|
||||
// Predicates for sign and magnitude of x.
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero);
|
||||
const Packet x_is_negative = pcmp_lt(x, cst_zero);
|
||||
const Packet x_is_zero = pcmp_eq(x, cst_zero);
|
||||
const Packet x_is_one = pcmp_eq(x, cst_one);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
|
||||
const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
|
||||
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
|
||||
const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_nan = pisnan(x);
|
||||
const Packet x_abs_gt_one = pcmp_lt(cst_one, x_abs);
|
||||
const Packet x_abs_is_inf = pcmp_eq(x_abs, cst_inf);
|
||||
|
||||
// Predicates for sign and magnitude of y.
|
||||
const Packet abs_y = pabs(y);
|
||||
const Packet y_abs = pabs(y);
|
||||
const Packet y_abs_is_inf = pcmp_eq(y_abs, cst_inf);
|
||||
const Packet y_is_negative = pcmp_lt(y, cst_zero);
|
||||
const Packet y_is_zero = pcmp_eq(y, cst_zero);
|
||||
const Packet y_is_one = pcmp_eq(y, cst_one);
|
||||
const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero);
|
||||
const Packet y_is_neg = pcmp_lt(y, cst_zero);
|
||||
const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg));
|
||||
const Packet y_is_nan = pisnan(y);
|
||||
const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf);
|
||||
EIGEN_CONSTEXPR Scalar huge_exponent =
|
||||
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon();
|
||||
const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
|
||||
|
||||
// Predicates for whether y is integer and/or even.
|
||||
const Packet y_is_int = pcmp_eq(pfloor(y), y);
|
||||
// Predicates for whether y is integer and odd/even.
|
||||
const Packet y_is_int = pandnot(pcmp_eq(pfloor(y), y), y_abs_is_inf);
|
||||
const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
|
||||
const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
|
||||
const Packet y_is_odd_int = pandnot(y_is_int, y_is_even);
|
||||
// Smallest exponent for which (1 + epsilon) overflows to infinity.
|
||||
EIGEN_CONSTEXPR Scalar huge_exponent =
|
||||
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon();
|
||||
const Packet y_abs_is_huge = pcmp_le(pset1<Packet>(huge_exponent), y_abs);
|
||||
|
||||
// Predicates encoding special cases for the value of pow(x,y)
|
||||
const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
|
||||
const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
|
||||
const Packet pow_is_one =
|
||||
por(por(x_is_one, abs_y_is_zero), pand(x_is_neg_one, por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
|
||||
const Packet pow_is_zero = por(por(por(pand(abs_x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg));
|
||||
const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
|
||||
const Packet pow_is_neg_zero = pand(pandnot(y_is_int, y_is_even),
|
||||
por(pand(y_is_neg, pand(abs_x_is_inf, x_is_neg)), pand(y_is_pos, x_is_neg_zero)));
|
||||
const Packet inf_val =
|
||||
pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even),
|
||||
cst_neg_inf, cst_pos_inf);
|
||||
// General computation of pow(x,y) for positive x or negative x and integer y.
|
||||
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
||||
const Packet pow_abs = generic_pow_impl(abs_x, y);
|
||||
return pselect(y_is_one, x,
|
||||
pselect(pow_is_one, cst_one,
|
||||
pselect(pow_is_nan, cst_nan,
|
||||
pselect(pow_is_inf, inf_val,
|
||||
pselect(pow_is_neg_zero, pnegate(cst_zero),
|
||||
pselect(pow_is_zero, cst_zero,
|
||||
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))));
|
||||
// * pow(base, exp) returns NaN if base is finite and negative
|
||||
// and exp is finite and non-integer.
|
||||
pow = pselect(pandnot(x_is_negative, y_is_int), cst_nan, pow);
|
||||
|
||||
// * pow(±0, exp), where exp is negative, finite, and is an even integer or
|
||||
// a non-integer, returns +∞
|
||||
// * pow(±0, exp), where exp is positive non-integer or a positive even
|
||||
// integer, returns +0
|
||||
// * pow(+0, exp), where exp is a negative odd integer, returns +∞
|
||||
// * pow(-0, exp), where exp is a negative odd integer, returns -∞
|
||||
// * pow(+0, exp), where exp is a positive odd integer, returns +0
|
||||
// * pow(-0, exp), where exp is a positive odd integer, returns -0
|
||||
// Sign is flipped by the rule below.
|
||||
pow = pselect(x_is_zero, pselect(y_is_negative, cst_inf, cst_zero), pow);
|
||||
|
||||
// pow(base, exp) returns -pow(abs(base), exp) if base has the sign bit set,
|
||||
// and exp is an odd integer exponent.
|
||||
pow = pselect(pand(x_has_signbit, y_is_odd_int), pnegate(pow), pow);
|
||||
|
||||
// * pow(base, -∞) returns +∞ for any |base|<1
|
||||
// * pow(base, -∞) returns +0 for any |base|>1
|
||||
// * pow(base, +∞) returns +0 for any |base|<1
|
||||
// * pow(base, +∞) returns +∞ for any |base|>1
|
||||
// * pow(±0, -∞) returns +∞
|
||||
// * pow(-1, +-∞) = 1
|
||||
Packet inf_y_val = pselect(por(pand(y_is_negative, x_is_zero), pxor(y_is_negative, x_abs_gt_one)), cst_inf, cst_zero);
|
||||
inf_y_val = pselect(pcmp_eq(x, pset1<Packet>(Scalar(-1.0))), cst_one, inf_y_val);
|
||||
pow = pselect(y_abs_is_huge, inf_y_val, pow);
|
||||
|
||||
// * pow(+∞, exp) returns +0 for any negative exp
|
||||
// * pow(+∞, exp) returns +∞ for any positive exp
|
||||
// * pow(-∞, exp) returns -0 if exp is a negative odd integer.
|
||||
// * pow(-∞, exp) returns +0 if exp is a negative non-integer or negative
|
||||
// even integer.
|
||||
// * pow(-∞, exp) returns -∞ if exp is a positive odd integer.
|
||||
// * pow(-∞, exp) returns +∞ if exp is a positive non-integer or positive
|
||||
// even integer.
|
||||
auto x_pos_inf_value = pselect(y_is_negative, cst_zero, cst_inf);
|
||||
auto x_neg_inf_value = pselect(y_is_odd_int, pnegate(x_pos_inf_value), x_pos_inf_value);
|
||||
pow = pselect(x_abs_is_inf, pselect(x_is_negative, x_neg_inf_value, x_pos_inf_value), pow);
|
||||
|
||||
// All cases of NaN inputs return NaN, except the two below.
|
||||
pow = pselect(por(pisnan(x), pisnan(y)), cst_nan, pow);
|
||||
|
||||
// * pow(base, 1) returns base.
|
||||
// * pow(base, +/-0) returns 1, regardless of base, even NaN.
|
||||
// * pow(+1, exp) returns 1, regardless of exponent, even NaN.
|
||||
pow = pselect(y_is_one, x, pselect(por(x_is_one, y_is_zero), cst_one, pow));
|
||||
|
||||
return pow;
|
||||
}
|
||||
|
||||
namespace unary_pow {
|
||||
@@ -2303,13 +2327,12 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false, ExponentIsSigned> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||
if (exponent_is_integer) {
|
||||
// The simple recursive doubling implementation is only accurate to 3 ulps for
|
||||
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
|
||||
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
// TODO(rmlarsen): Implement more efficient special case handling.
|
||||
return generic_pow(x, pset1<Packet>(exponent));
|
||||
// The simple recursive doubling implementation is only accurate to 3 ulps
|
||||
// for integer exponents in [-3:7]. Since this is a common case, we
|
||||
// specialize it here.
|
||||
bool use_repeated_squaring =
|
||||
(exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3)));
|
||||
return use_repeated_squaring ? unary_pow::int_pow(x, exponent) : generic_pow(x, pset1<Packet>(exponent));
|
||||
} else {
|
||||
Packet result = unary_pow::gen_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
|
||||
@@ -2322,13 +2345,7 @@ template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true, ExponentIsSigned> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
// The simple recursive doubling implementation is only sufficiently accurate to 3 ulps for
|
||||
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
|
||||
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
// TODO(rmlarsen): Implement more efficient special case handling.
|
||||
return generic_pow<Packet>(x, pset1<Packet>(Scalar(exponent)));
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user