Skip to content

Commit

Permalink
fix: 負の正規化数の場合に意図しない数値になっていたバグを修正 #247 (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
yutotnh authored Jul 26, 2023
1 parent bc37fdc commit 7dc7c30
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
14 changes: 5 additions & 9 deletions src/spirit/src/bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@ float bfloat16_to_float32(const uint16_t bfloat16)

uint16_t float32_to_bfloat16(const float value)
{
uint32_t sign = 0;
uint32_t exponent = 0;
uint32_t mantissa = 0;
const uint32_t sign = std::signbit(value) ? 1 : 0;
uint32_t exponent = 0;
uint32_t mantissa = 0;

if (value < 0.0F) {
sign = 1;
}

if (value == 0.0F) {
if (std::fpclassify(value) == FP_ZERO) {
exponent = 0;
mantissa = 0;
} else if (std::isinf(value)) {
Expand All @@ -51,7 +47,7 @@ uint16_t float32_to_bfloat16(const float value)
mantissa = 1;
} else {
int32_t exponent_int32_t = 0;
const float mantissa_float = std::frexp(value, reinterpret_cast<int *>(&exponent_int32_t)) - 0.50F;
const float mantissa_float = std::fabs(std::frexp(value, reinterpret_cast<int *>(&exponent_int32_t))) - 0.50F;
exponent = exponent_int32_t + 126;
mantissa = static_cast<uint32_t>(std::ldexp(mantissa_float, 8));
}
Expand Down
58 changes: 54 additions & 4 deletions tests/test_bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include "bfloat16.h"

/**
* @brief bfloat16をfloat32に変換するテスト
*/
TEST(BFloat16, Bfloat16ToFloat32)
{
auto test = [](const uint16_t bfloat16, const float value) {
Expand All @@ -15,24 +18,38 @@ TEST(BFloat16, Bfloat16ToFloat32)
} else {
// bfloat16は仮数部が7bitなので、float32に変換したときに誤差が出る
// そのため、数値が許容範囲内かどうかをチェックする
float allowable_error_margin = value / 127.0F;
float allowable_error_margin = std::fabs(value) / 127.0F;
EXPECT_NEAR(result, value, allowable_error_margin) << "bfloat16: " << result << ", value: " << value;
}
};

/// @test Zero
test(0x0000, 0.0F);
test(0x8000, -0.0F);

/// @test NaN
test(0x7F81, std::numeric_limits<float>::quiet_NaN());

/// @test Infinity
test(0x7F80, std::numeric_limits<float>::infinity());

/// @test 適当な値
test(0x3F80, 1.0F);
test(0xBF80, -1.0F);

test(0x3DCC, 0.1F);
test(0xBDCC, -0.1F);

test(0x3E4C, 0.2F);
test(0xBE4C, -0.2F);

test(0x42F6, 123.456F);
test(0xC2F6, -123.456F);
}

/**
* @brief float32をbfloat16に変換するテスト
*/
TEST(BFloat16, Float32ToBFloat16)
{
auto test = [](const float value, const uint16_t bfloat16) {
Expand All @@ -42,17 +59,32 @@ TEST(BFloat16, Float32ToBFloat16)

/// @test Zero
test(0.0F, 0x0000);
test(-0.0F, 0x8000);

/// @test NaN
test(std::numeric_limits<float>::quiet_NaN(), 0x7F81);

/// @test Infinity
test(std::numeric_limits<float>::infinity(), 0x7F80);

/// @test 適当な値
test(1.0F, 0x3F80);
test(-1.0F, 0xBF80);

test(0.1F, 0x3DCC);
test(-0.1F, 0xBDCC);

test(0.2F, 0x3E4C);
test(-0.2F, 0xBE4C);

test(123.456F, 0x42F6);
test(-123.456F, 0xC2F6);
}

/**
* @brief float32をbfloat16に変換して、再度float32に変換するテスト
* @details 今までのテストでも十分だが、念のため
*/
TEST(BFloat16, BFloat16ToFloat32AndFloat32ToBFloat16)
{
auto test = [](const float value) {
Expand All @@ -66,20 +98,38 @@ TEST(BFloat16, BFloat16ToFloat32AndFloat32ToBFloat16)
} else {
// bfloat16は仮数部が7bitなので、float32に変換したときに誤差が出る
// そのため、数値が許容範囲内かどうかをチェックする
float allowable_error_margin = value / 127.0F;
float allowable_error_margin = std::fabs(value) / 127.0F;
EXPECT_NEAR(result, value, allowable_error_margin) << "bfloat16: " << result << ", value: " << value;
}
};

/// @test Zero
test(0.0F);
test(-0.0F);

/// @test NaN
test(std::numeric_limits<float>::quiet_NaN());
test(std::numeric_limits<float>::signaling_NaN());

/// @test Infinity
test(std::numeric_limits<float>::infinity());

/// @test 適当な値
test(1.0F);
test(-1.0F);

test(0.1F);
test(0.2F);
test(123.456F);
test(-0.1F);

test(0.013F);
test(-0.013F);

test(0.77F);
test(-0.77F);

test(9.3e+10F);
test(-9.3e+10F);

test(2.52e-10F);
test(-2.52e-10F);
}

0 comments on commit 7dc7c30

Please sign in to comment.