From 4ab32e2de2511746e2108563a43cbbeb1922fbf2 Mon Sep 17 00:00:00 2001 From: Niels Dekker Date: Sat, 11 Jul 2020 12:50:46 +0200 Subject: [PATCH] Allow implicit conversion from bfloat16 to float and double Conversion from `bfloat16` to `float` and `double` is lossless. It seems natural to allow the conversion to be implicit, as the C++ language also support implicit conversion from a smaller to a larger floating point type. Intel's OneDLL bfloat16 implementation also has an implicit `operator float()`: https://github.com/oneapi-src/oneDNN/blob/v1.5/src/common/bfloat16.hpp --- Eigen/src/Core/arch/Default/BFloat16.h | 4 ++-- test/bfloat16_float.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index 561304f80..abf2ac933 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -117,10 +117,10 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const { return static_cast(bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + EIGEN_DEVICE_FUNC operator float() const { return bfloat16_impl::bfloat16_to_float(*this); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { + EIGEN_DEVICE_FUNC operator double() const { return static_cast(bfloat16_impl::bfloat16_to_float(*this)); } template diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp index eb55f7d45..478aef3a3 100644 --- a/test/bfloat16_float.cpp +++ b/test/bfloat16_float.cpp @@ -53,9 +53,9 @@ void test_conversion() VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity. // Verify round-to-nearest-even behavior. - float val1 = static_cast(bfloat16(__bfloat16_raw(0x3c00))); - float val2 = static_cast(bfloat16(__bfloat16_raw(0x3c01))); - float val3 = static_cast(bfloat16(__bfloat16_raw(0x3c02))); + float val1 = bfloat16(__bfloat16_raw(0x3c00)); + float val2 = bfloat16(__bfloat16_raw(0x3c01)); + float val3 = bfloat16(__bfloat16_raw(0x3c02)); VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00); VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);