From 7dc7c30ab125729880160fce5ece87a46785cf90 Mon Sep 17 00:00:00 2001 From: yutotnh <57719497+yutotnh@users.noreply.github.com> Date: Thu, 27 Jul 2023 00:16:11 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=B2=A0=E3=81=AE=E6=AD=A3=E8=A6=8F?= =?UTF-8?q?=E5=8C=96=E6=95=B0=E3=81=AE=E5=A0=B4=E5=90=88=E3=81=AB=E6=84=8F?= =?UTF-8?q?=E5=9B=B3=E3=81=97=E3=81=AA=E3=81=84=E6=95=B0=E5=80=A4=E3=81=AB?= =?UTF-8?q?=E3=81=AA=E3=81=A3=E3=81=A6=E3=81=84=E3=81=9F=E3=83=90=E3=82=B0?= =?UTF-8?q?=E3=82=92=E4=BF=AE=E6=AD=A3=20#247=20(#248)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/spirit/src/bfloat16.cpp | 14 ++++----- tests/test_bfloat16.cpp | 58 ++++++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/src/spirit/src/bfloat16.cpp b/src/spirit/src/bfloat16.cpp index 971119f..d025cc1 100644 --- a/src/spirit/src/bfloat16.cpp +++ b/src/spirit/src/bfloat16.cpp @@ -32,15 +32,11 @@ float bfloat16_to_float32(const uint16_t bfloat16) uint16_t float32_to_bfloat16(const float value) { - uint32_t sign = 0; - uint32_t exponent = 0; - uint32_t mantissa = 0; + const uint32_t sign = std::signbit(value) ? 1 : 0; + uint32_t exponent = 0; + uint32_t mantissa = 0; - if (value < 0.0F) { - sign = 1; - } - - if (value == 0.0F) { + if (std::fpclassify(value) == FP_ZERO) { exponent = 0; mantissa = 0; } else if (std::isinf(value)) { @@ -51,7 +47,7 @@ uint16_t float32_to_bfloat16(const float value) mantissa = 1; } else { int32_t exponent_int32_t = 0; - const float mantissa_float = std::frexp(value, reinterpret_cast(&exponent_int32_t)) - 0.50F; + const float mantissa_float = std::fabs(std::frexp(value, reinterpret_cast(&exponent_int32_t))) - 0.50F; exponent = exponent_int32_t + 126; mantissa = static_cast(std::ldexp(mantissa_float, 8)); } diff --git a/tests/test_bfloat16.cpp b/tests/test_bfloat16.cpp index 6a18d1c..dbb8733 100644 --- a/tests/test_bfloat16.cpp +++ b/tests/test_bfloat16.cpp @@ -4,6 +4,9 @@ #include "bfloat16.h" +/** + * @brief bfloat16をfloat32に変換するテスト + */ TEST(BFloat16, Bfloat16ToFloat32) { auto test = [](const uint16_t bfloat16, const float value) { @@ -15,24 +18,38 @@ TEST(BFloat16, Bfloat16ToFloat32) } else { // bfloat16は仮数部が7bitなので、float32に変換したときに誤差が出る // そのため、数値が許容範囲内かどうかをチェックする - float allowable_error_margin = value / 127.0F; + float allowable_error_margin = std::fabs(value) / 127.0F; EXPECT_NEAR(result, value, allowable_error_margin) << "bfloat16: " << result << ", value: " << value; } }; /// @test Zero test(0x0000, 0.0F); + test(0x8000, -0.0F); + /// @test NaN test(0x7F81, std::numeric_limits::quiet_NaN()); + /// @test Infinity test(0x7F80, std::numeric_limits::infinity()); + /// @test 適当な値 test(0x3F80, 1.0F); + test(0xBF80, -1.0F); + test(0x3DCC, 0.1F); + test(0xBDCC, -0.1F); + test(0x3E4C, 0.2F); + test(0xBE4C, -0.2F); + test(0x42F6, 123.456F); + test(0xC2F6, -123.456F); } +/** + * @brief float32をbfloat16に変換するテスト + */ TEST(BFloat16, Float32ToBFloat16) { auto test = [](const float value, const uint16_t bfloat16) { @@ -42,17 +59,32 @@ TEST(BFloat16, Float32ToBFloat16) /// @test Zero test(0.0F, 0x0000); + test(-0.0F, 0x8000); + /// @test NaN test(std::numeric_limits::quiet_NaN(), 0x7F81); + /// @test Infinity test(std::numeric_limits::infinity(), 0x7F80); + /// @test 適当な値 test(1.0F, 0x3F80); + test(-1.0F, 0xBF80); + test(0.1F, 0x3DCC); + test(-0.1F, 0xBDCC); + test(0.2F, 0x3E4C); + test(-0.2F, 0xBE4C); + test(123.456F, 0x42F6); + test(-123.456F, 0xC2F6); } +/** + * @brief float32をbfloat16に変換して、再度float32に変換するテスト + * @details 今までのテストでも十分だが、念のため + */ TEST(BFloat16, BFloat16ToFloat32AndFloat32ToBFloat16) { auto test = [](const float value) { @@ -66,20 +98,38 @@ TEST(BFloat16, BFloat16ToFloat32AndFloat32ToBFloat16) } else { // bfloat16は仮数部が7bitなので、float32に変換したときに誤差が出る // そのため、数値が許容範囲内かどうかをチェックする - float allowable_error_margin = value / 127.0F; + float allowable_error_margin = std::fabs(value) / 127.0F; EXPECT_NEAR(result, value, allowable_error_margin) << "bfloat16: " << result << ", value: " << value; } }; /// @test Zero test(0.0F); + test(-0.0F); + /// @test NaN test(std::numeric_limits::quiet_NaN()); + test(std::numeric_limits::signaling_NaN()); + /// @test Infinity test(std::numeric_limits::infinity()); + /// @test 適当な値 test(1.0F); + test(-1.0F); + test(0.1F); - test(0.2F); - test(123.456F); + test(-0.1F); + + test(0.013F); + test(-0.013F); + + test(0.77F); + test(-0.77F); + + test(9.3e+10F); + test(-9.3e+10F); + + test(2.52e-10F); + test(-2.52e-10F); }