Skip to content

Commit

Permalink
[compiler] fold mhlo.reduce with splat input
Browse files Browse the repository at this point in the history
  • Loading branch information
jianwenyyy committed Jun 24, 2024
1 parent d5be436 commit 7fe71ed
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
110 changes: 110 additions & 0 deletions compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,111 @@ struct CanonicalizeBroadcastToBroadcastInDim
}
};

template <typename Op, typename ElementType = Type, typename ValType,
typename FuncType>
static Attribute ReduceConstFolder(mhlo::ReduceOp *op,
ArrayRef<SplatElementsAttr> attrs,
ValType reduceCnt) {
if (!attrs[0] || !attrs[1])
return {};
auto splatInput = attrs[0];
ShapedType type = cast<ShapedType>(op->getResults()[0].getType());
Type etype = type.getElementType();
auto signedInput = addSign(splatInput.getSplatValue<ValType>(), etype);
auto signedReduceCnt = addSign(reduceCnt, etype);
FailureOr<decltype(signedInput)> result;
if (std::is_same_v<Op, mhlo::AddOp>) {
result = FailureOr<decltype(signedInput)>(
std::multiplies<FuncType>()(signedInput, signedReduceCnt));
} else if (std::is_same_v<Op, mhlo::MulOp>) {
result = FailureOr<decltype(signedInput)>(
Pow<FuncType>()(signedInput, signedReduceCnt));
} else if (std::is_same_v<Op, mhlo::MaxOp> ||
std::is_same_v<Op, mhlo::MinOp>) {
result = FailureOr<decltype(signedInput)>(signedInput);
} else {
return {};
}
return succeeded(result) ? SplatElementsAttr::get(type, *result)
: Attribute();
}

template <typename RegionOp>
struct FoldReduceOp : public OpRewritePattern<mhlo::ReduceOp> {
using OpRewritePattern<mhlo::ReduceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (isRegularReduceOp<RegionOp>(op)) {
auto input = op.getInputs()[0].getDefiningOp<mhlo::ConstantOp>();
auto initValue = op.getInitValues()[0].getDefiningOp<mhlo::ConstantOp>();
if (!input || !initValue) {
return failure();
}
// Only covers the case of both attrs being splats
SplatElementsAttr splatInput =
dyn_cast<SplatElementsAttr>(input.getValue());
SplatElementsAttr splaInitValue =
dyn_cast<SplatElementsAttr>(initValue.getValue());
auto type = cast<ShapedType>(op.getResults()[0].getType());
if (!splatInput || !splaInitValue || !type || !type.hasStaticShape()) {
return failure();
}
auto inputShape = cast<ShapedType>(splatInput.getType()).getShape();
auto reduceDims =
llvm::to_vector(op.getDimensions().getValues<int64_t>());
if (!reduceDims.size()) {
return failure();
}
int64_t reduceCntInt = 1;
for (const auto &dim : reduceDims) {
reduceCntInt *= inputShape[dim];
}
Attribute result;
if (isa<FloatType>(type.getElementType())) {
APFloat reduceCnt(static_cast<float>(reduceCntInt));
result = ReduceConstFolder<RegionOp, FloatType, APFloat, APFloat>(
&op, ArrayRef<SplatElementsAttr>{splatInput, splaInitValue},
reduceCnt);
} else if (isa<IntegerType>(type.getElementType())) {
APInt reduceCnt(splatInput.getSplatValue<APInt>().getBitWidth(),
static_cast<uint64_t>(reduceCntInt));
result = ReduceConstFolder<RegionOp, IntegerType, APInt, APSInt>(
&op, ArrayRef<SplatElementsAttr>{splatInput, splaInitValue},
reduceCnt);
}
if (!result) {
return failure();
}
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, result);
return success();
}
return failure();
}
};

struct DotGeneralZero : public OpRewritePattern<mhlo::DotGeneralOp> {
using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getLhs().getDefiningOp<mhlo::ConstantOp>();
auto rhs = op.getRhs().getDefiningOp<mhlo::ConstantOp>();
auto type = cast<ShapedType>(op.getType());
if (lhs && isZeroAttribute(lhs.getValue())) {
auto resizeSplat =
cast<SplatElementsAttr>(lhs.getValue()).resizeSplat(type);
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resizeSplat);
return success();
}
if (rhs && isZeroAttribute(rhs.getValue())) {
auto resizeSplat =
cast<SplatElementsAttr>(rhs.getValue()).resizeSplat(type);
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resizeSplat);
return success();
}
return failure();
}
};

} // namespace

void mlir::mhlo::populateFoldMultiplyZeroPattern(RewritePatternSet &patterns) {
Expand All @@ -2201,6 +2306,11 @@ void mlir::mhlo::populateFoldLargeBinaryOpPatterns(
patterns.add<FoldLargeBinaryOp<mhlo::PowOp, Pow>>(ctx);
patterns.add<FoldLargeCompareOp>(ctx);
patterns.add<FoldClampOp>(ctx);
patterns.add<FoldReduceOp<mhlo::AddOp>>(ctx);
patterns.add<FoldReduceOp<mhlo::MulOp>>(ctx);
patterns.add<FoldReduceOp<mhlo::MaxOp>>(ctx);
patterns.add<FoldReduceOp<mhlo::MinOp>>(ctx);
patterns.add<DotGeneralZero>(ctx);
}

void mlir::mhlo::populateConvertOpPattern(RewritePatternSet &patterns,
Expand Down
61 changes: 61 additions & 0 deletions compiler/test/Transforms/CanonicalizeExt/reduce_const.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: byteir-opt %s -canonicalize-ext | FileCheck %s

func.func private @fold_reduce_add_f() -> tensor<16xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = mhlo.constant dense<5.000000e+00> : tensor<1024x16x512xf32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0, 2] : (tensor<1024x16x512xf32>, tensor<f32>) -> tensor<16xf32>
return %2 : tensor<16xf32>
}
// CHECK-LABEL: fold_reduce_add_f
// CHECK: mhlo.constant dense<2.621440e+06>
// CHECK-NOT: mhlo.reduce

func.func private @fold_reduce_add_i() -> tensor<16xi32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<5> : tensor<16x1024xi32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [1] : (tensor<16x1024xi32>, tensor<i32>) -> tensor<16xi32>
return %2 : tensor<16xi32>
}
// CHECK-LABEL: fold_reduce_add_i
// CHECK: mhlo.constant dense<5120>
// CHECK-NOT: mhlo.reduce

func.func private @fold_reduce_mul_f() -> tensor<16xf32> {
%0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
%1 = mhlo.constant dense<2.000000e+00> : tensor<16x2x4xf32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.multiply across dimensions = [1, 2] : (tensor<16x2x4xf32>, tensor<f32>) -> tensor<16xf32>
return %2 : tensor<16xf32>
}
// CHECK-LABEL: fold_reduce_mul_f
// CHECK: mhlo.constant dense<2.560000e+02>
// CHECK-NOT: mhlo.reduce

func.func private @fold_reduce_mul_i() -> tensor<16xi32> {
%0 = mhlo.constant dense<1> : tensor<i32>
%1 = mhlo.constant dense<2> : tensor<16x16xi32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.multiply across dimensions = [0] : (tensor<16x16xi32>, tensor<i32>) -> tensor<16xi32>
return %2 : tensor<16xi32>
}
// CHECK-LABEL: fold_reduce_mul_i
// CHECK: mhlo.constant dense<65536>
// CHECK-NOT: mhlo.reduce

func.func private @fold_reduce_min_f() -> tensor<16xf32> {
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
%1 = mhlo.constant dense<5.000000e+00> : tensor<1024x16xf32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.minimum across dimensions = [0] : (tensor<1024x16xf32>, tensor<f32>) -> tensor<16xf32>
return %2 : tensor<16xf32>
}
// CHECK-LABEL: fold_reduce_min_f
// CHECK: mhlo.constant dense<5.000000e+00>
// CHECK-NOT: mhlo.reduce

func.func private @fold_reduce_max_i() -> tensor<16xi32> {
%0 = mhlo.constant dense<-2147483648> : tensor<i32>
%1 = mhlo.constant dense<5> : tensor<1024x512x16xi32>
%2 = mhlo.reduce(%1 init: %0) applies mhlo.maximum across dimensions = [0, 1] : (tensor<1024x512x16xi32>, tensor<i32>) -> tensor<16xi32>
return %2 : tensor<16xi32>
}
// CHECK-LABEL: fold_reduce_max_i
// CHECK: mhlo.constant dense<5>
// CHECK-NOT: mhlo.reduce

0 comments on commit 7fe71ed

Please sign in to comment.