From 7aea350ba1177db2ec7ae3bbb615450b0f74279f Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 9 Jan 2026 12:16:28 -0500 Subject: [PATCH] Fix more packetmath issues for RVV libeigen/eigen!2105 --- Eigen/src/Core/arch/RVV10/PacketMath.h | 19 +++++++--- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 40 ++++++++++++++-------- Eigen/src/Core/arch/RVV10/PacketMathFP16.h | 38 ++++++++++++-------- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMath.h b/Eigen/src/Core/arch/RVV10/PacketMath.h index 909fd88ce..679d5c1fc 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMath.h +++ b/Eigen/src/Core/arch/RVV10/PacketMath.h @@ -1665,11 +1665,20 @@ EIGEN_STRONG_INLINE Packet1Xd plset(const double& a) { template <> EIGEN_STRONG_INLINE void pbroadcast4(const double* a, Packet1Xd& a0, Packet1Xd& a1, Packet1Xd& a2, Packet1Xd& a3) { - vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4); - a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits::size); - a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits::size); - a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits::size); - a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits::size); + if (EIGEN_RISCV64_RVV_VL >= 256) { + vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4); + a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits::size); + a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits::size); + a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits::size); + a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits::size); + } else { + vfloat64m1_t aa0 = __riscv_vle64_v_f64m1(a + 0, 2); + vfloat64m1_t aa1 = __riscv_vle64_v_f64m1(a + 2, 2); + a0 = __riscv_vrgather_vx_f64m1(aa0, 0, unpacket_traits::size); + a1 = __riscv_vrgather_vx_f64m1(aa0, 1, unpacket_traits::size); + a2 = __riscv_vrgather_vx_f64m1(aa1, 0, unpacket_traits::size); + a3 = __riscv_vrgather_vx_f64m1(aa1, 1, unpacket_traits::size); + } } template <> diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index ec0e42b2a..2522efd99 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -164,11 +164,6 @@ EIGEN_STRONG_INLINE Packet1Xbf pabs(const Packet1Xbf& a) { unpacket_traits::size)); } -template <> -EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) { - return F32ToBf16(pabsdiff(Bf16ToF32(a), Bf16ToF32(b))); -} - template <> EIGEN_STRONG_INLINE Packet1Xbf pset1(const bfloat16& from) { return __riscv_vreinterpret_bf16m1( @@ -197,12 +192,23 @@ EIGEN_STRONG_INLINE void pbroadcast4(const bfloat16* a, Packet1Xbf& template <> EIGEN_STRONG_INLINE Packet1Xbf padd(const Packet1Xbf& a, const Packet1Xbf& b) { - return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); + // b + (1 * a) + return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(b), + numext::bit_cast<__bf16>(static_cast(0x3f80u)), a, + unpacket_traits::size)); } template <> EIGEN_STRONG_INLINE Packet1Xbf psub(const Packet1Xbf& a, const Packet1Xbf& b) { - return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); + // a + (-1 * b) + return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(a), + numext::bit_cast<__bf16>(static_cast(0xbf80u)), b, + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) { + return pabs(psub(a, b)); } template <> @@ -493,11 +499,6 @@ EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) { unpacket_traits::size)); } -template <> -EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) { - return F32ToBf16(pabsdiff(Bf16ToF32(a), Bf16ToF32(b))); -} - template <> EIGEN_STRONG_INLINE Packet2Xbf pset1(const bfloat16& from) { return __riscv_vreinterpret_bf16m2( @@ -526,12 +527,23 @@ EIGEN_STRONG_INLINE void pbroadcast4(const bfloat16* a, Packet2Xbf& template <> EIGEN_STRONG_INLINE Packet2Xbf padd(const Packet2Xbf& a, const Packet2Xbf& b) { - return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); + // b + (1 * a) + return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(b), + numext::bit_cast<__bf16>(static_cast(0x3f80u)), a, + unpacket_traits::size)); } template <> EIGEN_STRONG_INLINE Packet2Xbf psub(const Packet2Xbf& a, const Packet2Xbf& b) { - return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); + // a + (-1 * b) + return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(a), + numext::bit_cast<__bf16>(static_cast(0xbf80u)), b, + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) { + return pabs(psub(a, b)); } template <> diff --git a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h index 517b703e3..73118a31c 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h @@ -480,17 +480,21 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet1Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet1Xh& a) { const Eigen::half max = (std::numeric_limits::max)(); - return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), - unpacket_traits::size))); + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); + return (std::min)(static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1( + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size), + unpacket_traits::size))), + max); } template <> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet1Xh& a) { - const Eigen::half min = (std::numeric_limits::min)(); - return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), - unpacket_traits::size))); + const Eigen::half min = -(std::numeric_limits::max)(); + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); + return (std::max)(static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1( + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size), + unpacket_traits::size))), + min); } template @@ -837,17 +841,23 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet2Xh& a) { const Eigen::half max = (std::numeric_limits::max)(); - return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size / 2), - unpacket_traits::size))); + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); + return (std::min)( + static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1( + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size / 2), + unpacket_traits::size))), + max); } template <> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet2Xh& a) { - const Eigen::half min = (std::numeric_limits::min)(); - return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size / 2), - unpacket_traits::size))); + const Eigen::half min = -(std::numeric_limits::max)(); + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); + return (std::max)( + static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1( + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size / 2), + unpacket_traits::size))), + min); } template