mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix unary pow error handling and test
This commit is contained in:
committed by
Rasmus Munk Larsen
parent
7ac8897431
commit
b7151ffaab
@@ -1970,24 +1970,49 @@ struct pchebevl {
|
||||
};
|
||||
|
||||
namespace unary_pow {
|
||||
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct is_odd {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) {
|
||||
ScalarExponent xdiv2 = x / ScalarExponent(2);
|
||||
ScalarExponent floorxdiv2 = numext::floor(xdiv2);
|
||||
return xdiv2 != floorxdiv2;
|
||||
|
||||
template <typename ScalarExponent, bool IsInteger = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct exponent_helper {
|
||||
using safe_abs_type = ScalarExponent;
|
||||
static constexpr ScalarExponent one_half = ScalarExponent(0.5);
|
||||
// these routines assume that exp is an integer stored as a floating point type
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent safe_abs(const ScalarExponent& exp) {
|
||||
return numext::abs(exp);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const ScalarExponent& exp) {
|
||||
eigen_assert(((numext::isfinite)(exp) && exp == numext::floor(exp)) && "exp must be an integer");
|
||||
ScalarExponent exp_div_2 = exp * one_half;
|
||||
ScalarExponent floor_exp_div_2 = numext::floor(exp_div_2);
|
||||
return exp_div_2 != floor_exp_div_2;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent floor_div_two(const ScalarExponent& exp) {
|
||||
ScalarExponent exp_div_2 = exp * one_half;
|
||||
return numext::floor(exp_div_2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ScalarExponent>
|
||||
struct is_odd<ScalarExponent, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) {
|
||||
return x % ScalarExponent(2) != 0;
|
||||
struct exponent_helper<ScalarExponent, true> {
|
||||
// if `exp` is a signed integer type, cast it to its unsigned counterpart to safely store its absolute value
|
||||
// consider the (rare) case where `exp` is an int32_t: abs(-2147483648) != 2147483648
|
||||
using safe_abs_type = typename numext::get_integer_by_size<sizeof(ScalarExponent)>::unsigned_type;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type safe_abs(const ScalarExponent& exp) {
|
||||
ScalarExponent mask = exp ^ numext::abs(exp);
|
||||
safe_abs_type result = static_cast<safe_abs_type>(exp);
|
||||
return result ^ mask;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const safe_abs_type& exp) {
|
||||
return exp % safe_abs_type(2) != safe_abs_type(0);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type floor_div_two(const safe_abs_type& exp) {
|
||||
return exp >> safe_abs_type(1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
|
||||
struct do_div {
|
||||
bool ReciprocateIfExponentIsNegative =
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger && NumTraits<ScalarExponent>::IsSigned>
|
||||
struct reciprocate {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
@@ -1996,41 +2021,43 @@ struct do_div {
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct do_div<Packet, ScalarExponent, true> {
|
||||
struct reciprocate<Packet, ScalarExponent, false> {
|
||||
// pdiv not defined, nor necessary for integer base types
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) {
|
||||
return x;
|
||||
}
|
||||
// if the exponent is unsigned, then the exponent cannot be negative
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; }
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
using ExponentHelper = exponent_helper<ScalarExponent>;
|
||||
using AbsExponentType = typename ExponentHelper::safe_abs_type;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
if (exponent == 0) return cst_pos_one;
|
||||
Packet result = x;
|
||||
if (exponent == ScalarExponent(0)) return cst_pos_one;
|
||||
|
||||
Packet result = reciprocate<Packet, ScalarExponent>::run(x, exponent);
|
||||
Packet y = cst_pos_one;
|
||||
ScalarExponent m = numext::abs(exponent);
|
||||
AbsExponentType m = ExponentHelper::safe_abs(exponent);
|
||||
|
||||
while (m > 1) {
|
||||
bool odd = is_odd<ScalarExponent>::run(m);
|
||||
bool odd = ExponentHelper::is_odd(m);
|
||||
if (odd) y = pmul(y, result);
|
||||
result = pmul(result, result);
|
||||
m = numext::floor(m / ScalarExponent(2));
|
||||
m = ExponentHelper::floor_div_two(m);
|
||||
}
|
||||
result = pmul(y, result);
|
||||
result = do_div<Packet, ScalarExponent>::run(result, exponent);
|
||||
return result;
|
||||
|
||||
return pmul(y, result);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
|
||||
const typename unpacket_traits<Packet>::type& exponent) {
|
||||
const Packet exponent_packet = pset1<Packet>(exponent);
|
||||
return generic_pow_impl(x, exponent_packet);
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
|
||||
@@ -2045,36 +2072,45 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||
|
||||
const bool exponent_is_nan = (numext::isnan)(exponent);
|
||||
const bool exponent_is_fin = (numext::isfinite)(exponent);
|
||||
const bool exponent_is_not_fin = !(numext::isfinite)(exponent);
|
||||
const bool exponent_is_neg = exponent < ScalarExponent(0);
|
||||
const bool exponent_is_pos = exponent > ScalarExponent(0);
|
||||
|
||||
const Packet exp_is_nan = pset1<Packet>(exponent_is_nan ? all_ones : pos_zero);
|
||||
const Packet exp_is_fin = pset1<Packet>(exponent_is_fin ? all_ones : pos_zero);
|
||||
const Packet exp_is_not_fin = pset1<Packet>(exponent_is_not_fin ? all_ones : pos_zero);
|
||||
const Packet exp_is_neg = pset1<Packet>(exponent_is_neg ? all_ones : pos_zero);
|
||||
const Packet exp_is_pos = pset1<Packet>(exponent_is_pos ? all_ones : pos_zero);
|
||||
const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
|
||||
const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
|
||||
|
||||
const Packet x_is_gt_one = pcmp_lt(cst_pos_one, x);
|
||||
const Packet x_is_lt_one = pcmp_lt(x, cst_pos_one);
|
||||
const Packet x_is_zero = pcmp_eq(x, cst_pos_zero);
|
||||
const Packet x_is_not_one = por(x_is_gt_one, x_is_lt_one);
|
||||
const Packet x_is_le_zero = pcmp_le(x, cst_pos_zero);
|
||||
const Packet x_is_ge_zero = pcmp_le(cst_pos_zero, x);
|
||||
const Packet x_is_zero = pand(x_is_le_zero, x_is_ge_zero);
|
||||
|
||||
const Packet inf_if_neg_exp = pand(cst_pos_inf, exp_is_neg);
|
||||
const Packet inf_if_pos_exp = pandnot(cst_pos_inf, exp_is_neg);
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_le_one = pcmp_le(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_ge_one = pcmp_le(cst_pos_one, abs_x);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
const Packet abs_x_is_one = pand(abs_x_is_le_one, abs_x_is_ge_one);
|
||||
|
||||
Packet pow_is_inf_if_exp_is_neg = por(x_is_zero, pand(abs_x_is_le_one, exp_is_inf));
|
||||
Packet pow_is_inf_if_exp_is_pos = por(abs_x_is_inf, pand(abs_x_is_ge_one, exp_is_inf));
|
||||
Packet pow_is_one = pand(abs_x_is_one, por(exp_is_inf, x_is_ge_zero));
|
||||
|
||||
Packet result = powx;
|
||||
result = pselect(x_is_zero, inf_if_neg_exp, result);
|
||||
result = pselect(pandnot(x_is_gt_one, exp_is_fin), inf_if_pos_exp, result);
|
||||
result = pselect(pandnot(x_is_lt_one, exp_is_fin), inf_if_neg_exp, result);
|
||||
result = por(x_is_le_zero, result);
|
||||
result = pselect(pow_is_inf_if_exp_is_neg, pand(cst_pos_inf, exp_is_neg), result);
|
||||
result = pselect(pow_is_inf_if_exp_is_pos, pand(cst_pos_inf, exp_is_pos), result);
|
||||
result = por(exp_is_nan, result);
|
||||
result = pselect(x_is_not_one, result, cst_pos_one);
|
||||
result = pselect(pow_is_one, cst_pos_one, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) {
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
|
||||
// integer base, integer exponent case
|
||||
// singed integer base, signed integer exponent case
|
||||
|
||||
// This routine handles negative exponents.
|
||||
// The return value is either 0, 1, or -1.
|
||||
@@ -2085,7 +2121,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet&
|
||||
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
|
||||
const bool exponent_is_odd = unary_pow::is_odd<ScalarExponent>::run(exponent);
|
||||
const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0);
|
||||
|
||||
const Packet exp_is_odd = pset1<Packet>(exponent_is_odd ? all_ones : pos_zero);
|
||||
|
||||
@@ -2097,15 +2133,36 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet&
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent&) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
|
||||
// unsigned integer base, signed integer exponent case
|
||||
|
||||
// This routine handles negative exponents.
|
||||
// The return value is either 0 or 1
|
||||
|
||||
const Scalar pos_one = Scalar(1);
|
||||
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
|
||||
const Packet x_is_one = pcmp_eq(x, cst_pos_one);
|
||||
|
||||
return pand(x_is_one, x);
|
||||
}
|
||||
|
||||
|
||||
} // end namespace unary_pow
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger,
|
||||
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
|
||||
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger,
|
||||
bool ExponentIsSigned = NumTraits<ScalarExponent>::IsSigned>
|
||||
struct unary_pow_impl;
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, false> {
|
||||
template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, false, ExponentIsSigned> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||
@@ -2119,8 +2176,8 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||
template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true, ExponentIsSigned> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
@@ -2128,17 +2185,25 @@ struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, true, true> {
|
||||
struct unary_pow_impl<Packet, ScalarExponent, true, true, true> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
if (exponent < ScalarExponent(0)) {
|
||||
return unary_pow::handle_int_int(x, exponent);
|
||||
return unary_pow::handle_negative_exponent(x, exponent);
|
||||
} else {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, true, true, false> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
|
||||
Reference in New Issue
Block a user