diff --git a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp index a54442885..1342b25f5 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp @@ -2182,6 +2182,121 @@ struct CanonicalizeBroadcastToBroadcastInDim } }; +template +static Attribute ReduceConstFolder(mhlo::ReduceOp *op, + ArrayRef attrs, + ValType reduceCnt) { + if (!attrs[0] || !attrs[1]) + return {}; + auto splatInput = attrs[0]; + ShapedType type = cast(op->getResults()[0].getType()); + Type etype = type.getElementType(); + auto signedInput = addSign(splatInput.getSplatValue(), etype); + auto signedReduceCnt = addSign(reduceCnt, etype); + FailureOr result; + if (std::is_same_v) { + result = FailureOr( + std::multiplies()(signedInput, signedReduceCnt)); + } else if (std::is_same_v) { + result = FailureOr( + Pow()(signedInput, signedReduceCnt)); + } else if (std::is_same_v || + std::is_same_v) { + result = FailureOr(signedInput); + } else { + return {}; + } + return succeeded(result) ? SplatElementsAttr::get(type, *result) + : Attribute(); +} + +template +struct FoldReduceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ReduceOp op, + PatternRewriter &rewriter) const override { + if (isRegularReduceOp(op)) { + auto input = op.getInputs()[0].getDefiningOp(); + auto initValue = op.getInitValues()[0].getDefiningOp(); + if (!input || !initValue) { + return failure(); + } + // Only covers the case of both attrs being splats + SplatElementsAttr splatInput = + dyn_cast(input.getValue()); + SplatElementsAttr splaInitValue = + dyn_cast(initValue.getValue()); + auto type = cast(op.getResults()[0].getType()); + if (!splatInput || !splaInitValue || !type || !type.hasStaticShape()) { + return failure(); + } + auto inputShape = cast(splatInput.getType()).getShape(); + auto reduceDims = + llvm::to_vector(op.getDimensions().getValues()); + if (!reduceDims.size()) { + return failure(); + } + int64_t reduceCntInt = 1; + for (const auto &dim : reduceDims) { + reduceCntInt *= inputShape[dim]; + } + Attribute result; + if (isa(type.getElementType())) { + APFloat reduceCnt(static_cast(reduceCntInt)); + bool loses_info; + auto status = reduceCnt.convert( + dyn_cast(type.getElementType()).getFloatSemantics(), + APFloat::rmNearestTiesToEven, &loses_info); + if ((status & (~APFloat::opInexact)) != APFloat::opOK) { + op->emitWarning() << "Could not convert reduceCnt to target fp " + "type: opStatus = " + << static_cast(status); + return failure(); + } + result = ReduceConstFolder( + &op, ArrayRef{splatInput, splaInitValue}, + reduceCnt); + } else if (isa(type.getElementType())) { + APInt reduceCnt(splatInput.getSplatValue().getBitWidth(), + static_cast(reduceCntInt)); + result = ReduceConstFolder( + &op, ArrayRef{splatInput, splaInitValue}, + reduceCnt); + } + if (!result) { + return failure(); + } + rewriter.replaceOpWithNewOp(op, result); + return success(); + } + return failure(); + } +}; + +struct DotGeneralZero : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::DotGeneralOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.getLhs().getDefiningOp(); + auto rhs = op.getRhs().getDefiningOp(); + auto type = cast(op.getType()); + if (lhs && isZeroAttribute(lhs.getValue())) { + auto resizeSplat = + cast(lhs.getValue()).resizeSplat(type); + rewriter.replaceOpWithNewOp(op, resizeSplat); + return success(); + } + if (rhs && isZeroAttribute(rhs.getValue())) { + auto resizeSplat = + cast(rhs.getValue()).resizeSplat(type); + rewriter.replaceOpWithNewOp(op, resizeSplat); + return success(); + } + return failure(); + } +}; + } // namespace void mlir::mhlo::populateFoldMultiplyZeroPattern(RewritePatternSet &patterns) { @@ -2201,6 +2316,11 @@ void mlir::mhlo::populateFoldLargeBinaryOpPatterns( patterns.add>(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add>(ctx); + patterns.add>(ctx); + patterns.add>(ctx); + patterns.add>(ctx); + patterns.add(ctx); } void mlir::mhlo::populateConvertOpPattern(RewritePatternSet &patterns, diff --git a/compiler/test/Transforms/CanonicalizeExt/reduce_const.mlir b/compiler/test/Transforms/CanonicalizeExt/reduce_const.mlir new file mode 100644 index 000000000..8c60ccd8a --- /dev/null +++ b/compiler/test/Transforms/CanonicalizeExt/reduce_const.mlir @@ -0,0 +1,71 @@ +// RUN: byteir-opt %s -canonicalize-ext | FileCheck %s + +func.func private @fold_reduce_add_f() -> tensor<16xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<5.000000e+00> : tensor<1024x16x512xf32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0, 2] : (tensor<1024x16x512xf32>, tensor) -> tensor<16xf32> + return %2 : tensor<16xf32> +} +// CHECK-LABEL: fold_reduce_add_f +// CHECK: mhlo.constant dense<2.621440e+06> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_add_d() -> tensor<16xf64> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<5.000000e+00> : tensor<1024x16x512xf64> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0, 2] : (tensor<1024x16x512xf64>, tensor) -> tensor<16xf64> + return %2 : tensor<16xf64> +} +// CHECK-LABEL: fold_reduce_add_d +// CHECK: mhlo.constant dense<2.621440e+06> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_add_i() -> tensor<16xi32> { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<5> : tensor<16x1024xi32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [1] : (tensor<16x1024xi32>, tensor) -> tensor<16xi32> + return %2 : tensor<16xi32> +} +// CHECK-LABEL: fold_reduce_add_i +// CHECK: mhlo.constant dense<5120> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_mul_f() -> tensor<16xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor<16x2x4xf32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.multiply across dimensions = [1, 2] : (tensor<16x2x4xf32>, tensor) -> tensor<16xf32> + return %2 : tensor<16xf32> +} +// CHECK-LABEL: fold_reduce_mul_f +// CHECK: mhlo.constant dense<2.560000e+02> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_mul_i() -> tensor<16xi32> { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<2> : tensor<16x16xi32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.multiply across dimensions = [0] : (tensor<16x16xi32>, tensor) -> tensor<16xi32> + return %2 : tensor<16xi32> +} +// CHECK-LABEL: fold_reduce_mul_i +// CHECK: mhlo.constant dense<65536> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_min_f() -> tensor<16xf32> { + %0 = mhlo.constant dense<0x7F800000> : tensor + %1 = mhlo.constant dense<5.000000e+00> : tensor<1024x16xf32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.minimum across dimensions = [0] : (tensor<1024x16xf32>, tensor) -> tensor<16xf32> + return %2 : tensor<16xf32> +} +// CHECK-LABEL: fold_reduce_min_f +// CHECK: mhlo.constant dense<5.000000e+00> +// CHECK-NOT: mhlo.reduce + +func.func private @fold_reduce_max_i() -> tensor<16xi32> { + %0 = mhlo.constant dense<-2147483648> : tensor + %1 = mhlo.constant dense<5> : tensor<1024x512x16xi32> + %2 = mhlo.reduce(%1 init: %0) applies mhlo.maximum across dimensions = [0, 1] : (tensor<1024x512x16xi32>, tensor) -> tensor<16xi32> + return %2 : tensor<16xi32> +} +// CHECK-LABEL: fold_reduce_max_i +// CHECK: mhlo.constant dense<5> +// CHECK-NOT: mhlo.reduce \ No newline at end of file