diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h index 7d062f3ad..edafb6610 100644 --- a/Eigen/src/Core/arch/GPU/PacketMath.h +++ b/Eigen/src/Core/arch/GPU/PacketMath.h @@ -548,6 +548,15 @@ EIGEN_DEVICE_FUNC inline double2 ptrunc(const double2& a) { return make_double2(trunc(a.x), trunc(a.y)); } +template <> +EIGEN_DEVICE_FUNC inline float4 pround(const float4& a) { + return make_float4(roundf(a.x), roundf(a.y), roundf(a.z), roundf(a.w)); +} +template <> +EIGEN_DEVICE_FUNC inline double2 pround(const double2& a) { + return make_double2(round(a.x), round(a.y)); +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { float tmp = kernel.packet[0].y; kernel.packet[0].y = kernel.packet[1].x; diff --git a/Eigen/src/Core/arch/LSX/TypeCasting.h b/Eigen/src/Core/arch/LSX/TypeCasting.h index cda868067..0b2906b8b 100644 --- a/Eigen/src/Core/arch/LSX/TypeCasting.h +++ b/Eigen/src/Core/arch/LSX/TypeCasting.h @@ -18,6 +18,192 @@ namespace Eigen { namespace internal { +//============================================================================== +// type_casting_traits +//============================================================================== + +// float <-> double +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// float <-> integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// double <-> integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// int8_t <-> other integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// uint8_t <-> other integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// int16_t <-> wider integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// uint16_t <-> wider integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// int32_t <-> 64-bit integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + +// uint32_t <-> 64-bit integer types +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + //============================================================================== // preinterpret //============================================================================== @@ -93,42 +279,42 @@ EIGEN_STRONG_INLINE Packet2ul preinterpret(const Packet2l& template <> EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { Packet2d tmp = __lsx_vfcvtl_d_s(a); - return __lsx_vftint_l_d(tmp); + return __lsx_vftintrz_l_d(tmp); } template <> EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4f& a) { Packet2d tmp = __lsx_vfcvtl_d_s(a); - return __lsx_vftint_lu_d(tmp); + return __lsx_vftintrz_lu_d(tmp); } template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet4f& a) { - return __lsx_vftint_w_s(a); + return __lsx_vftintrz_w_s(a); } template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4f& a) { - return __lsx_vftint_wu_s(a); + return __lsx_vftintrz_wu_s(a); } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet4f& a, const Packet4f& b) { - return __lsx_vssrlni_h_w(__lsx_vftint_w_s(a), __lsx_vftint_w_s(b), 0); + return __lsx_vpickev_h(__lsx_vftintrz_w_s(b), __lsx_vftintrz_w_s(a)); } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet4f& a, const Packet4f& b) { - return __lsx_vssrlni_hu_w(__lsx_vftint_wu_s(a), __lsx_vftint_wu_s(b), 0); + return __lsx_vpickev_h(__lsx_vftintrz_wu_s(b), __lsx_vftintrz_wu_s(a)); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet4f& a, const Packet4f& b, const Packet4f& c, const Packet4f& d) { - Packet8s tmp1 = __lsx_vssrlni_h_w(__lsx_vftint_w_s(a), __lsx_vftint_w_s(b), 0); - Packet8s tmp2 = __lsx_vssrlni_h_w(__lsx_vftint_w_s(c), __lsx_vftint_w_s(d), 0); - return __lsx_vssrlni_b_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8s tmp1 = __lsx_vpickev_h(__lsx_vftintrz_w_s(b), __lsx_vftintrz_w_s(a)); + Packet8s tmp2 = __lsx_vpickev_h(__lsx_vftintrz_w_s(d), __lsx_vftintrz_w_s(c)); + return __lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4f& a, const Packet4f& b, const Packet4f& c, const Packet4f& d) { - Packet8us tmp1 = __lsx_vssrlni_hu_w(__lsx_vftint_wu_s(a), __lsx_vftint_wu_s(b), 0); - Packet8us tmp2 = __lsx_vssrlni_hu_w(__lsx_vftint_wu_s(c), __lsx_vftint_wu_s(d), 0); - return __lsx_vssrlni_bu_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8us tmp1 = __lsx_vpickev_h(__lsx_vftintrz_wu_s(b), __lsx_vftintrz_wu_s(a)); + Packet8us tmp2 = __lsx_vpickev_h(__lsx_vftintrz_wu_s(d), __lsx_vftintrz_wu_s(c)); + return __lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> @@ -230,11 +416,11 @@ EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8s& a) { } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet8s& a, const Packet8s& b) { - return __lsx_vssrlni_b_h((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_b((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8s& a, const Packet8s& b) { - return (Packet16uc)__lsx_vssrlni_b_h((__m128i)a, (__m128i)b, 0); + return (Packet16uc)__lsx_vpickev_b((__m128i)b, (__m128i)a); } template <> @@ -262,11 +448,11 @@ EIGEN_STRONG_INLINE Packet4i pcast(const Packet8us& a) { } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8us& a, const Packet8us& b) { - return __lsx_vssrlni_bu_h((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_b((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet8us& a, const Packet8us& b) { - return (Packet16c)__lsx_vssrlni_bu_h((__m128i)a, (__m128i)b, 0); + return (Packet16c)__lsx_vpickev_b((__m128i)b, (__m128i)a); } template <> @@ -283,25 +469,25 @@ EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4i& a) { } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet4i& a, const Packet4i& b) { - return __lsx_vssrlni_h_w((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_h((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet4i& a, const Packet4i& b) { - return (Packet8us)__lsx_vssrlni_h_w((__m128i)a, (__m128i)b, 0); + return (Packet8us)__lsx_vpickev_h((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet4i& a, const Packet4i& b, const Packet4i& c, const Packet4i& d) { - Packet8s tmp1 = __lsx_vssrlni_h_w((__m128i)a, (__m128i)b, 0); - Packet8s tmp2 = __lsx_vssrlni_h_w((__m128i)c, (__m128i)d, 0); - return __lsx_vssrlni_b_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8s tmp1 = __lsx_vpickev_h((__m128i)b, (__m128i)a); + Packet8s tmp2 = __lsx_vpickev_h((__m128i)d, (__m128i)c); + return __lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4i& a, const Packet4i& b, const Packet4i& c, const Packet4i& d) { - Packet8s tmp1 = __lsx_vssrlni_h_w((__m128i)a, (__m128i)b, 0); - Packet8s tmp2 = __lsx_vssrlni_h_w((__m128i)c, (__m128i)d, 0); - return (Packet16uc)__lsx_vssrlni_b_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8s tmp1 = __lsx_vpickev_h((__m128i)b, (__m128i)a); + Packet8s tmp2 = __lsx_vpickev_h((__m128i)d, (__m128i)c); + return (Packet16uc)__lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> @@ -318,52 +504,52 @@ EIGEN_STRONG_INLINE Packet2l pcast(const Packet4ui& a) { } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet4ui& a, const Packet4ui& b) { - return __lsx_vssrlni_hu_w((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_h((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet4ui& a, const Packet4ui& b) { - return (Packet8s)__lsx_vssrlni_hu_w((__m128i)a, (__m128i)b, 0); + return (Packet8s)__lsx_vpickev_h((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c, const Packet4ui& d) { - Packet8us tmp1 = __lsx_vssrlni_hu_w((__m128i)a, (__m128i)b, 0); - Packet8us tmp2 = __lsx_vssrlni_hu_w((__m128i)c, (__m128i)d, 0); - return __lsx_vssrlni_bu_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8us tmp1 = __lsx_vpickev_h((__m128i)b, (__m128i)a); + Packet8us tmp2 = __lsx_vpickev_h((__m128i)d, (__m128i)c); + return __lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c, const Packet4ui& d) { - Packet8us tmp1 = __lsx_vssrlni_hu_w((__m128i)a, (__m128i)b, 0); - Packet8us tmp2 = __lsx_vssrlni_hu_w((__m128i)c, (__m128i)d, 0); - return (Packet16c)__lsx_vssrlni_bu_h((__m128i)tmp1, (__m128i)tmp2, 0); + Packet8us tmp1 = __lsx_vpickev_h((__m128i)b, (__m128i)a); + Packet8us tmp2 = __lsx_vpickev_h((__m128i)d, (__m128i)c); + return (Packet16c)__lsx_vpickev_b((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2l& a, const Packet2l& b) { - return __lsx_vffint_s_w(__lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0)); + return __lsx_vfcvt_s_d(__lsx_vffint_d_l(b), __lsx_vffint_d_l(a)); } template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2l& a, const Packet2l& b) { - return __lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_w((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2l& a, const Packet2l& b) { - return (Packet4ui)__lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0); + return (Packet4ui)__lsx_vpickev_w((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, const Packet2l& d) { - Packet4i tmp1 = __lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0); - Packet4i tmp2 = __lsx_vssrlni_w_d((__m128i)c, (__m128i)d, 0); - return __lsx_vssrlni_h_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4i tmp1 = __lsx_vpickev_w((__m128i)b, (__m128i)a); + Packet4i tmp2 = __lsx_vpickev_w((__m128i)d, (__m128i)c); + return __lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, const Packet2l& d) { - Packet4i tmp1 = __lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0); - Packet4i tmp2 = __lsx_vssrlni_w_d((__m128i)c, (__m128i)d, 0); - return (Packet8us)__lsx_vssrlni_h_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4i tmp1 = __lsx_vpickev_w((__m128i)b, (__m128i)a); + Packet4i tmp2 = __lsx_vpickev_w((__m128i)d, (__m128i)c); + return (Packet8us)__lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, @@ -371,7 +557,7 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2l& a, cons const Packet2l& g, const Packet2l& h) { const Packet8s abcd = pcast(a, b, c, d); const Packet8s efgh = pcast(e, f, g, h); - return __lsx_vssrlni_b_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, @@ -379,34 +565,34 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2l& a, co const Packet2l& g, const Packet2l& h) { const Packet8us abcd = pcast(a, b, c, d); const Packet8us efgh = pcast(e, f, g, h); - return __lsx_vssrlni_bu_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2ul& a, const Packet2ul& b) { - return __lsx_vffint_s_wu(__lsx_vssrlni_w_d((__m128i)a, (__m128i)b, 0)); + return __lsx_vfcvt_s_d(__lsx_vffint_d_lu(b), __lsx_vffint_d_lu(a)); } template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2ul& a, const Packet2ul& b) { - return __lsx_vssrlni_wu_d((__m128i)a, (__m128i)b, 0); + return __lsx_vpickev_w((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2ul& a, const Packet2ul& b) { - return (Packet4i)__lsx_vssrlni_wu_d((__m128i)a, (__m128i)b, 0); + return (Packet4i)__lsx_vpickev_w((__m128i)b, (__m128i)a); } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, const Packet2ul& d) { - Packet4ui tmp1 = __lsx_vssrlni_wu_d((__m128i)a, (__m128i)b, 0); - Packet4ui tmp2 = __lsx_vssrlni_wu_d((__m128i)c, (__m128i)d, 0); - return __lsx_vssrlni_hu_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4ui tmp1 = __lsx_vpickev_w((__m128i)b, (__m128i)a); + Packet4ui tmp2 = __lsx_vpickev_w((__m128i)d, (__m128i)c); + return __lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, const Packet2ul& d) { - Packet4ui tmp1 = __lsx_vssrlni_wu_d((__m128i)a, (__m128i)b, 0); - Packet4ui tmp2 = __lsx_vssrlni_wu_d((__m128i)c, (__m128i)d, 0); - return (Packet8s)__lsx_vssrlni_hu_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4ui tmp1 = __lsx_vpickev_w((__m128i)b, (__m128i)a); + Packet4ui tmp2 = __lsx_vpickev_w((__m128i)d, (__m128i)c); + return (Packet8s)__lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, @@ -414,7 +600,7 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2ul& a, const Packet2ul& g, const Packet2ul& h) { const Packet8s abcd = pcast(a, b, c, d); const Packet8s efgh = pcast(e, f, g, h); - return __lsx_vssrlni_b_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, @@ -422,7 +608,7 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2ul& a, co const Packet2ul& g, const Packet2ul& h) { const Packet8us abcd = pcast(a, b, c, d); const Packet8us efgh = pcast(e, f, g, h); - return __lsx_vssrlni_bu_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> @@ -431,33 +617,33 @@ EIGEN_STRONG_INLINE Packet4f pcast(const Packet2d& a, const } template <> EIGEN_STRONG_INLINE Packet2l pcast(const Packet2d& a) { - return __lsx_vftint_l_d(a); + return __lsx_vftintrz_l_d(a); } template <> EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2d& a) { - return __lsx_vftint_lu_d(a); + return __lsx_vftintrz_lu_d(a); } template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2d& a, const Packet2d& b) { - return __lsx_vssrlni_w_d(__lsx_vftint_l_d(a), __lsx_vftint_l_d(b), 0); + return __lsx_vpickev_w(__lsx_vftintrz_l_d(b), __lsx_vftintrz_l_d(a)); } template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2d& a, const Packet2d& b) { - return __lsx_vssrlni_wu_d(__lsx_vftint_lu_d(a), __lsx_vftint_lu_d(b), 0); + return __lsx_vpickev_w(__lsx_vftintrz_lu_d(b), __lsx_vftintrz_lu_d(a)); } template <> EIGEN_STRONG_INLINE Packet8s pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, const Packet2d& d) { - Packet4i tmp1 = __lsx_vssrlni_w_d(__lsx_vftint_l_d(a), __lsx_vftint_l_d(b), 0); - Packet4i tmp2 = __lsx_vssrlni_w_d(__lsx_vftint_l_d(c), __lsx_vftint_l_d(d), 0); - return __lsx_vssrlni_h_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4i tmp1 = __lsx_vpickev_w(__lsx_vftintrz_l_d(b), __lsx_vftintrz_l_d(a)); + Packet4i tmp2 = __lsx_vpickev_w(__lsx_vftintrz_l_d(d), __lsx_vftintrz_l_d(c)); + return __lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, const Packet2d& d) { - Packet4ui tmp1 = __lsx_vssrlni_wu_d(__lsx_vftint_lu_d(a), __lsx_vftint_lu_d(b), 0); - Packet4ui tmp2 = __lsx_vssrlni_wu_d(__lsx_vftint_lu_d(c), __lsx_vftint_lu_d(d), 0); - return __lsx_vssrlni_hu_w((__m128i)tmp1, (__m128i)tmp2, 0); + Packet4ui tmp1 = __lsx_vpickev_w(__lsx_vftintrz_lu_d(b), __lsx_vftintrz_lu_d(a)); + Packet4ui tmp2 = __lsx_vpickev_w(__lsx_vftintrz_lu_d(d), __lsx_vftintrz_lu_d(c)); + return __lsx_vpickev_h((__m128i)tmp2, (__m128i)tmp1); } template <> EIGEN_STRONG_INLINE Packet16c pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, @@ -465,7 +651,7 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2d& a, cons const Packet2d& g, const Packet2d& h) { const Packet8s abcd = pcast(a, b, c, d); const Packet8s efgh = pcast(e, f, g, h); - return __lsx_vssrlni_b_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, @@ -473,7 +659,7 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2d& a, co const Packet2d& g, const Packet2d& h) { const Packet8us abcd = pcast(a, b, c, d); const Packet8us efgh = pcast(e, f, g, h); - return __lsx_vssrlni_bu_h((__m128i)abcd, (__m128i)efgh, 0); + return __lsx_vpickev_b((__m128i)efgh, (__m128i)abcd); } template <> diff --git a/Eigen/src/Core/arch/MSA/Complex.h b/Eigen/src/Core/arch/MSA/Complex.h index fbba642eb..a0771a198 100644 --- a/Eigen/src/Core/arch/MSA/Complex.h +++ b/Eigen/src/Core/arch/MSA/Complex.h @@ -100,6 +100,9 @@ struct packet_traits > : default_packet_traits { HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, + HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -393,6 +396,8 @@ struct packet_traits > : default_packet_traits { HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, + HasLog = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -606,6 +611,9 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { kernel.packet[1].v = v2; } +EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS(Packet2cf) +EIGEN_INSTANTIATE_COMPLEX_MATH_FUNCS_NO_EXP(Packet1cd) + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/MSA/PacketMath.h b/Eigen/src/Core/arch/MSA/PacketMath.h index 558893ee6..40e4ccdc2 100644 --- a/Eigen/src/Core/arch/MSA/PacketMath.h +++ b/Eigen/src/Core/arch/MSA/PacketMath.h @@ -789,6 +789,39 @@ EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) { return v; } +template <> +EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { + // frint.w uses the current rounding mode (default: round to nearest, ties to even). + Packet4f v = a; + asm volatile("frint.w %w[v], %w[v]\n" : [v] "+f"(v)); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4f ptrunc(const Packet4f& a) { + Packet4f v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 2\n" // 1 = round toward zero. + "ctcmsa $1, %[new_mode]\n" + "frint.w %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { + return (Packet4f)__builtin_msa_fcult_w(a, b); +} + //---------- double ---------- typedef v2f64 Packet2d; @@ -1192,6 +1225,39 @@ EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) { return v; } +template <> +EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) { + // frint.d uses the current rounding mode (default: round to nearest, ties to even). + Packet2d v = a; + asm volatile("frint.d %w[v], %w[v]\n" : [v] "+f"(v)); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet2d ptrunc(const Packet2d& a) { + Packet2d v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 2\n" // 1 = round toward zero. + "ctcmsa $1, %[new_mode]\n" + "frint.d %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { + return (Packet2d)__builtin_msa_fcult_d(a, b); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h index 310a40cbe..611ba793b 100644 --- a/Eigen/src/Core/arch/SVE/PacketMath.h +++ b/Eigen/src/Core/arch/SVE/PacketMath.h @@ -497,6 +497,22 @@ template <> EIGEN_STRONG_INLINE PacketXf pfloor(const PacketXf& a) { return svrintm_f32_x(svptrue_b32(), a); } +template <> +EIGEN_STRONG_INLINE PacketXf pceil(const PacketXf& a) { + return svrintp_f32_x(svptrue_b32(), a); +} +template <> +EIGEN_STRONG_INLINE PacketXf print(const PacketXf& a) { + return svrintn_f32_x(svptrue_b32(), a); +} +template <> +EIGEN_STRONG_INLINE PacketXf ptrunc(const PacketXf& a) { + return svrintz_f32_x(svptrue_b32(), a); +} +template <> +EIGEN_STRONG_INLINE PacketXf pround(const PacketXf& a) { + return svrinta_f32_x(svptrue_b32(), a); +} template <> EIGEN_STRONG_INLINE PacketXf ptrue(const PacketXf& /*a*/) { diff --git a/Eigen/src/Core/arch/ZVector/PacketMath.h b/Eigen/src/Core/arch/ZVector/PacketMath.h index 4ccda873d..e16660934 100644 --- a/Eigen/src/Core/arch/ZVector/PacketMath.h +++ b/Eigen/src/Core/arch/ZVector/PacketMath.h @@ -549,6 +549,18 @@ template <> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) { return vec_floor(a); } +template <> +EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) { + return __builtin_s390_vfidb(a, 4, 5); +} +template <> +EIGEN_STRONG_INLINE Packet2d ptrunc(const Packet2d& a) { + return __builtin_s390_vfidb(a, 4, 4); +} +template <> +EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { + return pnot(pcmp_le(b, a)); +} template <> EIGEN_STRONG_INLINE Packet4i ploadu(const int* from) { @@ -939,6 +951,20 @@ EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) { res.v4f[1] = vec_floor(a.v4f[1]); return res; } +template <> +EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { + Packet4f res; + res.v4f[0] = print(a.v4f[0]); + res.v4f[1] = print(a.v4f[1]); + return res; +} +template <> +EIGEN_STRONG_INLINE Packet4f ptrunc(const Packet4f& a) { + Packet4f res; + res.v4f[0] = ptrunc(a.v4f[0]); + res.v4f[1] = ptrunc(a.v4f[1]); + return res; +} template <> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) { @@ -1063,6 +1089,13 @@ Packet4f EIGEN_STRONG_INLINE pcmp_eq(const Packet4f& a, const Packet4f res.v4f[1] = pcmp_eq(a.v4f[1], b.v4f[1]); return res; } +template <> +Packet4f EIGEN_STRONG_INLINE pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { + Packet4f res; + res.v4f[0] = pcmp_lt_or_nan(a.v4f[0], b.v4f[0]); + res.v4f[1] = pcmp_lt_or_nan(a.v4f[1], b.v4f[1]); + return res; +} #else template <> @@ -1177,6 +1210,18 @@ EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) { return vec_floor(a); } template <> +EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { + return __builtin_s390_vfisb(a, 4, 5); +} +template <> +EIGEN_STRONG_INLINE Packet4f ptrunc(const Packet4f& a) { + return __builtin_s390_vfisb(a, 4, 4); +} +template <> +EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { + return pnot(pcmp_le(b, a)); +} +template <> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }