Add numext::fma and missing pmadd implementations.

This commit is contained in:
Antonio Sánchez
2025-03-23 01:05:53 +00:00
parent 754bd24f5e
commit d935916ac6
11 changed files with 319 additions and 105 deletions

View File

@@ -24,21 +24,55 @@ template <typename T>
inline T REF_MUL(const T& a, const T& b) {
return a * b;
}
template <typename Scalar, typename EnableIf = void>
struct madd_impl {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b + c;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b - c;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return c - a * b;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return Scalar(0) - (a * b + c);
}
};
template <typename Scalar>
struct madd_impl<Scalar,
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c));
}
};
template <typename T>
inline T REF_MADD(const T& a, const T& b, const T& c) {
return internal::pmadd(a, b, c);
return madd_impl<T>::madd(a, b, c);
}
template <typename T>
inline T REF_MSUB(const T& a, const T& b, const T& c) {
return internal::pmsub(a, b, c);
return madd_impl<T>::msub(a, b, c);
}
template <typename T>
inline T REF_NMADD(const T& a, const T& b, const T& c) {
return internal::pnmadd(a, b, c);
return madd_impl<T>::nmadd(a, b, c);
}
template <typename T>
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
return internal::pnmsub(a, b, c);
return madd_impl<T>::nmsub(a, b, c);
}
template <typename T>
inline T REF_DIV(const T& a, const T& b) {
@@ -70,6 +104,14 @@ template <>
inline bool REF_MADD(const bool& a, const bool& b, const bool& c) {
return (a && b) || c;
}
template <>
inline bool REF_DIV(const bool& a, const bool& b) {
return a && b;
}
template <>
inline bool REF_RECIPROCAL(const bool& a) {
return a;
}
template <typename T>
inline T REF_FREXP(const T& x, T& exp) {
@@ -501,8 +543,8 @@ void packetmath() {
eigen_optimization_barrier_test<Scalar>::run();
for (int i = 0; i < size; ++i) {
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data2[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[i] = internal::random<Scalar>();
data2[i] = internal::random<Scalar>();
refvalue = (std::max)(refvalue, numext::abs(data1[i]));
}
@@ -522,8 +564,8 @@ void packetmath() {
for (int M = 0; M < PacketSize; ++M) {
for (int N = 0; N <= PacketSize; ++N) {
for (int j = 0; j < size; ++j) {
data1[j] = internal::random<Scalar>() / RealScalar(PacketSize);
data2[j] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[j] = internal::random<Scalar>();
data2[j] = internal::random<Scalar>();
refvalue = (std::max)(refvalue, numext::abs(data1[j]));
}
@@ -652,11 +694,11 @@ void packetmath() {
// Avoid overflows.
if (NumTraits<Scalar>::IsInteger && NumTraits<Scalar>::IsSigned &&
Eigen::internal::unpacket_traits<Packet>::size > 1) {
Scalar limit =
static_cast<Scalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size)));
Scalar limit = static_cast<Scalar>(
static_cast<RealScalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size))));
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<Scalar>(-limit, limit);
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
}
}
ref[0] = Scalar(1);
@@ -1683,7 +1725,7 @@ void packetmath_scatter_gather() {
for (Index N = 0; N <= PacketSize; ++N) {
for (Index i = 0; i < N; ++i) {
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[i] = internal::random<Scalar>();
}
for (Index i = 0; i < N * 20; ++i) {
@@ -1702,7 +1744,7 @@ void packetmath_scatter_gather() {
}
for (Index i = 0; i < N * 7; ++i) {
buffer[i] = internal::random<Scalar>() / RealScalar(PacketSize);
buffer[i] = internal::random<Scalar>();
}
packet = internal::pgather_partial<Scalar, Packet>(buffer, 7, N);
internal::pstore_partial(data1, packet, N);