Skip to content

Commit

Permalink
[MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (l…
Browse files Browse the repository at this point in the history
…lvm#104451)

1. Expand legal op list in `legalizeToF32`
2. add legalization support for `math::rsqrtOp` in `mathToLibm`.
  • Loading branch information
crazydemo authored Aug 21, 2024
1 parent 5ec73b7 commit b96f18b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
1 change: 1 addition & 0 deletions mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ void mlir::math::populateLegalizeToF32TypeConverter(

void mlir::math::populateLegalizeToF32ConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter) {
target.addDynamicallyLegalDialect<MathDialect>(
[&typeConverter](Operation *op) -> bool {
return typeConverter.isLegal(op);
});
target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
if (isa<MathDialect>(op->getDialect()))
return typeConverter.isLegal(op);
return true;
});
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s

// CHECK-DAG: @acos(f64) -> f64 attributes {llvm.readnone}
Expand Down Expand Up @@ -58,6 +59,8 @@
// CHECK-DAG: @ceilf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @sqrt(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @sqrtf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @rsqrt(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @rsqrtf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @pow(f64, f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @powf(f32, f32) -> f32 attributes {llvm.readnone}

Expand Down Expand Up @@ -999,6 +1002,43 @@ func.func @sqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
// CHECK: }

// CHECK-LABEL: func @rsqrt_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func.func @rsqrt_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @rsqrtf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.rsqrt %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @rsqrt(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.rsqrt %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}

func.func @rsqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.rsqrt %float : vector<2xf32>
%double_result = math.rsqrt %double : vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
// CHECK-LABEL: func @rsqrt_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @rsqrtf(%[[IN0_F32]]) : (f32) -> f32
// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @rsqrtf(%[[IN1_F32]]) : (f32) -> f32
// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @rsqrt(%[[IN0_F64]]) : (f64) -> f64
// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @rsqrt(%[[IN1_F64]]) : (f64) -> f64
// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
// CHECK: }

// CHECK-LABEL: func @powf_caller(
// CHECK-SAME: %[[FLOATA:.*]]: f32, %[[FLOATB:.*]]: f32
// CHECK-SAME: %[[DOUBLEA:.*]]: f64, %[[DOUBLEB:.*]]: f64
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Math/legalize-to-f32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,17 @@ func.func @sequences(%arg0: f16) -> f16 {
%1 = math.sin %0 : f16
return %1 : f16
}

// CHECK-LABEL: @promote_in_if_block
func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
// CHECK: [[EXTF0:%.+]] = arith.extf
// CHECK-NEXT: %[[RES:.*]] = scf.if
%0 = scf.if %arg2 -> bf16 {
%1 = math.absf %arg0 : bf16
// CHECK: [[TRUNCF0:%.+]] = arith.truncf
scf.yield %1 : bf16
} else {
scf.yield %arg1 : bf16
}
return %0 : bf16
}

0 comments on commit b96f18b

Please sign in to comment.