Skip to content

Commit

Permalink
1. Rename cast_from_fixed to cast_from_fixed_using_rne.
Browse files Browse the repository at this point in the history
2. Simplify the code a bit.
3. Add a comment about some uncaptured overflow behaviour for VERY large numbers.

PiperOrigin-RevId: 551596324
  • Loading branch information
sandwichmaker authored and copybara-github committed Jul 27, 2023
1 parent fb6726b commit cae5fbb
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 57 deletions.
6 changes: 3 additions & 3 deletions docs_src/floating_point.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ pub fn ldexp<EXP_SZ:u32, FRACTION_SZ:u32>(
`NaN` representations as input.


### `apfloat::cast_from_fixed`
### `apfloat::cast_from_fixed_using_rne`

```dslx-snippet
pub fn cast_from_fixed<EXP_SZ:u32, FRACTION_SZ:u32, NUM_SRC_BITS:u32>(
to_cast: sN[NUM_SRC_BITS])
pub fn cast_from_fixed_using_rne<EXP_SZ:u32, FRACTION_SZ:u32, NUM_SRC_BITS:u32>(
to_cast: sN[NUM_SRC_BITS])
-> APFloat<EXP_SZ, FRACTION_SZ> {
```

Expand Down
2 changes: 1 addition & 1 deletion third_party/xls_go_math/fpexp_32.x
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub fn fpexp_32(x: F32) -> F32 {

// Reduce
// TODO(jbaileyhandle): Cheaper to truncate fp_k directly?
let fp_truncated_k = float32::cast_from_fixed(k);
let fp_truncated_k = float32::cast_from_fixed_using_rne(k);
let hi = float32::mul(LN2HI, fp_truncated_k);
let hi = float32::sub(x, hi);
let lo = float32::mul( LN2LO, fp_truncated_k);
Expand Down
62 changes: 33 additions & 29 deletions xls/dslx/stdlib/apfloat.x
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,22 @@ pub fn unflatten<EXP_SZ:u32, FRACTION_SZ:u32,

// Casts the fixed point number to a floating point number using RNE
// (Round to Nearest Even) as the rounding mode.
pub fn cast_from_fixed<EXP_SZ:u32, FRACTION_SZ:u32, NUM_SRC_BITS:u32>(
to_cast: sN[NUM_SRC_BITS])
pub fn cast_from_fixed_using_rne<EXP_SZ:u32, FRACTION_SZ:u32, NUM_SRC_BITS:u32>(
to_cast: sN[NUM_SRC_BITS])
-> APFloat<EXP_SZ, FRACTION_SZ> {
const UEXP_SZ:u32 = EXP_SZ + u32:1;
const EXTENDED_FRACTION_SZ:u32 = FRACTION_SZ + NUM_SRC_BITS;

// Determine sign.
let sign = (to_cast as uN[NUM_SRC_BITS])[(NUM_SRC_BITS-u32:1) as s32 : NUM_SRC_BITS as s32];
let is_negative = to_cast < sN[NUM_SRC_BITS]:0;

// Determine exponent.
let abs_magnitude = (if sign == u1:0 { to_cast } else { -to_cast }) as uN[NUM_SRC_BITS];
let abs_magnitude = std::abs(to_cast) as uN[NUM_SRC_BITS];
let lz = clz(abs_magnitude);
let num_trailing_nonzeros = (NUM_SRC_BITS as uN[NUM_SRC_BITS]) - lz;

// The following computation of exp can overflow if num_trailing_nonzeros
// is larger than what uN[UEXP_SZ] can hold.
let exp = (num_trailing_nonzeros as uN[UEXP_SZ]) - uN[UEXP_SZ]:1;
let max_exp_exclusive = uN[UEXP_SZ]:1 << ((EXP_SZ as uN[UEXP_SZ]) - uN[UEXP_SZ]:1);
let is_inf = exp >= max_exp_exclusive;
Expand Down Expand Up @@ -345,28 +349,28 @@ pub fn cast_from_fixed<EXP_SZ:u32, FRACTION_SZ:u32, NUM_SRC_BITS:u32>(

let result =
APFloat<EXP_SZ, FRACTION_SZ>{
sign: sign,
sign: is_negative,
bexp: bexp,
fraction: fraction
};

let is_zero = abs_magnitude == uN[NUM_SRC_BITS]:0;
let result = if is_inf { inf<EXP_SZ, FRACTION_SZ>(sign) } else { result };
let result = if is_zero { zero<EXP_SZ, FRACTION_SZ>(sign) } else { result };
let result = if is_inf { inf<EXP_SZ, FRACTION_SZ>(is_negative) } else { result };
let result = if is_zero { zero<EXP_SZ, FRACTION_SZ>(is_negative) } else { result };
result
}

#[test]
fn cast_from_fixed_test() {
fn cast_from_fixed_using_rne_test() {
// Zero is a special case.
let zero_float = zero<u32:4, u32:4>(u1:0);
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:0), zero_float);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:0), zero_float);

// +/-1
let one_float = one<u32:4, u32:4>(u1:0);
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:1), one_float);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:1), one_float);
let none_float = one<u32:4, u32:4>(u1:1);
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:-1), none_float);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:-1), none_float);

// +/-4
let four_float =
Expand All @@ -375,14 +379,14 @@ fn cast_from_fixed_test() {
bexp: u4:9,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:4), four_float);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:4), four_float);
let nfour_float =
APFloat<u32:4, u32:4>{
sign: u1:1,
bexp: u4:9,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:-4), nfour_float);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:-4), nfour_float);

// Cast maximum representable exponent in target format.
let max_representable =
Expand All @@ -391,10 +395,10 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:128), max_representable);
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:128), max_representable);

// Cast minimum non-representable exponent in target format.
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:256),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:256),
inf<u32:4, u32:4>(u1:0));

// Test rounding - maximum truncated bits that will round down, even fraction.
Expand All @@ -404,7 +408,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:131),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:131),
truncate);

// Test rounding - maximum truncated bits that will round down, odd fraction.
Expand All @@ -414,7 +418,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:1
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:139),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:139),
truncate);

// Test rounding - halfway and already even, round down
Expand All @@ -424,7 +428,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:132),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:132),
truncate);

// Test rounding - halfway and odd, round up
Expand All @@ -434,7 +438,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:2
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:140),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:140),
round_up);

// Test rounding - over halfway and even, round up
Expand All @@ -444,7 +448,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:1
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:133),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:133),
round_up);

// Test rounding - over halfway and odd, round up
Expand All @@ -454,7 +458,7 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:2
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:141),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:141),
round_up);

// Test rounding - Rounding up increases exponent.
Expand All @@ -464,15 +468,15 @@ fn cast_from_fixed_test() {
bexp: u4:14,
fraction: u4:0
};
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:126),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:126),
round_inc_exponent);
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:127),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:127),
round_inc_exponent);

// Test rounding - Rounding up overflows to infinity.
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:252),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:252),
inf<u32:4, u32:4>(u1:0));
assert_eq(cast_from_fixed<u32:4, u32:4>(sN[32]:254),
assert_eq(cast_from_fixed_using_rne<u32:4, u32:4>(sN[32]:254),
inf<u32:4, u32:4>(u1:0));
()
}
Expand Down Expand Up @@ -731,21 +735,21 @@ fn cast_to_fixed_test() {
cast_to_fixed<u32:32>(n_one_point_five), s32:-1);

// Cast +/-4.0
let four = cast_from_fixed<u32:8, u32:23>(s32:4);
let neg_four = cast_from_fixed<u32:8, u32:23>(s32:-4);
let four = cast_from_fixed_using_rne<u32:8, u32:23>(s32:4);
let neg_four = cast_from_fixed_using_rne<u32:8, u32:23>(s32:-4);
assert_eq(
cast_to_fixed<u32:32>(four), s32:4);
assert_eq(
cast_to_fixed<u32:32>(neg_four), s32:-4);

// Cast 7
let seven = cast_from_fixed<u32:8, u32:23>(s32:7);
let seven = cast_from_fixed_using_rne<u32:8, u32:23>(s32:7);
assert_eq(
cast_to_fixed<u32:32>(seven), s32:7);

// Cast big number (more digits left of decimal than hidden bit + fraction).
let big_num = (u1:0 ++ std::mask_bits<u32:23>() ++ u8:0) as s32;
let fp_big_num = cast_from_fixed<u32:8, u32:23>(big_num);
let fp_big_num = cast_from_fixed_using_rne<u32:8, u32:23>(big_num);
assert_eq(
cast_to_fixed<u32:32>(fp_big_num), big_num);

Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/stdlib/bfloat16.x
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ pub fn bias(unbiased_exponent_in: s8) -> u8 {
pub fn flatten(f: BF16) -> u16 { apfloat::flatten<u32:8, u32:7>(f) }
pub fn unflatten(f: u16) -> BF16 { apfloat::unflatten<u32:8, u32:7>(f) }
pub fn ldexp(f: BF16, e : s32) -> BF16 {apfloat::ldexp(f, e)}
pub fn cast_from_fixed<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> BF16 {
apfloat::cast_from_fixed<u32:8, u32:7>(s)
pub fn cast_from_fixed_using_rne<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> BF16 {
apfloat::cast_from_fixed_using_rne<u32:8, u32:7>(s)
}
pub fn cast_to_fixed<NUM_DST_BITS:u32>(to_cast: BF16) -> sN[NUM_DST_BITS] {
apfloat::cast_to_fixed<NUM_DST_BITS, u32:8, u32:7>(to_cast)
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/stdlib/float32.x
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pub fn flatten(f: F32) -> u32 { apfloat::flatten<u32:8, u32:23>(f) }
pub fn unflatten(f: u32) -> F32 { apfloat::unflatten<u32:8, u32:23>(f) }
pub fn ldexp(f: F32, e : s32) -> F32 {apfloat::ldexp(f, e)}

pub fn cast_from_fixed<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> F32 {
apfloat::cast_from_fixed<u32:8, u32:23>(s)
pub fn cast_from_fixed_using_rne<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> F32 {
apfloat::cast_from_fixed_using_rne<u32:8, u32:23>(s)
}
pub fn cast_to_fixed<NUM_DST_BITS:u32>(to_cast: F32) -> sN[NUM_DST_BITS] {
apfloat::cast_to_fixed<NUM_DST_BITS, u32:8, u32:23>(to_cast)
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/stdlib/float64.x
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ pub fn bias(unbiased_exponent_in: s11) -> u11 {
pub fn flatten(f: F64) -> u64 { apfloat::flatten<u32:11, u32:52>(f) }
pub fn unflatten(f: u64) -> F64 { apfloat::unflatten<u32:11, u32:52>(f) }
pub fn ldexp(f: F64, e : s32) -> F64 {apfloat::ldexp(f, e)}
pub fn cast_from_fixed<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> F64 {
apfloat::cast_from_fixed<u32:11, u32:52>(s)
pub fn cast_from_fixed_using_rne<NUM_SRC_BITS:u32>(s: sN[NUM_SRC_BITS]) -> F64 {
apfloat::cast_from_fixed_using_rne<u32:11, u32:52>(s)
}
pub fn cast_to_fixed<NUM_DST_BITS:u32>(to_cast: F64) -> sN[NUM_DST_BITS] {
apfloat::cast_to_fixed<NUM_DST_BITS, u32:11, u32:52>(to_cast)
Expand Down
6 changes: 3 additions & 3 deletions xls/examples/dot_product.x
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ fn dot_product_fixed_test() {

#[test]
fn dot_product_float32_test() {
let a = map(s32[4]:[1, 2, 3, 4], float32::cast_from_fixed);
let b = map(s32[4]:[5, 6, 7, 8], float32::cast_from_fixed);
let a = map(s32[4]:[1, 2, 3, 4], float32::cast_from_fixed_using_rne);
let b = map(s32[4]:[5, 6, 7, 8], float32::cast_from_fixed_using_rne);
let result = dot_product_float32(a, b);
assert_eq(result, float32::cast_from_fixed(s32:70));
assert_eq(result, float32::cast_from_fixed_using_rne(s32:70));
()

}
Expand Down
6 changes: 3 additions & 3 deletions xls/examples/fir_filter.x
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ fn fir_filter_fixed_test() {

#[test]
fn fir_filter_float32_test() {
let samples = map(s32[6]:[1, 2, 3, 4, 5, 6], float32::cast_from_fixed);
let coefficients= map(s32[4]:[10, 11, -12, -13], float32::cast_from_fixed);
let samples = map(s32[6]:[1, 2, 3, 4, 5, 6], float32::cast_from_fixed_using_rne);
let coefficients= map(s32[4]:[10, 11, -12, -13], float32::cast_from_fixed_using_rne);
let result = fir_filter_float32(samples, coefficients);
let expected = map(s32[3]:[36, 32, 28], float32::cast_from_fixed);
let expected = map(s32[3]:[36, 32, 28], float32::cast_from_fixed_using_rne);
assert_eq(result, expected);
()
}
24 changes: 12 additions & 12 deletions xls/examples/sobel_filter.x
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const Y_STENCIL = s32[3][3]:[s32[3]:[-1, -2, -1],
// TODO(jbaileyhandle): Do we have a way to reshape multidimensional
// arrays or to apply map to multidemnsional arrays?
fn convert_triplet(input: s32[3]) -> F32[3] {
map(input, float32::cast_from_fixed)
map(input, float32::cast_from_fixed_using_rne)
}
const X_STENCIL_F32 = map(X_STENCIL, convert_triplet);
const Y_STENCIL_F32 = map(Y_STENCIL, convert_triplet);
Expand Down Expand Up @@ -70,25 +70,25 @@ fn apply_stencil_float32_test() {
let img1 = map(s32[16]:[1, 1, 1, 1,
1, -1, -2, 1,
1, -4, -3, 1,
1, 1, 1, 1], float32::cast_from_fixed);
1, 1, 1, 1], float32::cast_from_fixed_using_rne);

assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:0, u32:0, X_STENCIL_F32),
float32::cast_from_fixed(s32:-10));
float32::cast_from_fixed_using_rne(s32:-10));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:0, u32:1, X_STENCIL_F32),
float32::cast_from_fixed(s32:9));
float32::cast_from_fixed_using_rne(s32:9));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:1, u32:0, X_STENCIL_F32),
float32::cast_from_fixed(s32:-11));
float32::cast_from_fixed_using_rne(s32:-11));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:1, u32:1, X_STENCIL_F32),
float32::cast_from_fixed(s32:12));
float32::cast_from_fixed_using_rne(s32:12));

assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:0, u32:0, Y_STENCIL_F32),
float32::cast_from_fixed(s32:-14));
float32::cast_from_fixed_using_rne(s32:-14));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:0, u32:1, Y_STENCIL_F32),
float32::cast_from_fixed(s32:-13));
float32::cast_from_fixed_using_rne(s32:-13));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:1, u32:0, Y_STENCIL_F32),
float32::cast_from_fixed(s32:7));
float32::cast_from_fixed_using_rne(s32:7));
assert_eq(apply_stencil_float32<u32:4, u32:4>(img1, u32:1, u32:1, Y_STENCIL_F32),
float32::cast_from_fixed(s32:8));
float32::cast_from_fixed_using_rne(s32:8));

()
}
Expand Down Expand Up @@ -128,7 +128,7 @@ fn sobel_filter_float32_test() {
let img1 = map(s32[16]:[1, 1, 1, 1,
1, -1, -2, 1,
1, -4, -3, 1,
1, 1, 1, 1], float32::cast_from_fixed);
1, 1, 1, 1], float32::cast_from_fixed_using_rne);

let sobel_out = sobel_filter_float32<u32:4, u32:4>(img1);
// Truncate before comparison for simplicity.
Expand All @@ -143,7 +143,7 @@ fn sobel_filter_float32_test() {
// Try non-square image.
let img1 = map(s32[12]:[1, 1, 1, 1,
1, -1, -2, 1,
1, -4, -3, 1], float32::cast_from_fixed);
1, -4, -3, 1], float32::cast_from_fixed_using_rne);
let sobel_out = sobel_filter_float32<u32:3, u32:4>(img1);
assert_eq(float32::cast_to_fixed<u32:32>(sobel_out[u32:0]), s32:17);
assert_eq(float32::cast_to_fixed<u32:32>(sobel_out[u32:1]), s32:15);
Expand Down

0 comments on commit cae5fbb

Please sign in to comment.