diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 29dbe4cab999db..f338f5718712bf 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -362,6 +362,13 @@ struct TosaFoldConstantUnaryElementwise : public TosaFoldConstantBase { "tensor has a single user"); } + TensorType opType = dyn_cast(op.getType()); + if (opType == nullptr || + !static_cast(this)->isSupportedElementType( + opType.getElementType())) { + return rewriter.notifyMatchFailure(op, "Type is not supported."); + } + DenseElementsAttr newTensor = static_cast(this)->compute( inputValues, rewriter, op); if (!newTensor) { @@ -394,6 +401,9 @@ struct TosaFoldConstantUnaryElementwise : public TosaFoldConstantBase { PatternRewriter &rewriter, TosaOp op) const { return {}; } + + /// Return true if the \p elementType is supported by the folder. + bool isSupportedElementType(Type type) const { return true; } }; template @@ -559,6 +569,10 @@ struct TosaFoldConstantRSQRT return applyElementWise( values, &computeRSQRT, cast(values.getElementType())); } + + bool isSupportedElementType(Type type) const { + return type.isBF16() || type.isF16() || type.isF32(); + } }; struct TosaFoldConstantLogicalNot @@ -618,6 +632,10 @@ struct TosaFoldConstantPow return applyElementWise(lhsValues, rhsValues, op.getType(), computePower); } + + bool isSupportedElementType(Type type) const { + return type.isBF16() || type.isF16() || type.isF32(); + } }; struct TosaFoldConstantMul @@ -1023,10 +1041,7 @@ struct TosaFoldConstantBitwiseNot DenseElementsAttr computeInteger(DenseElementsAttr values, PatternRewriter &rewriter, TosaOp op) const { return applyElementWise( - values, - [](const APInt &val, IntegerType) { - return APInt(val.getBitWidth(), ~val.getZExtValue()); - }, + values, [](const APInt &val, IntegerType) { return ~val; }, cast(values.getElementType())); } }; @@ -1041,12 +1056,8 @@ struct TosaFoldConstantCeil return applyElementWise( values, [](const APFloat &val, FloatType) { - // Compute ceil (APFloat unfortunately does not provide this function, - // such that we need to unpack here) - auto res = APFloat(std::ceil(val.convertToFloat())); - bool lostPrecision; - res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven, - &lostPrecision); + auto res = val; + res.roundToIntegral(llvm::RoundingMode::TowardPositive); return res; }, cast(values.getElementType())); @@ -1063,8 +1074,6 @@ struct TosaFoldConstantErf return applyElementWise( values, [](const APFloat &val, FloatType) { - // Compute ceil (APFloat unfortunately does not provide this function, - // such that we need to unpack here) auto res = APFloat(std::erf(val.convertToFloat())); bool lostPrecision; res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven, @@ -1073,6 +1082,12 @@ struct TosaFoldConstantErf }, cast(values.getElementType())); } + + bool isSupportedElementType(Type type) const { + // Note: For now, we only support BF16 and F32 as std::erf may + // have an impact on the accuracy of the returned value. + return type.isBF16() || type.isF32(); + } }; } // namespace diff --git a/mlir/test/Dialect/Tosa/constant-ceil.mlir b/mlir/test/Dialect/Tosa/constant-ceil.mlir index e71815b91bddb4..0931d3f3b41ba2 100644 --- a/mlir/test/Dialect/Tosa/constant-ceil.mlir +++ b/mlir/test/Dialect/Tosa/constant-ceil.mlir @@ -30,6 +30,26 @@ func.func @ceil_fold_splat() -> tensor<12x7xf32> { return %1 : tensor<12x7xf32> } +// CHECK-LABEL: @ceil_fold_bf16 +func.func @ceil_fold_bf16() -> tensor<12x7xbf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}5.000000e+00 + // CHECK-NOT: tosa.ceil + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.2> : tensor<12x7xbf16>} : () -> tensor<12x7xbf16> + %1 = "tosa.ceil"(%0) : (tensor<12x7xbf16>) -> tensor<12x7xbf16> + return %1 : tensor<12x7xbf16> +} + +// CHECK-LABEL: @ceil_fold_f16 +func.func @ceil_fold_f16() -> tensor<12x7xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}5.000000e+00 + // CHECK-NOT: tosa.ceil + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.2> : tensor<12x7xf16>} : () -> tensor<12x7xf16> + %1 = "tosa.ceil"(%0) : (tensor<12x7xf16>) -> tensor<12x7xf16> + return %1 : tensor<12x7xf16> +} + // CHECK-LABEL: @ceil_nan func.func @ceil_nan() -> tensor { // 0x7FC00000 is the value for NAN diff --git a/mlir/test/Dialect/Tosa/constant-erf.mlir b/mlir/test/Dialect/Tosa/constant-erf.mlir index 9f53d56c08b135..e1a0ad94a96304 100644 --- a/mlir/test/Dialect/Tosa/constant-erf.mlir +++ b/mlir/test/Dialect/Tosa/constant-erf.mlir @@ -30,6 +30,16 @@ func.func @erf_fold_splat() -> tensor<12x7xf32> { return %1 : tensor<12x7xf32> } +// CHECK-LABEL: @erf_fold_bf16 +func.func @erf_fold_bf16() -> tensor<12x7xbf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}7.031250e-02 + // CHECK-NOT: tosa.erf + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0.0625> : tensor<12x7xbf16>} : () -> tensor<12x7xbf16> + %1 = "tosa.erf"(%0) : (tensor<12x7xbf16>) -> tensor<12x7xbf16> + return %1 : tensor<12x7xbf16> +} + // CHECK-LABEL: @erf_zero func.func @erf_zero() -> tensor { // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00 @@ -101,6 +111,16 @@ func.func @erf_no_fold(%arg0: tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @erf_no_fold_f16 +func.func @erf_no_fold_f16() -> tensor<12x7xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}6.250000e-02 + // CHECK: tosa.erf + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<6.250000e-02> : tensor<12x7xf16>} : () -> tensor<12x7xf16> + %1 = "tosa.erf"(%0) : (tensor<12x7xf16>) -> tensor<12x7xf16> + return %1 : tensor<12x7xf16> +} + // CHECK-LABEL: @erf_fold func.func @erf_fold() -> tensor<4x6xf32> { // CHECK: [[RES:]] ={{.*}}tosa.const