diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index e3dcfae70..7e4f054dd 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -1,4 +1,3 @@ - // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // @@ -146,7 +145,6 @@ struct packet_traits : default_packet_traits { #endif HasTanh = EIGEN_FAST_MATH, HasLog = 1, - HasErf = 1, HasErfc = 1, HasExp = 1, HasSqrt = 1, @@ -1935,6 +1933,22 @@ EIGEN_STRONG_INLINE Packet4d pldexp(const Packet4d& a, const Packet4d& return out; } +template <> +EIGEN_STRONG_INLINE Packet4d pldexp_fast(const Packet4d& a, const Packet4d& exponent) { + // Clamp exponent to [-1024, 1024] + const Packet4d min_exponent = pset1(-1023.0); + const Packet4d max_exponent = pset1(1024.0); + const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, min_exponent), max_exponent)); + const Packet4i bias = pset1(1023); + + // 2^e + Packet4i hi = vec4i_swizzle1(padd(e, bias), 0, 2, 1, 3); + const Packet4i lo = _mm_slli_epi64(hi, 52); + hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52); + const Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1)); + return pmul(a, c); // a * 2^e +} + template <> EIGEN_STRONG_INLINE float predux(const Packet8f& a) { return predux(Packet4f(_mm_add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)))); diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 4e441b498..8b7d762ff 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -274,22 +274,20 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con // // Assumes IEEE floating point format template -struct pldexp_fast_impl { +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent) { typedef typename unpacket_traits::integer_packet PacketI; typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::type ScalarI; static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits::digits - 1, ExponentBits = TotalBits - MantissaBits - 1; - static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet run(const Packet& a, const Packet& exponent) { - const Packet bias = pset1(Scalar((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1))); // 127 - const Packet limit = pset1(Scalar((ScalarI(1) << ExponentBits) - ScalarI(1))); // 255 - // restrict biased exponent between 0 and 255 for float. - const PacketI e = pcast(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 - // return a * (2^e) - return pmul(a, preinterpret(plogical_shift_left(e))); - } -}; + const Packet bias = pset1(Scalar((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1))); // 127 + const Packet limit = pset1(Scalar((ScalarI(1) << ExponentBits) - ScalarI(1))); // 255 + // restrict biased exponent between 0 and 255 for float. + const PacketI e = pcast(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 + // return a * (2^e) + return pmul(a, preinterpret(plogical_shift_left(e))); +} // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) @@ -549,7 +547,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_float(const Pack y = pmadd(r2, y, p_low); // Return 2^m * exp(r). - // TODO: replace pldexp with faster implementation since y in [-1, 1). + const Packet fast_pldexp_unsafe = pandnot(pcmp_lt(x, pset1(-87.0)), zero_mask); + if (!predux_any(fast_pldexp_unsafe)) { + // For x >= -87, we can safely use the fast version of pldexp. + return pselect(zero_mask, cst_zero, pmax(pldexp_fast(y, m), _x)); + } return pselect(zero_mask, cst_zero, pmax(pldexp(y, m), _x)); } @@ -562,7 +564,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_double(const Pac const Packet cst_half = pset1(0.5); const Packet cst_exp_hi = pset1(709.784); - const Packet cst_exp_lo = pset1(-709.784); + const Packet cst_exp_lo = pset1(-745.519); const Packet cst_cephes_LOG2EF = pset1(1.4426950408889634073599); const Packet cst_cephes_exp_p0 = pset1(1.26177193074810590878e-4); @@ -616,7 +618,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_double(const Pac // Construct the result 2^n * exp(g) = e * x. The max is used to catch // non-finite values in the input. - // TODO: replace pldexp with faster implementation since x in [-1, 1). + const Packet fast_pldexp_unsafe = pandnot(pcmp_lt(_x, pset1(-708.0)), zero_mask); + if (!predux_any(fast_pldexp_unsafe)) { + // For x >= -708, we can safely use the fast version of pldexp. + return pselect(zero_mask, cst_zero, pmax(pldexp_fast(x, fx), _x)); + } return pselect(zero_mask, cst_zero, pmax(pldexp(x, fx), _x)); } diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 3b362f4f6..ac0e2cfd3 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -42,6 +42,18 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic_get_biased_exponent( template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, const Packet& exponent); +// Explicitly multiplies +// a * (2^e) +// clamping e to the range +// [NumTraits::min_exponent()-2, NumTraits::max_exponent()] +// +// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow +// if 2^e doesn't fit into a normal floating-point Scalar. +// +// Assumes IEEE floating point format +template +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent); + /** \internal \returns log(x) for single precision float */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_float(const Packet _x); diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index b3c526fbe..f29400950 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -217,7 +217,6 @@ struct packet_traits : default_packet_traits { HasCos = EIGEN_FAST_MATH, HasTanh = EIGEN_FAST_MATH, HasLog = 1, - HasErf = EIGEN_FAST_MATH, HasErfc = EIGEN_FAST_MATH, HasExp = 1, HasSqrt = 1, @@ -1767,7 +1766,6 @@ EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& // We specialize pldexp here, since the generic implementation uses Packet2l, which is not well // supported by SSE, and has more range than is needed for exponents. -// TODO(rmlarsen): Remove this specialization once Packet2l has support or casting. template <> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) { // Clamp exponent to [-2099, 2099] @@ -1788,6 +1786,24 @@ EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& return out; } +// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well +// supported by SSE, and has more range than is needed for exponents. +template <> +EIGEN_STRONG_INLINE Packet2d pldexp_fast(const Packet2d& a, const Packet2d& exponent) { + // Clamp exponent to [-1023, 1024] + const Packet2d min_exponent = pset1(-1023.0); + const Packet2d max_exponent = pset1(1024.0); + const Packet2d e = pmin(pmax(exponent, min_exponent), max_exponent); + + // Convert e to integer and swizzle to low-order bits. + const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3); + + // Compute 2^e multiply: + const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023); + const Packet2d c = _mm_castsi128_pd(_mm_slli_epi64(padd(ei, bias), 52)); // 2^e + return pmul(a, c); +} + // with AVX, the default implementations based on pload1 are faster #ifndef __AVX__ template <> diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index defd3c2a1..03542e331 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1292,7 +1292,7 @@ struct scalar_logistic_op { p = pmadd(r2, p, p_low); // 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r). - Packet e = pldexp_fast_impl::run(p, m); + Packet e = pldexp_fast(p, m); // 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2. e = pmul(e, e);