Add missing functions for Packet8bf in Altivec architecture.

Including new tests for bfloat16 Packets.
Fix prsqrt on GenericPacketMath.
This commit is contained in:
Pedro Caldeira
2020-08-21 17:52:34 -05:00
parent 85428a3440
commit 35d149e34c
4 changed files with 141 additions and 5 deletions

View File

@@ -612,7 +612,8 @@ Packet psqrt(const Packet& a) { EIGEN_USING_STD_MATH(sqrt); return sqrt(a); }
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet prsqrt(const Packet& a) {
return pdiv(pset1<Packet>(1), psqrt(a));
typedef typename internal::unpacket_traits<Packet>::type Scalar;
return pdiv(pset1<Packet>(Scalar(1)), psqrt(a));
}
/** \internal \returns the rounded value of \a a (coeff-wise) */

View File

@@ -646,6 +646,11 @@ template<> EIGEN_DEVICE_FUNC inline Packet8us pgather<unsigned short int, Packet
return pgather_size8<Packet8us>(from, stride);
}
template<> EIGEN_DEVICE_FUNC inline Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
{
return pgather_size8<Packet8bf>(from, stride);
}
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_size16(const __UNPACK_TYPE__(Packet)* from, Index stride)
{
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
@@ -724,6 +729,11 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<unsigned short int, Packet8us>
pscatter_size8<Packet8us>(to, from, stride);
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
{
pscatter_size8<Packet8bf>(to, from, stride);
}
template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size16(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
{
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
@@ -1285,7 +1295,30 @@ template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, con
template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf prsqrt<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(prsqrt<Packet4f>, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf psin<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcos<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pcos_float, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf plog<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(plog_float, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pfloor<Packet4f>, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pceil<Packet4f>, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pround<Packet4f>, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
Packet4f a_even = Bf16ToF32Even(a);
Packet4f a_odd = Bf16ToF32Odd(a);
@@ -1325,6 +1358,12 @@ template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16*
return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
}
template<> EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
bfloat16 countdown[8] = { bfloat16(0), bfloat16(1), bfloat16(2), bfloat16(3),
bfloat16(4), bfloat16(5), bfloat16(6), bfloat16(7) };
return padd<Packet8bf>(pset1<Packet8bf>(a), pload<Packet8bf>(countdown));
}
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
return pfrexp_float(a,exponent);
}