add nextafter for bfloat16

This commit is contained in:
Peter Gavin
2024-10-21 21:23:41 +00:00
parent 53b83cddf9
commit b15ebb1c2d
2 changed files with 60 additions and 0 deletions

View File

@@ -353,6 +353,40 @@ void test_product() {
VERIFY_IS_APPROX(Ch.noalias() += Ah * Bh, (Cf.noalias() += Af * Bf).cast<bfloat16>());
}
void test_nextafter() {
VERIFY((numext::isnan)(numext::nextafter(std::numeric_limits<bfloat16>::quiet_NaN(), bfloat16(1.0f))));
VERIFY((numext::isnan)(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::quiet_NaN())));
VERIFY(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)) == bfloat16(0.0f));
VERIFY(numext::nextafter(bfloat16(1.0f), bfloat16(1.0f)) == bfloat16(1.0f));
VERIFY(numext::nextafter(bfloat16(-1.0f), bfloat16(-1.0f)) == bfloat16(-1.0f));
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), std::numeric_limits<bfloat16>::infinity()) ==
std::numeric_limits<bfloat16>::infinity());
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
(std::numeric_limits<bfloat16>::max)());
VERIFY(numext::nextafter(-std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
-(std::numeric_limits<bfloat16>::max)());
VERIFY(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(1.0f) + std::numeric_limits<bfloat16>::epsilon());
VERIFY(numext::nextafter(bfloat16(1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(1.0f) - std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
VERIFY(numext::nextafter(bfloat16(-1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(-1.0f) - std::numeric_limits<bfloat16>::epsilon());
VERIFY(numext::nextafter(bfloat16(-1.0f), std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(-1.0f) + std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
VERIFY(numext::nextafter((std::numeric_limits<bfloat16>::max)(), std::numeric_limits<bfloat16>::infinity()) ==
std::numeric_limits<bfloat16>::infinity());
VERIFY(numext::nextafter(-(std::numeric_limits<bfloat16>::max)(), -std::numeric_limits<bfloat16>::infinity()) ==
-std::numeric_limits<bfloat16>::infinity());
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(1.0f)), 0x0001);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(1.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-1.0f)), 0x8000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-1.0f)), 0x8001);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-0.0f)), 0x8000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(0.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-0.0f)), 0x8000);
}
EIGEN_DECLARE_TEST(bfloat16_float) {
CALL_SUBTEST(test_numtraits());
for (int i = 0; i < g_repeat; i++) {
@@ -363,5 +397,6 @@ EIGEN_DECLARE_TEST(bfloat16_float) {
CALL_SUBTEST(test_trigonometric_functions());
CALL_SUBTEST(test_array());
CALL_SUBTEST(test_product());
CALL_SUBTEST(test_nextafter());
}
}