mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Preliminary HIP bfloat16 GPU support.
This commit is contained in:
committed by
Antonio Sánchez
parent
40bbe8a4d0
commit
48e40b22bf
@@ -18,6 +18,18 @@ limitations under the License.
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
#if defined(EIGEN_HAS_HIP_BF16)
|
||||
// When compiling with GPU support, the "hip_bfloat16" base class as well as
|
||||
// some other routines are defined in the GPU compiler header files
|
||||
// (hip_bfloat16.h), and they are not tagged constexpr
|
||||
// As a consequence, we get compile failures when compiling Eigen with
|
||||
// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
|
||||
// Eigen with GPU support
|
||||
#pragma push_macro("EIGEN_CONSTEXPR")
|
||||
#undef EIGEN_CONSTEXPR
|
||||
#define EIGEN_CONSTEXPR
|
||||
#endif
|
||||
|
||||
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
|
||||
template <> \
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
|
||||
@@ -25,19 +37,46 @@ limitations under the License.
|
||||
return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
|
||||
}
|
||||
|
||||
// Only use HIP GPU bf16 in kernels
|
||||
#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
#define EIGEN_USE_HIP_BF16
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
struct bfloat16;
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 numext::bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src);
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t numext::bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src);
|
||||
|
||||
namespace bfloat16_impl {
|
||||
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
|
||||
struct __bfloat16_raw : public hip_bfloat16 {
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : hip_bfloat16(raw) {}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
// Make our own __bfloat16_raw definition.
|
||||
struct __bfloat16_raw {
|
||||
#if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
|
||||
#else
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
|
||||
#endif
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
|
||||
unsigned short value;
|
||||
};
|
||||
|
||||
#endif // defined(EIGEN_USE_HIP_BF16)
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
|
||||
template <bool AssumeArgumentIsNormalOrInfinityOrZero>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
|
||||
@@ -150,7 +189,7 @@ namespace bfloat16_impl {
|
||||
// We need to provide emulated *host-side* BF16 operators for clang.
|
||||
#pragma push_macro("EIGEN_DEVICE_FUNC")
|
||||
#undef EIGEN_DEVICE_FUNC
|
||||
#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
|
||||
#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
|
||||
#define EIGEN_DEVICE_FUNC __host__
|
||||
#else // both host and device need emulated ops.
|
||||
#define EIGEN_DEVICE_FUNC __host__ __device__
|
||||
@@ -179,9 +218,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, co
|
||||
return bfloat16(float(a) / float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
|
||||
bfloat16 result;
|
||||
result.value = a.value ^ 0x8000;
|
||||
return result;
|
||||
numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
|
||||
return numext::bit_cast<bfloat16>(x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
|
||||
a = bfloat16(float(a) + float(b));
|
||||
@@ -248,33 +286,47 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, In
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
|
||||
#else
|
||||
__bfloat16_raw output;
|
||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
||||
if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
||||
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
|
||||
return output;
|
||||
}
|
||||
output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
|
||||
return output;
|
||||
#endif
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
__bfloat16_raw bf;
|
||||
bf.data = value;
|
||||
return bf;
|
||||
#else
|
||||
return __bfloat16_raw(value);
|
||||
#endif
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return bf.data;
|
||||
#else
|
||||
return bf.value;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float_to_bfloat16_rtne template specialization that does not make any
|
||||
// assumption about the value of its function argument (ff).
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
|
||||
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
||||
// Nothing to do here
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
|
||||
#else
|
||||
__bfloat16_raw output;
|
||||
|
||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
|
||||
if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
|
||||
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
|
||||
// this makes sure after truncation we don't end up with an inf.
|
||||
//
|
||||
@@ -443,8 +495,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<fals
|
||||
// type to bfloat16.
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
|
||||
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
||||
// Nothing to do here
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
|
||||
#else
|
||||
numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
|
||||
__bfloat16_raw output;
|
||||
@@ -459,29 +511,41 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return static_cast<float>(h);
|
||||
#else
|
||||
return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- standard functions ---
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
|
||||
EIGEN_USING_STD(isinf);
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return (isinf)(a); // Uses HIP hip_bfloat16 isinf operator
|
||||
#else
|
||||
return (isinf)(float(a));
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
|
||||
EIGEN_USING_STD(isnan);
|
||||
#if defined(EIGEN_USE_HIP_BF16)
|
||||
return (isnan)(a); // Uses HIP hip_bfloat16 isnan operator
|
||||
#else
|
||||
return (isnan)(float(a));
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
|
||||
return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
|
||||
bfloat16 result;
|
||||
result.value = a.value & 0x7FFF;
|
||||
return result;
|
||||
numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
|
||||
return numext::bit_cast<bfloat16>(x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
|
||||
return bfloat16(::expf(float(a)));
|
||||
return bfloat16(::expf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
|
||||
return bfloat16(numext::expm1(float(a)));
|
||||
@@ -499,7 +563,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
|
||||
return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
|
||||
return bfloat16(::sqrtf(float(a)));
|
||||
return bfloat16(::sqrtf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::powf(float(a), float(b)));
|
||||
@@ -563,6 +627,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bf
|
||||
const float f2 = static_cast<float>(b);
|
||||
return f2 < f1 ? b : a;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
@@ -574,6 +639,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfl
|
||||
const float f2 = static_cast<float>(b);
|
||||
return bfloat16(::fminf(f1, f2));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
@@ -623,7 +689,6 @@ template<> struct NumTraits<Eigen::bfloat16>
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
|
||||
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
|
||||
@@ -641,6 +706,11 @@ template<> struct NumTraits<Eigen::bfloat16>
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
|
||||
#if defined(EIGEN_HAS_HIP_BF16)
|
||||
#pragma pop_macro("EIGEN_CONSTEXPR")
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
namespace numext {
|
||||
|
||||
@@ -664,7 +734,7 @@ bool (isfinite)(const Eigen::bfloat16& h) {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
|
||||
return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
|
||||
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -686,5 +756,49 @@ struct hash<Eigen::bfloat16> {
|
||||
} // namespace std
|
||||
#endif
|
||||
|
||||
// Add the missing shfl* intrinsics.
|
||||
// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
|
||||
// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
|
||||
//
|
||||
// HIP and CUDA prior to SDK 9.0 define
|
||||
// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
|
||||
// CUDA since 9.0 deprecates those and instead defines
|
||||
// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
|
||||
// with native support for __half and __nv_bfloat16
|
||||
//
|
||||
// Note that the following are __device__ - only functions.
|
||||
#if defined(EIGEN_HIPCC)
|
||||
|
||||
#if defined(EIGEN_HAS_HIP_BF16)
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var, int srcLane, int width=warpSize) {
|
||||
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
|
||||
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var, unsigned int delta, int width=warpSize) {
|
||||
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
|
||||
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var, unsigned int delta, int width=warpSize) {
|
||||
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
|
||||
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var, int laneMask, int width=warpSize) {
|
||||
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
|
||||
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
|
||||
}
|
||||
|
||||
#endif // HIP
|
||||
|
||||
#endif // __shfl*
|
||||
|
||||
#if defined(EIGEN_HIPCC)
|
||||
EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(const Eigen::bfloat16* ptr) {
|
||||
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(__ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
|
||||
}
|
||||
#endif // __ldg
|
||||
|
||||
#endif // EIGEN_BFLOAT16_H
|
||||
|
||||
@@ -468,6 +468,8 @@
|
||||
#include <hip/hip_vector_types.h>
|
||||
#define EIGEN_HAS_HIP_FP16
|
||||
#include <hip/hip_fp16.h>
|
||||
#define EIGEN_HAS_HIP_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user