Skip to content

Commit

Permalink
Fix stream sink op folder with users captured in nested regions (iree…
Browse files Browse the repository at this point in the history
…-org#16363)

Fix the sink op folder to check users in regions and not sink below
those users.

Fix iree-org#16320
  • Loading branch information
Jerry Wu authored Feb 12, 2024
1 parent 14927d1 commit 246edee
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
17 changes: 13 additions & 4 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,12 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) {
// that if `targetOp` is not a terminator, then we can prune the set of
// sinkable ops that might fight with `toBeSunkOp` more aggressively by using
// use-def chains.
bool allowUseDefPruning = !targetOp->hasTrait<mlir::OpTrait::IsTerminator>();
// The use-def chains check below doesn't detect implicit captures (which can
// be heavy to check) so we also ignore `targetOp` with regions. This can be
// relexed if needed.
bool allowUseDefPruning =
!targetOp->hasTrait<mlir::OpTrait::IsTerminator>() &&
targetOp->getNumRegions() == 0;

// If the sinking operation would be a no-op, then we need to prevent
// the sinking operation, to avoid infinite pattern applications.
Expand Down Expand Up @@ -1406,9 +1411,13 @@ struct SinkAllocaLikeOpToConsumers : public OpRewritePattern<Op> {
// can sink down to it.
Operation *firstUserInDominator = commonDominator->getTerminator();
for (auto user : users) {
if (user->getBlock() == commonDominator) {
if (user->isBeforeInBlock(firstUserInDominator)) {
firstUserInDominator = user;
for (auto ancestor = user; ancestor != commonDominator->getParentOp();
ancestor = ancestor->getParentOp()) {
if (ancestor->getBlock() == commonDominator) {
if (ancestor->isBeforeInBlock(firstUserInDominator)) {
firstUserInDominator = ancestor;
}
break;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,50 @@ func.func @SinkSplatsToConsumers(

// -----

// CHECK-LABEL: @SinkSplatsToCommonAncestorOfConsumersInRegions
func.func @SinkSplatsToCommonAncestorOfConsumersInRegions(%arg0: i1) -> (!stream.resource<*>, !stream.resource<*>) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
%c2 = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
%c3 = arith.constant 3 : index
// CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index
%c100 = arith.constant 100 : index
// CHECK-DAG: %[[C123:.+]] = arith.constant 123 : i32
%c123_i32 = arith.constant 123 : i32
// CHECK-DAG: %[[C456:.+]] = arith.constant 456 : i32
%c456_i32 = arith.constant 456 : i32
// CHECK-DAG: %[[C789:.+]] = arith.constant 789 : i32
%c789_i32 = arith.constant 789 : i32
// CHECK-NOT: stream.async.splat %[[C123]]
// CHECK-NOT: stream.async.splat %[[C456]]
%0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c100}
%1 = stream.async.splat %c456_i32 : i32 -> !stream.resource<*>{%c100}
// CHECK: %[[SPLAT3:.+]] = stream.async.splat %[[C789]]
%2 = stream.async.splat %c789_i32 : i32 -> !stream.resource<*>{%c100}
// CHECK: stream.async.dispatch @executable::@dispatch2[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT3]][%[[C0]] to %[[C100]] for %[[C100]]])
%3 = stream.async.dispatch @executable::@dispatch2[%c1, %c2, %c3](%2[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
// CHECK-DAG: %[[SPLAT1:.+]] = stream.async.splat %[[C123]]
// CHECK-DAG: %[[SPLAT2:.+]] = stream.async.splat %[[C456]]
// CHECK-NEXT: scf.if
%4 = scf.if %arg0 -> (!stream.resource<*>) {
// CHECK: stream.async.dispatch @executable::@dispatch0[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT1]][%[[C0]] to %[[C100]] for %[[C100]]], %[[SPLAT2]][%[[C0]] to %[[C100]] for %[[C100]]])
%5 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
scf.yield %5 : !stream.resource<*>
// CHECK: else
} else {
// CHECK: stream.async.dispatch @executable::@dispatch1[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT1]][%[[C0]] to %[[C100]] for %[[C100]]], %[[SPLAT2]][%[[C0]] to %[[C100]] for %[[C100]]])
%6 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
scf.yield %6 : !stream.resource<*>
}
return %4, %3 : !stream.resource<*>, !stream.resource<*>
}

// -----

// CHECK-LABEL: @SplatAlreadyAtSinkLocation
func.func @SplatAlreadyAtSinkLocation(
%arg0: i1, %arg1: i1,
Expand Down

0 comments on commit 246edee

Please sign in to comment.