From 246edee03074f6f8126aa3f6ae7d815f32f245d9 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Mon, 12 Feb 2024 14:51:03 -0800 Subject: [PATCH] Fix stream sink op folder with users captured in nested regions (#16363) Fix the sink op folder to check users in regions and not sink below those users. Fix #16320 --- .../Dialect/Stream/IR/StreamOpFolders.cpp | 17 +++++-- .../Dialect/Stream/IR/test/async_folding.mlir | 44 +++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 81b9be9e2c40..4cb1080149ce 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -234,7 +234,12 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { // that if `targetOp` is not a terminator, then we can prune the set of // sinkable ops that might fight with `toBeSunkOp` more aggressively by using // use-def chains. - bool allowUseDefPruning = !targetOp->hasTrait(); + // The use-def chains check below doesn't detect implicit captures (which can + // be heavy to check) so we also ignore `targetOp` with regions. This can be + // relexed if needed. + bool allowUseDefPruning = + !targetOp->hasTrait() && + targetOp->getNumRegions() == 0; // If the sinking operation would be a no-op, then we need to prevent // the sinking operation, to avoid infinite pattern applications. @@ -1406,9 +1411,13 @@ struct SinkAllocaLikeOpToConsumers : public OpRewritePattern { // can sink down to it. Operation *firstUserInDominator = commonDominator->getTerminator(); for (auto user : users) { - if (user->getBlock() == commonDominator) { - if (user->isBeforeInBlock(firstUserInDominator)) { - firstUserInDominator = user; + for (auto ancestor = user; ancestor != commonDominator->getParentOp(); + ancestor = ancestor->getParentOp()) { + if (ancestor->getBlock() == commonDominator) { + if (ancestor->isBeforeInBlock(firstUserInDominator)) { + firstUserInDominator = ancestor; + } + break; } } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir index a5663b212c2b..de8a0f1c4c01 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir @@ -42,6 +42,50 @@ func.func @SinkSplatsToConsumers( // ----- +// CHECK-LABEL: @SinkSplatsToCommonAncestorOfConsumersInRegions +func.func @SinkSplatsToCommonAncestorOfConsumersInRegions(%arg0: i1) -> (!stream.resource<*>, !stream.resource<*>) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + %c2 = arith.constant 2 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + %c3 = arith.constant 3 : index + // CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index + %c100 = arith.constant 100 : index + // CHECK-DAG: %[[C123:.+]] = arith.constant 123 : i32 + %c123_i32 = arith.constant 123 : i32 + // CHECK-DAG: %[[C456:.+]] = arith.constant 456 : i32 + %c456_i32 = arith.constant 456 : i32 + // CHECK-DAG: %[[C789:.+]] = arith.constant 789 : i32 + %c789_i32 = arith.constant 789 : i32 + // CHECK-NOT: stream.async.splat %[[C123]] + // CHECK-NOT: stream.async.splat %[[C456]] + %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c100} + %1 = stream.async.splat %c456_i32 : i32 -> !stream.resource<*>{%c100} + // CHECK: %[[SPLAT3:.+]] = stream.async.splat %[[C789]] + %2 = stream.async.splat %c789_i32 : i32 -> !stream.resource<*>{%c100} + // CHECK: stream.async.dispatch @executable::@dispatch2[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT3]][%[[C0]] to %[[C100]] for %[[C100]]]) + %3 = stream.async.dispatch @executable::@dispatch2[%c1, %c2, %c3](%2[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100} + // CHECK-DAG: %[[SPLAT1:.+]] = stream.async.splat %[[C123]] + // CHECK-DAG: %[[SPLAT2:.+]] = stream.async.splat %[[C456]] + // CHECK-NEXT: scf.if + %4 = scf.if %arg0 -> (!stream.resource<*>) { + // CHECK: stream.async.dispatch @executable::@dispatch0[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT1]][%[[C0]] to %[[C100]] for %[[C100]]], %[[SPLAT2]][%[[C0]] to %[[C100]] for %[[C100]]]) + %5 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c100}) -> !stream.resource<*>{%c100} + scf.yield %5 : !stream.resource<*> + // CHECK: else + } else { + // CHECK: stream.async.dispatch @executable::@dispatch1[%[[C1]], %[[C2]], %[[C3]]](%[[SPLAT1]][%[[C0]] to %[[C100]] for %[[C100]]], %[[SPLAT2]][%[[C0]] to %[[C100]] for %[[C100]]]) + %6 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c100}) -> !stream.resource<*>{%c100} + scf.yield %6 : !stream.resource<*> + } + return %4, %3 : !stream.resource<*>, !stream.resource<*> +} + +// ----- + // CHECK-LABEL: @SplatAlreadyAtSinkLocation func.func @SplatAlreadyAtSinkLocation( %arg0: i1, %arg1: i1,