From 9148c47d67a94c9557c29132956da8a7ebc4f5b1 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 29 May 2024 00:20:12 +0000 Subject: [PATCH] Vectorize isfinite and isinf. --- Eigen/src/Core/GenericPacketMath.h | 20 ++++--- Eigen/src/Core/functors/UnaryFunctors.h | 49 +++++++++++++---- .../Eigen/CXX11/src/Tensor/TensorBase.h | 9 ++-- unsupported/test/cxx11_tensor_comparisons.cpp | 53 +++++++++++++++++++ 4 files changed, 110 insertions(+), 21 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 8a07d50fe..e1347b985 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -579,12 +579,6 @@ EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); } -/** \internal \returns isnan(a) */ -template -EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) { - return pandnot(ptrue(a), pcmp_eq(a, a)); -} - // In the general case, use bitwise select. template struct pselect_impl { @@ -1002,6 +996,20 @@ EIGEN_DEVICE_FUNC inline Packet pcplxflip(const Packet& a) { * Special math functions ***************************/ +/** \internal \returns isnan(a) */ +template +EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) { + return pandnot(ptrue(a), pcmp_eq(a, a)); +} + +/** \internal \returns isinf(a) */ +template +EIGEN_DEVICE_FUNC inline Packet pisinf(const Packet& a) { + using Scalar = typename unpacket_traits::type; + constexpr Scalar inf = NumTraits::infinity(); + return pcmp_eq(pabs(a), pset1(inf)); +} + /** \internal \returns the sine of \a a (coeff-wise) */ template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin(const Packet& a) { diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index c1bbc7c28..5059a5408 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -989,10 +989,9 @@ struct functor_traits> { * \brief Template functor to check whether a scalar is +/-inf * \sa class CwiseUnaryOp, ArrayBase::isinf() */ -template +template struct scalar_isinf_op { - typedef bool result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { #if defined(SYCL_DEVICE_ONLY) return numext::isinf(a); #else @@ -1000,19 +999,33 @@ struct scalar_isinf_op { #endif } }; + template -struct functor_traits> { - enum { Cost = NumTraits::MulCost, PacketAccess = false }; +struct scalar_isinf_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const { +#if defined(SYCL_DEVICE_ONLY) + return (numext::isinf(a) ? ptrue(a) : pzero(a)); +#else + return (numext::isinf EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a)); +#endif + } + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { + return pisinf(a); + } +}; +template +struct functor_traits> { + enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasCmp && UseTypedPredicate }; }; /** \internal * \brief Template functor to check whether a scalar has a finite value * \sa class CwiseUnaryOp, ArrayBase::isfinite() */ -template +template struct scalar_isfinite_op { - typedef bool result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { #if defined(SYCL_DEVICE_ONLY) return numext::isfinite(a); #else @@ -1020,9 +1033,25 @@ struct scalar_isfinite_op { #endif } }; + template -struct functor_traits> { - enum { Cost = NumTraits::MulCost, PacketAccess = false }; +struct scalar_isfinite_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const { +#if defined(SYCL_DEVICE_ONLY) + return (numext::isfinite(a) ? ptrue(a) : pzero(a)); +#else + return (numext::isfinite EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a)); +#endif + } + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { + constexpr Scalar inf = NumTraits::infinity(); + return pcmp_lt(pabs(a), pset1(inf)); + } +}; +template +struct functor_traits> { + enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasCmp && UseTypedPredicate }; }; /** \internal diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index f88793ef2..2c2c7810e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -618,16 +618,15 @@ class TensorBase (isnan)() const { return unaryExpr(internal::scalar_isnan_op()).template cast(); } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + EIGEN_STRONG_INLINE const TensorConversionOp, const Derived>> (isinf)() const { - return unaryExpr(internal::scalar_isinf_op()); + return unaryExpr(internal::scalar_isinf_op()).template cast(); } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + EIGEN_STRONG_INLINE const TensorConversionOp, const Derived>> (isfinite)() const { - return unaryExpr(internal::scalar_isfinite_op()); + return unaryExpr(internal::scalar_isfinite_op()).template cast(); } // Coefficient-wise ternary operators. diff --git a/unsupported/test/cxx11_tensor_comparisons.cpp b/unsupported/test/cxx11_tensor_comparisons.cpp index 3565b624f..17a177607 100644 --- a/unsupported/test/cxx11_tensor_comparisons.cpp +++ b/unsupported/test/cxx11_tensor_comparisons.cpp @@ -132,8 +132,61 @@ static void test_isnan() { } } +static void test_isinf() { + Tensor mat(2, 3, 7); + + mat.setRandom(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + if (internal::random()) { + mat(i, j, k) = std::numeric_limits::infinity(); + } + } + } + } + Tensor inf(2, 3, 7); + inf = (mat.isinf)(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(inf(i, j, k), (std::isinf)(mat(i, j, k))); + } + } + } +} + +static void test_isfinite() { + Tensor mat(2, 3, 7); + + mat.setRandom(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + if (internal::random()) { + mat(i, j, k) = std::numeric_limits::infinity(); + } + if (internal::random()) { + mat(i, j, k) = std::numeric_limits::quiet_NaN(); + } + } + } + } + Tensor inf(2, 3, 7); + inf = (mat.isfinite)(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(inf(i, j, k), (std::isfinite)(mat(i, j, k))); + } + } + } +} + EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) { CALL_SUBTEST(test_orderings()); CALL_SUBTEST(test_equality()); CALL_SUBTEST(test_isnan()); + CALL_SUBTEST(test_isinf()); + CALL_SUBTEST(test_isfinite()); }