Skip to content

Commit

Permalink
Fix Tosa folders for supporting NaNs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jul 28, 2023
1 parent 4f03561 commit f77e5e7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,8 @@ struct TosaFoldConstantGreater : public TosaFoldConstantBinary<TosaFoldConstantG
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return APInt(1, false);
return APInt(1, first > second);
});
}
Expand Down Expand Up @@ -1170,6 +1172,8 @@ struct TosaFoldConstantGreaterEqual
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return APInt(1, false);
return APInt(1, first >= second);
});
}
Expand Down Expand Up @@ -1225,6 +1229,8 @@ struct TosaFoldConstantMinimum
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return first.isNaN() ? first : second;
return first < second ? first : second;
});
}
Expand Down Expand Up @@ -1253,6 +1259,8 @@ struct TosaFoldConstantMaximum
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return first.isNaN() ? first : second;
return first > second ? first : second;
});
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/constant-greater-equal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func.func @greater_equal_fold_float() -> tensor<4xi1> {

// CHECK-LABEL: @greater_equal_fold_float_infinity_nan
func.func @greater_equal_fold_float_infinity_nan() -> tensor<6xi1> {
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[true, false, true, false, false, true]>
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[true, false, true, false, false, false]>
// CHECK-NOT: tosa.greater_equal
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000]> :
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0xFF800000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%1 = "tosa.const"() {value =
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> :
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0x7FC00000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%2 = "tosa.greater_equal"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xi1>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/constant-greater-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func.func @greater_fold_float() -> tensor<4xi1> {

// CHECK-LABEL: @greater_fold_float_infinity_nan
func.func @greater_fold_float_infinity_nan() -> tensor<6xi1> {
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[true, false, true, false, false, true]>
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[true, false, true, false, false, false]>
// CHECK-NOT: tosa.greater
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000]> :
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0xFF800000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%1 = "tosa.const"() {value =
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> :
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0x7FC00000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%2 = "tosa.greater"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xi1>
Expand Down
9 changes: 5 additions & 4 deletions mlir/test/Dialect/Tosa/constant-maximum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ func.func @maximum_fold_float() -> tensor<4xf16> {

// CHECK-LABEL: @maximum_fold_float_infinity_nan
func.func @maximum_fold_float_infinity_nan() -> tensor<6xf32> {
// Any comparison with NAN results in NAN
// 0x7FC00000 is the value for NAN
// 0x7F800000 is the value for Inf
// 0xFF800000 is the value for -Inf
// 0x7FC00000 is the value for NAN
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[0x7F800000, -3.000000e+00, 0x7F800000, 3.000000e+00, 1.000000e+00, 0x7F800000]>
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[0x7F800000, -3.000000e+00, 0x7F800000, 3.000000e+00, 0x7FC00000, 0x7FC00000]>
// CHECK-NOT: tosa.maximum
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0x7F800000, -3.000000e+00, 0x7F800000, 3.000000e+00, 1.000000e+00, 0x7F800000]> :
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0xFF800000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%1 = "tosa.const"() {value =
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> :
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0x7FC00000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%2 = "tosa.maximum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32>
Expand Down
9 changes: 5 additions & 4 deletions mlir/test/Dialect/Tosa/constant-minimum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ func.func @minimum_fold_float() -> tensor<4xf16> {

// CHECK-LABEL: @minimum_fold_float_infinity_nan
func.func @minimum_fold_float_infinity_nan() -> tensor<6xf32> {
// Any comparison with NAN results in NAN
// 0x7FC00000 is the value for NAN
// 0x7F800000 is the value for Inf
// 0xFF800000 is the value for -Inf
// 0x7FC00000 is the value for NAN
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[3.000000e+00, 0xFF800000, -3.000000e+00, 0xFF800000, 1.000000e+00, 0xFF800000]>
// CHECK: [[RES:]] = "tosa.const"() <{value = dense<[3.000000e+00, 0xFF800000, -3.000000e+00, 0xFF800000, 0x7FC00000, 0x7FC00000]>
// CHECK-NOT: tosa.minimum
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000]> :
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0xFF800000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%1 = "tosa.const"() {value =
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> :
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0x7FC00000]> :
tensor<6xf32>
} : () -> tensor<6xf32>
%2 = "tosa.minimum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32>
Expand Down

0 comments on commit f77e5e7

Please sign in to comment.