mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Speed up exp(x).
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
@@ -146,7 +145,6 @@ struct packet_traits<double> : default_packet_traits {
|
||||
#endif
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasLog = 1,
|
||||
HasErf = 1,
|
||||
HasErfc = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
@@ -1935,6 +1933,22 @@ EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d&
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4d pldexp_fast<Packet4d>(const Packet4d& a, const Packet4d& exponent) {
|
||||
// Clamp exponent to [-1024, 1024]
|
||||
const Packet4d min_exponent = pset1<Packet4d>(-1023.0);
|
||||
const Packet4d max_exponent = pset1<Packet4d>(1024.0);
|
||||
const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, min_exponent), max_exponent));
|
||||
const Packet4i bias = pset1<Packet4i>(1023);
|
||||
|
||||
// 2^e
|
||||
Packet4i hi = vec4i_swizzle1(padd(e, bias), 0, 2, 1, 3);
|
||||
const Packet4i lo = _mm_slli_epi64(hi, 52);
|
||||
hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
|
||||
const Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
|
||||
return pmul(a, c); // a * 2^e
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE float predux<Packet8f>(const Packet8f& a) {
|
||||
return predux(Packet4f(_mm_add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))));
|
||||
|
||||
@@ -274,22 +274,20 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con
|
||||
//
|
||||
// Assumes IEEE floating point format
|
||||
template <typename Packet>
|
||||
struct pldexp_fast_impl {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = TotalBits - MantissaBits - 1;
|
||||
|
||||
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet run(const Packet& a, const Packet& exponent) {
|
||||
const Packet bias = pset1<Packet>(Scalar((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1))); // 127
|
||||
const Packet limit = pset1<Packet>(Scalar((ScalarI(1) << ExponentBits) - ScalarI(1))); // 255
|
||||
// restrict biased exponent between 0 and 255 for float.
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
|
||||
// return a * (2^e)
|
||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
|
||||
}
|
||||
};
|
||||
const Packet bias = pset1<Packet>(Scalar((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1))); // 127
|
||||
const Packet limit = pset1<Packet>(Scalar((ScalarI(1) << ExponentBits) - ScalarI(1))); // 255
|
||||
// restrict biased exponent between 0 and 255 for float.
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
|
||||
// return a * (2^e)
|
||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
|
||||
}
|
||||
|
||||
// Natural or base 2 logarithm.
|
||||
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
|
||||
@@ -549,7 +547,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_float(const Pack
|
||||
y = pmadd(r2, y, p_low);
|
||||
|
||||
// Return 2^m * exp(r).
|
||||
// TODO: replace pldexp with faster implementation since y in [-1, 1).
|
||||
const Packet fast_pldexp_unsafe = pandnot(pcmp_lt(x, pset1<Packet>(-87.0)), zero_mask);
|
||||
if (!predux_any(fast_pldexp_unsafe)) {
|
||||
// For x >= -87, we can safely use the fast version of pldexp.
|
||||
return pselect(zero_mask, cst_zero, pmax(pldexp_fast(y, m), _x));
|
||||
}
|
||||
return pselect(zero_mask, cst_zero, pmax(pldexp(y, m), _x));
|
||||
}
|
||||
|
||||
@@ -562,7 +564,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_double(const Pac
|
||||
const Packet cst_half = pset1<Packet>(0.5);
|
||||
|
||||
const Packet cst_exp_hi = pset1<Packet>(709.784);
|
||||
const Packet cst_exp_lo = pset1<Packet>(-709.784);
|
||||
const Packet cst_exp_lo = pset1<Packet>(-745.519);
|
||||
|
||||
const Packet cst_cephes_LOG2EF = pset1<Packet>(1.4426950408889634073599);
|
||||
const Packet cst_cephes_exp_p0 = pset1<Packet>(1.26177193074810590878e-4);
|
||||
@@ -616,7 +618,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_double(const Pac
|
||||
|
||||
// Construct the result 2^n * exp(g) = e * x. The max is used to catch
|
||||
// non-finite values in the input.
|
||||
// TODO: replace pldexp with faster implementation since x in [-1, 1).
|
||||
const Packet fast_pldexp_unsafe = pandnot(pcmp_lt(_x, pset1<Packet>(-708.0)), zero_mask);
|
||||
if (!predux_any(fast_pldexp_unsafe)) {
|
||||
// For x >= -708, we can safely use the fast version of pldexp.
|
||||
return pselect(zero_mask, cst_zero, pmax(pldexp_fast(x, fx), _x));
|
||||
}
|
||||
return pselect(zero_mask, cst_zero, pmax(pldexp(x, fx), _x));
|
||||
}
|
||||
|
||||
|
||||
@@ -42,6 +42,18 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic_get_biased_exponent(
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, const Packet& exponent);
|
||||
|
||||
// Explicitly multiplies
|
||||
// a * (2^e)
|
||||
// clamping e to the range
|
||||
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
|
||||
//
|
||||
// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
|
||||
// if 2^e doesn't fit into a normal floating-point Scalar.
|
||||
//
|
||||
// Assumes IEEE floating point format
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent);
|
||||
|
||||
/** \internal \returns log(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_float(const Packet _x);
|
||||
|
||||
@@ -217,7 +217,6 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasLog = 1,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
@@ -1767,7 +1766,6 @@ EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f&
|
||||
|
||||
// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well
|
||||
// supported by SSE, and has more range than is needed for exponents.
|
||||
// TODO(rmlarsen): Remove this specialization once Packet2l has support or casting.
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
|
||||
// Clamp exponent to [-2099, 2099]
|
||||
@@ -1788,6 +1786,24 @@ EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d&
|
||||
return out;
|
||||
}
|
||||
|
||||
// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well
|
||||
// supported by SSE, and has more range than is needed for exponents.
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2d pldexp_fast<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
|
||||
// Clamp exponent to [-1023, 1024]
|
||||
const Packet2d min_exponent = pset1<Packet2d>(-1023.0);
|
||||
const Packet2d max_exponent = pset1<Packet2d>(1024.0);
|
||||
const Packet2d e = pmin(pmax(exponent, min_exponent), max_exponent);
|
||||
|
||||
// Convert e to integer and swizzle to low-order bits.
|
||||
const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3);
|
||||
|
||||
// Compute 2^e multiply:
|
||||
const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023);
|
||||
const Packet2d c = _mm_castsi128_pd(_mm_slli_epi64(padd(ei, bias), 52)); // 2^e
|
||||
return pmul(a, c);
|
||||
}
|
||||
|
||||
// with AVX, the default implementations based on pload1 are faster
|
||||
#ifndef __AVX__
|
||||
template <>
|
||||
|
||||
@@ -1292,7 +1292,7 @@ struct scalar_logistic_op<float> {
|
||||
p = pmadd(r2, p, p_low);
|
||||
|
||||
// 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r).
|
||||
Packet e = pldexp_fast_impl<Packet>::run(p, m);
|
||||
Packet e = pldexp_fast(p, m);
|
||||
|
||||
// 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2.
|
||||
e = pmul(e, e);
|
||||
|
||||
Reference in New Issue
Block a user