mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
@@ -1665,11 +1665,20 @@ EIGEN_STRONG_INLINE Packet1Xd plset<Packet1Xd>(const double& a) {
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pbroadcast4<Packet1Xd>(const double* a, Packet1Xd& a0, Packet1Xd& a1, Packet1Xd& a2,
|
||||
Packet1Xd& a3) {
|
||||
vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4);
|
||||
a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits<Packet1Xd>::size);
|
||||
a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits<Packet1Xd>::size);
|
||||
a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits<Packet1Xd>::size);
|
||||
a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits<Packet1Xd>::size);
|
||||
if (EIGEN_RISCV64_RVV_VL >= 256) {
|
||||
vfloat64m1_t aa = __riscv_vle64_v_f64m1(a, 4);
|
||||
a0 = __riscv_vrgather_vx_f64m1(aa, 0, unpacket_traits<Packet1Xd>::size);
|
||||
a1 = __riscv_vrgather_vx_f64m1(aa, 1, unpacket_traits<Packet1Xd>::size);
|
||||
a2 = __riscv_vrgather_vx_f64m1(aa, 2, unpacket_traits<Packet1Xd>::size);
|
||||
a3 = __riscv_vrgather_vx_f64m1(aa, 3, unpacket_traits<Packet1Xd>::size);
|
||||
} else {
|
||||
vfloat64m1_t aa0 = __riscv_vle64_v_f64m1(a + 0, 2);
|
||||
vfloat64m1_t aa1 = __riscv_vle64_v_f64m1(a + 2, 2);
|
||||
a0 = __riscv_vrgather_vx_f64m1(aa0, 0, unpacket_traits<Packet1Xd>::size);
|
||||
a1 = __riscv_vrgather_vx_f64m1(aa0, 1, unpacket_traits<Packet1Xd>::size);
|
||||
a2 = __riscv_vrgather_vx_f64m1(aa1, 0, unpacket_traits<Packet1Xd>::size);
|
||||
a3 = __riscv_vrgather_vx_f64m1(aa1, 1, unpacket_traits<Packet1Xd>::size);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -164,11 +164,6 @@ EIGEN_STRONG_INLINE Packet1Xbf pabs(const Packet1Xbf& a) {
|
||||
unpacket_traits<Packet1Xs>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) {
|
||||
return F32ToBf16(pabsdiff<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1Xbf pset1<Packet1Xbf>(const bfloat16& from) {
|
||||
return __riscv_vreinterpret_bf16m1(
|
||||
@@ -197,12 +192,23 @@ EIGEN_STRONG_INLINE void pbroadcast4<Packet1Xbf>(const bfloat16* a, Packet1Xbf&
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1Xbf padd<Packet1Xbf>(const Packet1Xbf& a, const Packet1Xbf& b) {
|
||||
return F32ToBf16(padd<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
// b + (1 * a)
|
||||
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(b),
|
||||
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0x3f80u)), a,
|
||||
unpacket_traits<Packet1Xbf>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1Xbf psub<Packet1Xbf>(const Packet1Xbf& a, const Packet1Xbf& b) {
|
||||
return F32ToBf16(psub<Packet2Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
// a + (-1 * b)
|
||||
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(a),
|
||||
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0xbf80u)), b,
|
||||
unpacket_traits<Packet1Xbf>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1Xbf pabsdiff(const Packet1Xbf& a, const Packet1Xbf& b) {
|
||||
return pabs<Packet1Xbf>(psub<Packet1Xbf>(a, b));
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -493,11 +499,6 @@ EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) {
|
||||
unpacket_traits<Packet2Xs>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) {
|
||||
return F32ToBf16(pabsdiff<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2Xbf pset1<Packet2Xbf>(const bfloat16& from) {
|
||||
return __riscv_vreinterpret_bf16m2(
|
||||
@@ -526,12 +527,23 @@ EIGEN_STRONG_INLINE void pbroadcast4<Packet2Xbf>(const bfloat16* a, Packet2Xbf&
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2Xbf padd<Packet2Xbf>(const Packet2Xbf& a, const Packet2Xbf& b) {
|
||||
return F32ToBf16(padd<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
// b + (1 * a)
|
||||
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(b),
|
||||
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0x3f80u)), a,
|
||||
unpacket_traits<Packet2Xbf>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2Xbf psub<Packet2Xbf>(const Packet2Xbf& a, const Packet2Xbf& b) {
|
||||
return F32ToBf16(psub<Packet4Xf>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
// a + (-1 * b)
|
||||
return F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(a),
|
||||
numext::bit_cast<__bf16>(static_cast<numext::int16_t>(0xbf80u)), b,
|
||||
unpacket_traits<Packet2Xbf>::size));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2Xbf pabsdiff(const Packet2Xbf& a, const Packet2Xbf& b) {
|
||||
return pabs<Packet2Xbf>(psub<Packet2Xbf>(a, b));
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -480,17 +480,21 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet1Xh>(const Packet1Xh& a) {
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet1Xh>(const Packet1Xh& a) {
|
||||
const Eigen::half max = (std::numeric_limits<Eigen::half>::max)();
|
||||
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits<Packet1Xh>::size),
|
||||
unpacket_traits<Packet1Xh>::size)));
|
||||
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
|
||||
return (std::min)(static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet1Xh>::size),
|
||||
unpacket_traits<Packet1Xh>::size))),
|
||||
max);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet1Xh>(const Packet1Xh& a) {
|
||||
const Eigen::half min = (std::numeric_limits<Eigen::half>::min)();
|
||||
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits<Packet1Xh>::size),
|
||||
unpacket_traits<Packet1Xh>::size)));
|
||||
const Eigen::half min = -(std::numeric_limits<Eigen::half>::max)();
|
||||
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
|
||||
return (std::max)(static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet1Xh>::size),
|
||||
unpacket_traits<Packet1Xh>::size))),
|
||||
min);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
@@ -837,17 +841,23 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet2Xh>(const Packet2Xh& a) {
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet2Xh>(const Packet2Xh& a) {
|
||||
const Eigen::half max = (std::numeric_limits<Eigen::half>::max)();
|
||||
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits<Packet2Xh>::size / 2),
|
||||
unpacket_traits<Packet2Xh>::size)));
|
||||
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
|
||||
return (std::min)(
|
||||
static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet2Xh>::size / 2),
|
||||
unpacket_traits<Packet2Xh>::size))),
|
||||
max);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet2Xh>(const Packet2Xh& a) {
|
||||
const Eigen::half min = (std::numeric_limits<Eigen::half>::min)();
|
||||
return static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits<Packet2Xh>::size / 2),
|
||||
unpacket_traits<Packet2Xh>::size)));
|
||||
const Eigen::half min = -(std::numeric_limits<Eigen::half>::max)();
|
||||
const Eigen::half nan = (std::numeric_limits<Eigen::half>::quiet_NaN)();
|
||||
return (std::max)(
|
||||
static_cast<Eigen::half>(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1(
|
||||
a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits<Packet2Xh>::size / 2),
|
||||
unpacket_traits<Packet2Xh>::size))),
|
||||
min);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
|
||||
Reference in New Issue
Block a user