Skip to content

Commit

Permalink
[compiler] do not fusion single reshape in aggressive fusion (#104)
Browse files Browse the repository at this point in the history
* leave single reshape on host so it could be lowered to `byre.alias`.
  • Loading branch information
qingyunqu authored Jan 25, 2024
1 parent b60d7a6 commit 0aa0c8a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
7 changes: 6 additions & 1 deletion compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ bool isFusibleTrigger(Operation *) { return true; }

bool isFusibleWith(Operation *, Operation *) { return true; }

bool isValidSingleOp(Operation *) { return true; }
bool isValidSingleOp(Operation *op) {
if (llvm::isa<mhlo::ReshapeOp>(op))
return false;
else
return true;
}

bool isValidFusionPattern(const MhloFusionPattern &) { return true; }

Expand Down
25 changes: 24 additions & 1 deletion compiler/test/Dialect/Mhlo/transforms/aggressiveFusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ func.func @mhlo_aggressive_fusion(%arg0 : tensor<32x32xf32>, %arg1 : tensor<32xi
%2 = "mhlo.add"(%1, %arg2) : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
return %2 : tensor<32x32xf32>
}

// CHECK-LABEL: func.func @mhlo_aggressive_fusion
// CHECK-NEXT: mhlo.fusion
// CHECK-NEXT: mhlo.add
Expand All @@ -15,3 +14,27 @@ func.func @mhlo_aggressive_fusion(%arg0 : tensor<32x32xf32>, %arg1 : tensor<32xi
// CHECK-NEXT: mhlo.return
// CHECK: {__byteir_hlo_aggressive_fusion__}
// CHECK: return


func.func @reshape_add(%arg0: tensor<2xf32>, %arg1: tensor<2x1xf32>) -> (tensor<2x1xf32>) {
%0 = mhlo.reshape %arg0 : (tensor<2xf32>) -> tensor<2x1xf32>
%1 = mhlo.add %0, %arg1 : tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
// CHECK-LABEL: func.func @reshape_add
// CHECK-NEXT: mhlo.fusion
// CHECK-NEXT: mhlo.reshape
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.return
// CHECK: {__byteir_hlo_aggressive_fusion__}
// CHECK: return


func.func @single_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> {
%0 = mhlo.reshape %arg0 : (tensor<2xf32>) -> tensor<2x1xf32>
return %0 : tensor<2x1xf32>
}
// CHECK-LABEL: func.func @single_reshape
// CHECK-NOT: mhlo.fusion
// CHECK: mhlo.reshape
// CHECK: return

0 comments on commit 0aa0c8a

Please sign in to comment.