diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 381d8fff5..8a07d50fe 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -709,33 +709,21 @@ EIGEN_DEVICE_FUNC inline Packet parg(const Packet& a) { } /** \internal \returns \a a arithmetically shifted by N bits to the right */ -template -EIGEN_DEVICE_FUNC inline int parithmetic_shift_right(const int& a) { - return a >> N; -} -template -EIGEN_DEVICE_FUNC inline long int parithmetic_shift_right(const long int& a) { - return a >> N; +template +EIGEN_DEVICE_FUNC inline T parithmetic_shift_right(const T& a) { + return numext::arithmetic_shift_right(a, N); } /** \internal \returns \a a logically shifted by N bits to the right */ -template -EIGEN_DEVICE_FUNC inline int plogical_shift_right(const int& a) { - return static_cast(static_cast(a) >> N); -} -template -EIGEN_DEVICE_FUNC inline long int plogical_shift_right(const long int& a) { - return static_cast(static_cast(a) >> N); +template +EIGEN_DEVICE_FUNC inline T plogical_shift_right(const T& a) { + return numext::logical_shift_right(a, N); } /** \internal \returns \a a shifted by N bits to the left */ -template -EIGEN_DEVICE_FUNC inline int plogical_shift_left(const int& a) { - return a << N; -} -template -EIGEN_DEVICE_FUNC inline long int plogical_shift_left(const long int& a) { - return a << N; +template +EIGEN_DEVICE_FUNC inline T plogical_shift_left(const T& a) { + return numext::logical_shift_left(a, N); } /** \internal \returns the significant and exponent of the underlying floating point numbers diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 6bb9a1202..d42fc93cc 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -1746,6 +1746,23 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double fmod(const double& a, const double& #undef SYCL_SPECIALIZE_BINARY_FUNC #endif +template ::value>> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar logical_shift_left(const Scalar& a, int n) { + return a << n; +} + +template ::value>> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar logical_shift_right(const Scalar& a, int n) { + using UnsignedScalar = typename numext::get_integer_by_size::unsigned_type; + return bit_cast(bit_cast(a) >> n); +} + +template ::value>> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar& a, int n) { + using SignedScalar = typename numext::get_integer_by_size::signed_type; + return bit_cast(bit_cast(a) >> n); +} + } // end namespace numext namespace internal { diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h index 2848b7896..a6e2de477 100644 --- a/Eigen/src/Core/NumTraits.h +++ b/Eigen/src/Core/NumTraits.h @@ -101,10 +101,10 @@ namespace numext { template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) { // The behaviour of memcpy is not specified for non-trivially copyable types - EIGEN_STATIC_ASSERT(std::is_trivially_copyable::value, THIS_TYPE_IS_NOT_SUPPORTED); + EIGEN_STATIC_ASSERT(std::is_trivially_copyable::value, THIS_TYPE_IS_NOT_SUPPORTED) EIGEN_STATIC_ASSERT(std::is_trivially_copyable::value && std::is_default_constructible::value, - THIS_TYPE_IS_NOT_SUPPORTED); - EIGEN_STATIC_ASSERT(sizeof(Src) == sizeof(Tgt), THIS_TYPE_IS_NOT_SUPPORTED); + THIS_TYPE_IS_NOT_SUPPORTED) + EIGEN_STATIC_ASSERT(sizeof(Src) == sizeof(Tgt), THIS_TYPE_IS_NOT_SUPPORTED) Tgt tgt; // Load src into registers first. This allows the memcpy to be elided by CUDA. diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 2b0c05ce4..c1bbc7c28 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -219,7 +219,9 @@ struct functor_traits> { */ template struct scalar_shift_right_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return a >> N; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return numext::arithmetic_shift_right(a); + } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { return internal::parithmetic_shift_right(a); @@ -237,7 +239,9 @@ struct functor_traits> { */ template struct scalar_shift_left_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return a << N; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return numext::logical_shift_left(a); + } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { return internal::plogical_shift_left(a); diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index b5ad3c46f..3b36328fa 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -1068,24 +1068,45 @@ void min_max(const ArrayType& m) { } } -template -struct shift_left { - template - Scalar operator()(const Scalar& v) const { - return (v << N); +template +struct shift_imm_traits { + enum { Cost = 1, PacketAccess = internal::packet_traits::HasShift }; +}; + +template +struct logical_left_shift_op { + Scalar operator()(const Scalar& v) const { return numext::logical_shift_left(v, N); } + template + Packet packetOp(const Packet& v) const { + return internal::plogical_shift_left(v); + } +}; +template +struct logical_right_shift_op { + Scalar operator()(const Scalar& v) const { return numext::logical_shift_right(v, N); } + template + Packet packetOp(const Packet& v) const { + return internal::plogical_shift_right(v); + } +}; +template +struct arithmetic_right_shift_op { + Scalar operator()(const Scalar& v) const { return numext::arithmetic_shift_right(v, N); } + template + Packet packetOp(const Packet& v) const { + return internal::parithmetic_shift_right(v); } }; -template -struct arithmetic_shift_right { - template - Scalar operator()(const Scalar& v) const { - return (v >> N); - } -}; +template +struct internal::functor_traits> : shift_imm_traits {}; +template +struct internal::functor_traits> : shift_imm_traits {}; +template +struct internal::functor_traits> : shift_imm_traits {}; template -struct signed_shift_test_impl { +struct shift_test_impl { typedef typename ArrayType::Scalar Scalar; static constexpr size_t Size = sizeof(Scalar); static constexpr size_t MaxShift = (CHAR_BIT * Size) - 1; @@ -1099,20 +1120,24 @@ struct signed_shift_test_impl { ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols), m3(rows, cols); - m2 = m1.unaryExpr(internal::scalar_shift_right_op()); - m3 = m1.unaryExpr(arithmetic_shift_right()); + m2 = m1.unaryExpr([](const Scalar& v) { return numext::logical_shift_left(v, N); }); + m3 = m1.unaryExpr(logical_left_shift_op()); VERIFY_IS_CWISE_EQUAL(m2, m3); - m2 = m1.unaryExpr(internal::scalar_shift_left_op()); - m3 = m1.unaryExpr(shift_left()); + m2 = m1.unaryExpr([](const Scalar& v) { return numext::logical_shift_right(v, N); }); + m3 = m1.unaryExpr(logical_right_shift_op()); + VERIFY_IS_CWISE_EQUAL(m2, m3); + + m2 = m1.unaryExpr([](const Scalar& v) { return numext::arithmetic_shift_right(v, N); }); + m3 = m1.unaryExpr(arithmetic_right_shift_op()); VERIFY_IS_CWISE_EQUAL(m2, m3); run(m); } }; template -void signed_shift_test(const ArrayType& m) { - signed_shift_test_impl::run(m); +void shift_test(const ArrayType& m) { + shift_test_impl::run(m); } template @@ -1361,10 +1386,10 @@ EIGEN_DECLARE_TEST(array_cwise) { ArrayXXi(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_7(array_generic(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); - CALL_SUBTEST_8(signed_shift_test( + CALL_SUBTEST_8(shift_test( ArrayXXi(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); - CALL_SUBTEST_9(signed_shift_test(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), - internal::random(1, EIGEN_TEST_MAX_SIZE)))); + CALL_SUBTEST_9(shift_test(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), + internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_10(array_generic(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_11(array_generic(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), diff --git a/test/numext.cpp b/test/numext.cpp index a2d511bcb..ebe9fb069 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -292,6 +292,27 @@ void check_signbit() { check_signbit_impl::run(); } +template +void check_shift() { + using SignedT = typename numext::get_integer_by_size::signed_type; + using UnsignedT = typename numext::get_integer_by_size::unsigned_type; + constexpr int kNumBits = CHAR_BIT * sizeof(T); + for (int i = 0; i < 1000; ++i) { + const T a = internal::random(); + for (int s = 1; s < kNumBits; s++) { + T a_bsll = numext::logical_shift_left(a, s); + T a_bsll_ref = a << s; + VERIFY_IS_EQUAL(a_bsll, a_bsll_ref); + T a_bsrl = numext::logical_shift_right(a, s); + T a_bsrl_ref = numext::bit_cast(numext::bit_cast(a) >> s); + VERIFY_IS_EQUAL(a_bsrl, a_bsrl_ref); + T a_bsra = numext::arithmetic_shift_right(a, s); + T a_bsra_ref = numext::bit_cast(numext::bit_cast(a) >> s); + VERIFY_IS_EQUAL(a_bsra, a_bsra_ref); + } + } +} + EIGEN_DECLARE_TEST(numext) { for (int k = 0; k < g_repeat; ++k) { CALL_SUBTEST(check_negate()); @@ -354,5 +375,15 @@ EIGEN_DECLARE_TEST(numext) { CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); + + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); + + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); + CALL_SUBTEST(check_shift()); } }