Fix more packetmath issues for RVV

libeigen/eigen!2105
This commit is contained in:
Chip Kerchner
2026-01-09 12:16:28 -05:00
parent 5d9beb81ab
commit 7aea350ba1
3 changed files with 64 additions and 33 deletions

View File

@@ -1665,11 +1665,20 @@ EIGEN_STRONG_INLINE Packet1Xd plset<Packet1Xd>(const double& a) {
template <>
EIGEN_STRONG_INLINE void pbroadcast4<Packet1Xd>(const double* a, Packet1Xd& a0, Packet1Xd& a1, Packet1Xd& a2,
Packet1Xd& a3) {
vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4);
a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits<Packet1Xd>::size);
a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits<Packet1Xd>::size);
a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits<Packet1Xd>::size);
a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits<Packet1Xd>::size);
if (EIGEN_RISCV64_RVV_VL >= 256) {
vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4);
a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits<Packet1Xd>::size);
a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits<Packet1Xd>::size);
a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits<Packet1Xd>::size);
a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits<Packet1Xd>::size);
} else {
vfloat64m1_t aa0 = __riscv_vle64_v_f64m1(a + 0, 2);
vfloat64m1_t aa1 = __riscv_vle64_v_f64m1(a + 2, 2);
a0 = __riscv_vrgather_vx_f64m1(aa0, 0, unpacket_traits<Packet1Xd>::size);
a1 = __riscv_vrgather_vx_f64m1(aa0, 1, unpacket_traits<Packet1Xd>::size);
a2 = __riscv_vrgather_vx_f64m1(aa1, 0, unpacket_traits<Packet1Xd>::size);
a3 = __riscv_vrgather_vx_f64m1(aa1, 1, unpacket_traits<Packet1Xd>::size);
}
}
template <>

View File

@@ -164,11 +164,6 @@ EIGEN_STRONG_INLINE Packet1Xbf pabs(const Packet1Xbf& a) {
unpacket_traits<Packet1Xs>::size));
}
template <>
EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) {
return F32ToBf16(pabsdiff<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet1Xbf pset1<Packet1Xbf>(const bfloat16& from) {
return __riscv_vreinterpret_bf16m1(
@@ -197,12 +192,23 @@ EIGEN_STRONG_INLINE void pbroadcast4<Packet1Xbf>(const bfloat16* a, Packet1Xbf&
template <>
EIGEN_STRONG_INLINE Packet1Xbf padd<Packet1Xbf>(const Packet1Xbf& a, const Packet1Xbf& b) {
return F32ToBf16(padd<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
// b + (1 * a)
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(b),
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0x3f80u)), a,
unpacket_traits<Packet1Xbf>::size));
}
template <>
EIGEN_STRONG_INLINE Packet1Xbf psub<Packet1Xbf>(const Packet1Xbf& a, const Packet1Xbf& b) {
return F32ToBf16(psub<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
// a + (-1 * b)
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(a),
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0xbf80u)), b,
unpacket_traits<Packet1Xbf>::size));
}
template <>
EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) {
return pabs<Packet1Xbf>(psub<Packet1Xbf>(a, b));
}
template <>
@@ -493,11 +499,6 @@ EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) {
unpacket_traits<Packet2Xs>::size));
}
template <>
EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) {
return F32ToBf16(pabsdiff<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet2Xbf pset1<Packet2Xbf>(const bfloat16& from) {
return __riscv_vreinterpret_bf16m2(
@@ -526,12 +527,23 @@ EIGEN_STRONG_INLINE void pbroadcast4<Packet2Xbf>(const bfloat16* a, Packet2Xbf&
template <>
EIGEN_STRONG_INLINE Packet2Xbf padd<Packet2Xbf>(const Packet2Xbf& a, const Packet2Xbf& b) {
return F32ToBf16(padd<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
// b + (1 * a)
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(b),
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0x3f80u)), a,
unpacket_traits<Packet2Xbf>::size));
}
template <>
EIGEN_STRONG_INLINE Packet2Xbf psub<Packet2Xbf>(const Packet2Xbf& a, const Packet2Xbf& b) {
return F32ToBf16(psub<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
// a + (-1 * b)
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(a),
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0xbf80u)), b,
unpacket_traits<Packet2Xbf>::size));
}
template <>
EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) {
return pabs<Packet2Xbf>(psub<Packet2Xbf>(a, b));
}
template <>

View File

@@ -480,17 +480,21 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet1Xh>(const Packet1Xh& a) {
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet1Xh>(const Packet1Xh& a) {
const Eigen::half max = (std::numeric_limits<Eigen::half>::max)();
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits<Packet1Xh>::size),
unpacket_traits<Packet1Xh>::size)));
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
return (std::min)(static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet1Xh>::size),
unpacket_traits<Packet1Xh>::size))),
max);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet1Xh>(const Packet1Xh& a) {
const Eigen::half min = (std::numeric_limits<Eigen::half>::min)();
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits<Packet1Xh>::size),
unpacket_traits<Packet1Xh>::size)));
const Eigen::half min = -(std::numeric_limits<Eigen::half>::max)();
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
return (std::max)(static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet1Xh>::size),
unpacket_traits<Packet1Xh>::size))),
min);
}
template <int N>
@@ -837,17 +841,23 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet2Xh>(const Packet2Xh& a) {
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet2Xh>(const Packet2Xh& a) {
const Eigen::half max = (std::numeric_limits<Eigen::half>::max)();
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits<Packet2Xh>::size / 2),
unpacket_traits<Packet2Xh>::size)));
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
return (std::min)(
static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet2Xh>::size / 2),
unpacket_traits<Packet2Xh>::size))),
max);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet2Xh>(const Packet2Xh& a) {
const Eigen::half min = (std::numeric_limits<Eigen::half>::min)();
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits<Packet2Xh>::size / 2),
unpacket_traits<Packet2Xh>::size)));
const Eigen::half min = -(std::numeric_limits<Eigen::half>::max)();
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
return (std::max)(
static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1(
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet2Xh>::size / 2),
unpacket_traits<Packet2Xh>::size))),
min);
}
template <int N>