Skip to content

Commit

Permalink
Add conversions for FP8 types (F8E5M2 and F8E4M3) (iree-org#16374)
Browse files Browse the repository at this point in the history
This PR almost doesn't make code any bigger because the existing
conversion code was already essentially generic. So at least the F8E5M2
type falls for free. F8E4M3 is a bit trickier due to it not having
infinities and reclaiming that encoding space to get extra large finite
values.
  • Loading branch information
bjacob authored Feb 12, 2024
1 parent 4a49e37 commit 9aabcb3
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 81 deletions.
186 changes: 107 additions & 79 deletions runtime/src/iree/base/internal/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) {
}

//==============================================================================
// FP16 and BFloat16 support
// FP16, BFloat16 and FP8 support
//==============================================================================

// NOTE: We used to have code here using built-in _Float16 type support.
Expand All @@ -273,91 +273,115 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) {
// in slow generic fallbacks or test code, and we weren't able to use
// a builtin for bf16 anyway.

#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, bits, ebits) \
// Define some helper constants for working with a floating-point format with
// the given number of {exponent,mantissa} bits.
#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits) \
const int prefix##exp_bits IREE_ATTRIBUTE_UNUSED = ebits; \
const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = \
bits - 1 - prefix##exp_bits; \
const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = bits - 1; \
const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = mbits; \
const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = ebits + mbits; \
const int prefix##exp_shift IREE_ATTRIBUTE_UNUSED = prefix##mantissa_bits; \
const int prefix##sign_mask IREE_ATTRIBUTE_UNUSED = 1u \
<< prefix##sign_shift; \
const int prefix##mantissa_mask IREE_ATTRIBUTE_UNUSED = \
(1u << prefix##exp_shift) - 1; \
const int prefix##exp_mask IREE_ATTRIBUTE_UNUSED = \
(1u << prefix##sign_shift) - (1u << prefix##exp_shift);

static inline float iree_math_generic_fp16_to_f32(uint16_t f16_value,
int exp_bits) {
IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
const uint32_t f16_sign = f16_value & f16_sign_mask;
const uint32_t f32_sign = f16_sign << (f32_sign_shift - f16_sign_shift);
const uint32_t f16_exp = f16_value & f16_exp_mask;
const uint32_t f16_mantissa = f16_value & f16_mantissa_mask;
(1u << prefix##sign_shift) - (1u << prefix##exp_shift); \
const int prefix##exp_bias IREE_ATTRIBUTE_UNUSED = \
(1u << (prefix##exp_bits - 1)) - 1;

// Generic conversion from any less-than-32-bit floating-point format to f32.
// The `src` value is typed as a uint32_t for genericity but occupies only the
// bottom (1 + exp_bits + mantissa_bits) bits. The upper bits of `src` are
// unused.
static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits,
int mantissa_bits,
bool have_infinity) {
IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits)
IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
const uint32_t src_sign = src & src_sign_mask;
const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift);
const uint32_t src_exp = src & src_exp_mask;
const uint32_t src_mantissa = src & src_mantissa_mask;
uint32_t f32_exp = 0;
uint32_t f32_mantissa = 0;
if (f16_exp == f16_exp_mask) {
if (src_exp == src_exp_mask) {
// No infinities => more large finite values.
if (!have_infinity && src_mantissa != src_mantissa_mask) {
float sign = (src & src_sign_mask) ? -1.0f : 1.0f;
return sign * 2 * (1u << src_exp_bits) *
((1u << src_mantissa_bits) + src_mantissa);
}
// NaN or Inf case.
f32_exp = f32_exp_mask;
if (f16_mantissa) {
if (src_mantissa) {
// NaN. Generate a quiet NaN.
f32_mantissa = f32_mantissa_mask;
} else {
// Inf. Leave zero mantissa.
}
} else if (f16_exp == 0) {
} else if (src_exp == 0) {
// Zero or subnormal. Generate zero. Leave zero mantissa.
} else {
// Normal finite value.
int arithmetic_f16_exp = f16_exp >> f16_exp_shift;
int arithmetic_f32_exp = arithmetic_f16_exp + (1 << (f32_exp_bits - 1)) -
(1 << (f16_exp_bits - 1));
int arithmetic_src_exp = src_exp >> src_exp_shift;
int arithmetic_f32_exp = arithmetic_src_exp + (1 << (f32_exp_bits - 1)) -
(1 << (src_exp_bits - 1));
f32_exp = arithmetic_f32_exp << f32_exp_shift;
f32_mantissa = f16_mantissa << (f32_mantissa_bits - f16_mantissa_bits);
f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits);
}
const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa;
float f32_value;
memcpy(&f32_value, &u32_value, sizeof f32_value);
return f32_value;
}

static inline uint16_t iree_math_f32_to_generic_fp16(float value,
int exp_bits) {
IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
// Generic conversion from f32 to any less-than-32-bit floating-point format,
// rounding to nearest-even. The return value is typed as a uint32_t for
// genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits.
// The upper bits of the return value are unused.
static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even(
float value, int exp_bits, int mantissa_bits, bool have_infinity) {
IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits)
IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
uint32_t u32_value;
memcpy(&u32_value, &value, sizeof value);
const uint32_t f32_sign = u32_value & f32_sign_mask;
const uint32_t f16_sign = f32_sign >> (f32_sign_shift - f16_sign_shift);
const uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift);
const uint32_t f32_exp = u32_value & f32_exp_mask;
const uint32_t f32_mantissa = u32_value & f32_mantissa_mask;
uint32_t f16_exp = 0;
uint32_t f16_mantissa = 0;
if (f32_exp == f32_exp_mask) {
uint32_t dst_exp = 0;
uint32_t dst_mantissa = 0;
if (f32_exp >= f32_exp_mask) {
// NaN or Inf case.
f16_exp = f16_exp_mask;
if (f32_mantissa) {
dst_exp = dst_exp_mask;
if (f32_mantissa || !have_infinity) {
// NaN. Generate a quiet NaN.
f16_mantissa = f16_mantissa_mask;
dst_mantissa = dst_mantissa_mask;
} else {
// Inf. Leave zero mantissa.
}
} else if (f32_exp == 0) {
// Zero or subnormal. Generate zero. Leave zero mantissa.
} else {
// Normal finite value.
int arithmetic_exp = (f32_exp >> f32_exp_shift) - (1 << (f32_exp_bits - 1));
if (arithmetic_exp >= (1 << (f16_exp_bits - 1))) {
int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias;
// Test if the exponent is too large for the destination type. If
// the destination type does not have infinities, that frees up the
// max exponent value for additional finite values.
if (arithmetic_exp > (1 << (dst_exp_bits - 1)) - have_infinity) {
// Overflow. Generate Inf. Leave zero mantissa.
f16_exp = f16_exp_mask;
} else if (arithmetic_exp < -(1 << (f16_exp_bits - 1))) {
dst_exp = dst_exp_mask;
if (!have_infinity) {
// Generate NaN.
dst_mantissa = dst_mantissa_mask;
}
} else if (arithmetic_exp < -(1 << (dst_exp_bits - 1))) {
// Underflow. Generate zero. Leave zero mantissa.
f16_exp = 0;
dst_exp = 0;
} else {
// Normal case.
// Implement round-to-nearest-even, by adding a bias before truncating.
// truncating.
int even_bit = 1u << (f32_mantissa_bits - f16_mantissa_bits);
int even_bit = 1u << (f32_mantissa_bits - dst_mantissa_bits);
int odd_bit = even_bit >> 1;
uint32_t biased_f32_mantissa =
f32_mantissa +
Expand All @@ -377,52 +401,56 @@ static inline uint16_t iree_math_f32_to_generic_fp16(float value,
biased_f32_mantissa = 0;
++arithmetic_exp;
}
// The exponent increment in the above if() branch may cause overflow.
// This is exercised by converting 65520.0f from f32 to f16. No special
// handling is needed for this case: the above if() branch already set
// biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
// needed for infinite values.
f16_exp = (arithmetic_exp + (1 << (f16_exp_bits - 1))) << f16_exp_shift;
f16_mantissa =
biased_f32_mantissa >> (f32_mantissa_bits - f16_mantissa_bits);
// In the !have_infinity case, arithmetic_exp might have been the top
// value already, so incrementing it may have overflown it.
if (!have_infinity && arithmetic_exp > (1 << (dst_exp_bits - 1))) {
dst_exp = dst_exp_mask;
dst_mantissa = dst_mantissa_mask;
} else {
// The exponent increment in the above if() branch may cause overflow.
// This is exercised by converting 65520.0f from f32 to f16. No special
// handling is needed for this case: the above if() branch already set
// biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
// needed for infinite values.
dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift;
dst_mantissa =
biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits);
}
}
}
uint16_t f16_value = f16_sign | f16_exp | f16_mantissa;
return f16_value;
uint32_t dst_value = dst_sign | dst_exp | dst_mantissa;
return dst_value;
}

// Converts a fp16 value to a 32-bit C `float`.
static inline float iree_math_f16_to_f32(uint16_t f16_value) {
return iree_math_generic_fp16_to_f32(f16_value, 5);
}

// Converts a 32-bit C `float` value to a fp16 value, rounding to nearest
// even.
static inline uint16_t iree_math_f32_to_f16(float value) {
return iree_math_f32_to_generic_fp16(value, 5);
}
#define IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(NAME, INT_TYPE, EXP_BITS, \
MANTISSA_BITS, HAVE_INFINITY) \
/* Converts a to a 32-bit C `float`. */ \
static inline float iree_math_##NAME##_to_f32(INT_TYPE src) { \
return iree_math_make_f32_from_bits(src, EXP_BITS, MANTISSA_BITS, \
HAVE_INFINITY); \
} \
/* Truncates a 32-bit C `float`, rounding to nearest even. */ \
static inline INT_TYPE iree_math_f32_to_##NAME(float value) { \
return iree_math_truncate_f32_to_bits_rounding_to_nearest_even( \
value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY); \
} \
/* Round-trip f32->f32 rounding via the narrow float type */ \
static inline float iree_math_round_to_nearest_##NAME(float value) { \
return iree_math_##NAME##_to_f32(iree_math_f32_to_##NAME(value)); \
}

// Rounds of 32-bit C `float` value to nearest 16-bit value and returns
// 32-bit `float`
static inline float iree_math_round_to_nearest_f16(float f32_value) {
return iree_math_f16_to_f32(iree_math_f32_to_f16(f32_value));
}
// IEEE half-precision a.k.a. float16,
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true)

// Converts a bfloat16 value to a 32-bit C `float`.
static inline float iree_math_bf16_to_f32(uint16_t bf16_value) {
return iree_math_generic_fp16_to_f32(bf16_value, 8);
}
// Bfloat16, https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true)

// Converts a 32-bit C `float` value to a bfloat16 value, rounding to nearest
// even.
static inline uint16_t iree_math_f32_to_bf16(float value) {
return iree_math_f32_to_generic_fp16(value, 8);
}
// F8E5M2 type, https://arxiv.org/abs/2209.05433
IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true)

// Rounds of 32-bit C `float` value to nearest bfloat16 value and returns
// 32-bit `float`
static inline float iree_math_round_to_nearest_bf16(float f32_value) {
return iree_math_bf16_to_f32(iree_math_f32_to_bf16(f32_value));
}
// F8E4M3 type, https://arxiv.org/abs/2209.05433.
IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3, uint8_t, 4, 3,
/*have_infinity=*/false)

#endif // IREE_BASE_INTERNAL_MATH_H_
Loading

0 comments on commit 9aabcb3

Please sign in to comment.