diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 1b9ab22706d297..c4a833f19edb08 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -1005,6 +1005,37 @@ struct TosaFoldConstantAdd : public TosaFoldConstantBinary { + using TosaFoldConstantBinary::TosaFoldConstantBinary; + + DenseElementsAttr computeInteger(DenseElementsAttr lhsValues, + DenseElementsAttr rhsValues, + PatternRewriter &rewriter, SubOp op) const { + bool overflowed = false; + auto newTensor = applyElementWise(lhsValues, rhsValues, + op.getType(), [&overflowed](const APInt &first, const APInt &second) { + bool didOverflow; + auto res = first.ssub_ov(second, didOverflow); + overflowed |= didOverflow; + return res; + }); + + if (overflowed) { + op->emitWarning("Subtraction did overflow. The results are unspecified."); + } + return newTensor; + } + + DenseElementsAttr computeFloat(DenseElementsAttr lhsValues, + DenseElementsAttr rhsValues, + PatternRewriter &rewriter, SubOp op) const { + return applyElementWise(lhsValues, rhsValues, + op.getType(), [](const APFloat &first, const APFloat &second) { + return first - second; + }); + } +}; + struct TosaFoldConstantGreater : public TosaFoldConstantBinary { using TosaFoldConstantBinary::TosaFoldConstantBinary; @@ -1382,6 +1413,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns( patterns.add(ctx, foldSplatOrSingleUseOnly); } patterns.add(ctx, foldSplatOrSingleUseOnly); + patterns.add(ctx, foldSplatOrSingleUseOnly); patterns.add(ctx, foldSplatOrSingleUseOnly); patterns.add(ctx, foldSplatOrSingleUseOnly); patterns.add(ctx, foldSplatOrSingleUseOnly); @@ -1394,4 +1426,5 @@ void mlir::tosa::populateTosaFoldConstantPatterns( patterns.add(ctx, foldSplatOrSingleUseOnly); patterns.add(ctx, foldSplatOrSingleUseOnly); patterns.add(ctx, foldSplatOrSingleUseOnly); -} \ No newline at end of file +} + diff --git a/mlir/test/Dialect/Tosa/constant-sub.mlir b/mlir/test/Dialect/Tosa/constant-sub.mlir new file mode 100644 index 00000000000000..c8a8a2560b326e --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-sub.mlir @@ -0,0 +1,150 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// Float subtractions + +// CHECK-LABEL: @sub_fold_float +func.func @sub_fold_float() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.152500e+02, 7.988280e+00, 0.000000e+00, -5.000000e+00 + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17.4978, 4.9882, 0.0, -0.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %1 = "tosa.const"() {value = + dense<[-132.7, -3.0, -0.0, 5.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %2 = "tosa.sub"(%0, %1) : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + return %2 : tensor<4xf16> +} + +// CHECK-LABEL: @sub_fold_float_infinity_nan +func.func @sub_fold_float_infinity_nan() -> tensor<6xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000 + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000]> : + tensor<6xf32> + } : () -> tensor<6xf32> + %1 = "tosa.const"() {value = + dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> : + tensor<6xf32> + } : () -> tensor<6xf32> + %2 = "tosa.sub"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32> + return %2 : tensor<6xf32> +} + +// CHECK-LABEL: @sub_fold_float_overflow +func.func @sub_fold_float_overflow() -> tensor<2xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000 + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[3.1e+38, -3.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %1 = "tosa.const"() {value = + dense<[-2.1e+38, 1.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %2 = "tosa.sub"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- +// Int subtraction + +// CHECK-LABEL: @sub_fold_int +func.func @sub_fold_int() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-149, 1, 0, 5 + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[132, 3, 0, -5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.sub"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: @sub_fold_int_overflow +func.func @sub_fold_int_overflow() -> tensor<4xi32> { + // Don't expect any specific results for the overflowing subtraction, just + // expect that it is folded. + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[2147483647, 2147483640, -2147483648, -2147483640]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-1, -10, 1, 30]> : + tensor<4xi32> + } : () -> tensor<4xi32> + // expected-warning@below {{Subtraction did overflow. The results are unspecified.}} + %2 = "tosa.sub"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// ----- +// self-subtraction + +// CHECK-LABEL: @sub_fold_equal_args +func.func @sub_fold_equal_args() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0> + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %2 = "tosa.sub"(%0, %0) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// ----- +// Broadcasted subtractions + +// CHECK-LABEL: @sub_fold_int_broadcast_simple +func.func @sub_fold_int_broadcast_simple() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-29, -8, -12 + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %1 = "tosa.const"() {value = + dense<12> : + tensor<1xi32> + } : () -> tensor<1xi32> + %2 = "tosa.sub"(%0, %1) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// CHECK-LABEL: @sub_fold_int_broadcast_complex +func.func @sub_fold_int_broadcast_complex() -> tensor<3x3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[-29, -10, -13], + // CHECK-SAME{LITERAL}: [-11, 8, 5], + // CHECK-SAME{LITERAL}: [7, 26, 23]] + // CHECK-NOT: tosa.sub + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[[-17], [1], [19]]> : + tensor<3x1xi32> + } : () -> tensor<3x1xi32> + %1 = "tosa.const"() {value = + dense<[[12, -7, -4]]> : + tensor<1x3xi32> + } : () -> tensor<1x3xi32> + %2 = "tosa.sub"(%0, %1) : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + return %2 : tensor<3x3xi32> +}