diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index fd8db86e723cd6..84cef8abd91b23 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -588,6 +588,11 @@ struct TosaFoldConstantRSQRT auto floatVal = apFloatVal.convertToFloat(); auto sqrtVal = std::sqrt(floatVal); APFloat apSqrtVal(sqrtVal); + // We fold only float32 and bfloat16, so we do not expect any precision loss + // for float32 and the tosa spec explicitly allows to implement bfloat16 as + // float32, so any precision loss on the conversion back is fine. + bool losesInfo = false; + apSqrtVal.convert(apFloatVal.getSemantics(), tosaRoundingMode, &losesInfo); // Compute the reciprocal return computeReciprocal(apSqrtVal, floatTy); @@ -600,7 +605,7 @@ struct TosaFoldConstantRSQRT } bool isSupportedElementType(Type type) const { - return type.isBF16() || type.isF16() || type.isF32(); + return type.isBF16() || type.isF32(); } }; diff --git a/mlir/test/Dialect/Tosa/constant-rsqrt-opt.mlir b/mlir/test/Dialect/Tosa/constant-rsqrt-opt.mlir index 1088dd2541f169..f025ce6e3d1451 100644 --- a/mlir/test/Dialect/Tosa/constant-rsqrt-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-rsqrt-opt.mlir @@ -123,6 +123,16 @@ func.func @rsqrt_fold() -> tensor<4x6xf32> { return %1 : tensor<4x6xf32> } +// CHECK-LABEL: @rsqrt_fold_single_valued_bf16 +func.func @rsqrt_fold_single_valued_bf16() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.890630e-01{{.*}}tensor + // CHECK-NOT: tosa.rsqrt + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<12.0> : tensor} : () -> tensor + %1 = "tosa.rsqrt"(%0) : (tensor) -> tensor + return %1 : tensor +} + // CHECK-LABEL: @rsqrt_of_const_sparse // Sparse tensors are currently not supported func.func @rsqrt_of_const_sparse() -> tensor<32xbf16> {