From 137e36507aeeefa0aecf12bbd3fbaf2defbfe376 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 20 Aug 2024 13:23:08 -0400 Subject: [PATCH] [Codegen][GPU] Add pass to annotate memory spaces on allocations (#18251) Trying to infer the memory space of an allocation from within the bufferization alloc callback function is too late. This adds a rudimentary pass to annotate the memory space in obvious situations and then disallows all cases of a bufferization allocation without an already pre-determined memory space (for the LLVMGPUTileAndFuse pipeline). This gives us correctness guarantees that were somewhat hand wavy before. This makes all allocations that aren't marked explicitly as shared (or can be obviously inferred as shared) as thread local. Any previous lowerings that violate this invariant is a bug (most likely from a failure to tile an operation). --- .../compiler/Codegen/Common/GPU/BUILD.bazel | 1 + .../Codegen/Common/GPU/CMakeLists.txt | 1 + .../Common/GPU/GPUInferMemorySpace.cpp | 106 ++++++++++++++++++ .../Common/GPU/GPUVerifyDistribution.cpp | 23 +--- .../compiler/Codegen/Common/GPU/Passes.td | 8 ++ .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../GPU/test/gpu_infer_memory_space.mlir | 54 +++++++++ .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 + .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 55 +++++---- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 1 + .../iree/compiler/Codegen/Utils/GPUUtils.h | 26 +++++ 13 files changed, 238 insertions(+), 41 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 4a0b879b94ee..64d0412bf76d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -58,6 +58,7 @@ iree_compiler_cc_library( "GPUDistributeSharedMemoryCopy.cpp", "GPUDistributionPatterns.cpp", "GPUGeneralizeNamedOps.cpp", + "GPUInferMemorySpace.cpp", "GPULowerToUKernels.cpp", "GPUMultiBuffering.cpp", "GPUNestedLayoutDistributionPatterns.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index eb51b3ec2408..5a376fdcc5e3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -56,6 +56,7 @@ iree_cc_library( "GPUDistributeSharedMemoryCopy.cpp" "GPUDistributionPatterns.cpp" "GPUGeneralizeNamedOps.cpp" + "GPUInferMemorySpace.cpp" "GPULowerToUKernels.cpp" "GPUMultiBuffering.cpp" "GPUNestedLayoutDistributionPatterns.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp new file mode 100644 index 000000000000..64fed3b038c9 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp @@ -0,0 +1,106 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUINFERMEMORYSPACEPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +/// Pass to infer the memory spaces of unmarked `bufferization.alloc_tensor` +/// ops. Inferring the memory space during bufferization (in the allocation +/// function) is infeasible due to some limited analysis of surrounding loop +/// structures needed. After this pass, any unexpected allocations are then +/// treated as a compiler failure indicating something went wrong during +/// bufferization. +struct GPUInferMemorySpacePass final + : impl::GPUInferMemorySpacePassBase { + + void runOnOperation() override; +}; + +bool isDefinitelyShared(bufferization::AllocTensorOp alloc) { + // An allocation can be inferred as shared if it is the destination of a + // thread distributed `scf.forall` op. All other shared allocations are + // expected to be properly indicated in advance. + for (auto user : alloc->getUsers()) { + auto forallOp = dyn_cast(user); + if (!forallOp || + !forallOpHasMappingType(forallOp)) { + return false; + } + } + return true; +} + +void GPUInferMemorySpacePass::runOnOperation() { + MLIRContext *context = &getContext(); + FunctionOpInterface funcOp = getOperation(); + + gpu::AddressSpaceAttr privateAddressSpace = gpu::AddressSpaceAttr::get( + context, gpu::GPUDialect::getPrivateAddressSpace()); + gpu::AddressSpaceAttr sharedAddressSpace = gpu::AddressSpaceAttr::get( + context, gpu::GPUDialect::getWorkgroupAddressSpace()); + + WalkResult res = funcOp.walk([&](bufferization::AllocTensorOp alloc) { + // Continue if the allocation already has a valid memory space. + std::optional currentMemSpace = alloc.getMemorySpace(); + if (currentMemSpace.has_value()) { + if (currentMemSpace.value() == privateAddressSpace || + currentMemSpace.value() == sharedAddressSpace) { + return WalkResult::advance(); + } + alloc.emitOpError( + "unexpected gpu memory space must be private or workgroup."); + return WalkResult::interrupt(); + } + + /// Determining GPU memory spaces must be trivial by the time of this pass. + /// Because this pass runs immediately before bufferization, input IR is + /// expected to mix (thread) distributed and shared contexts. Because after + /// bufferization distributed loops (scf.forall) ops are expected to be + /// inlined as-is with no further tiling occurring, all tensors at this + /// point in the IR are assumed to be thread-local unless it is explicitly + /// marked as shared. This gives the following invariants: + /// + /// 1. If the alloc_tensor is annotated with `#gpu.address_space` + /// already, or if it is used as the immediate destination of a thread + /// or warp distributed `scf.forall` op, then the allocation must be + /// shared memory. + /// 2. All other allocations are thread local. + /// + /// Any allocation that is not explicitly marked as shared memory that is + /// supposed to be indicates a bug in earlier passes/lowerings. + if (isDefinitelyShared(alloc)) { + alloc.setMemorySpaceAttr(sharedAddressSpace); + } else { + alloc.setMemorySpaceAttr(privateAddressSpace); + } + return WalkResult::advance(); + }); + + if (res.wasInterrupted()) { + funcOp->emitOpError("failed to set the gpu memory space for all " + "`bufferization.alloc_tensor` ops"); + return signalPassFailure(); + } +} + +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp index 273cadfff5a8..fe4738034b3c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -19,28 +20,6 @@ namespace mlir::iree_compiler { namespace { -template -bool forallOpHasMappingType(scf::ForallOp forallOp) { - std::optional mapping = forallOp.getMapping(); - if (!mapping || mapping.value().empty()) { - return false; - } - - return isa(*mapping.value().begin()); -} - -template -bool operationHasParentForallOfMappingType(Operation *op) { - auto parentForallOp = op->getParentOfType(); - while (parentForallOp) { - if (forallOpHasMappingType(parentForallOp)) { - return true; - } - parentForallOp = parentForallOp->getParentOfType(); - } - return false; -} - /// Pass to verify that writes only happen in distributed contexts. Code in /// shared contexts are executed uniformly across all threads after resolution /// of distributed contexts (i.e. scf.forall), thus operations with write diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index cec8ba43a030..f02205aae0ac 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -59,6 +59,14 @@ def GPUGeneralizeNamedOpsPass : let summary = "Convert named Linalg ops to linalg.generic ops"; } +def GPUInferMemorySpacePass : + InterfacePass<"iree-codegen-gpu-infer-memory-space", "mlir::FunctionOpInterface"> { + let summary = "Pass to infer and set the memory space for all alloc_tensor ops."; + let dependentDialects = [ + "::mlir::gpu::GPUDialect" + ]; +} + def GPULowerToUKernelsPass : Pass<"iree-codegen-gpu-lower-to-ukernels", ""> { let summary = "Separate out parts of the IR that lower to a micro-kernel"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 5854bd5ca932..257cbe82db80 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -25,6 +25,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", "gpu_generalize_named_ops.mlir", + "gpu_infer_memory_space.mlir", "gpu_lower_to_ukernels.mlir", "gpu_nested_layout_contract_amdgpu.mlir", "gpu_nested_layout_vector_distribution.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index a61138be693a..a67de538082d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" "gpu_generalize_named_ops.mlir" + "gpu_infer_memory_space.mlir" "gpu_lower_to_ukernels.mlir" "gpu_nested_layout_contract_amdgpu.mlir" "gpu_nested_layout_vector_distribution.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir new file mode 100644 index 000000000000..7d1533ab5b01 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir @@ -0,0 +1,54 @@ +// RUN: iree-opt %s --split-input-file --verify-diagnostics \ +// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-infer-memory-space))" | FileCheck %s + +func.func @write_in_lane_forall(%dest : tensor<4x3xi32>) -> tensor<4x3xi32> { + %alloc = bufferization.alloc_tensor() : tensor<2x3xi32> + %cst = arith.constant dense<0> : vector<2x3xi32> + %c0 = arith.constant 0 : index + %res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> { + %w = vector.transfer_write %cst, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<2x3xi32>, tensor<2x3xi32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32> + } + } {mapping = [#iree_gpu.lane_id<0>]} + return %res : tensor<4x3xi32> +} + +// CHECK: func @write_in_lane_forall +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} +// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]] + +// ----- + +func.func @forall_shared_dest(%w : tensor<2x3xi32>) -> tensor<4x3xi32> { + %dest = bufferization.alloc_tensor() : tensor<4x3xi32> + %res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> { + scf.forall.in_parallel { + tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32> + } + } {mapping = [#gpu.warp]} + return %res : tensor<4x3xi32> +} + +// CHECK: func @forall_shared_dest +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} +// CHECK: scf.forall {{.*}} shared_outs(%{{.*}} = %[[ALLOC]]) + +// ----- + +func.func @already_annotated_alloc() -> tensor<2x3xi32> { + %alloc = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<2x3xi32> + return %alloc : tensor<2x3xi32> +} + +// CHECK: func @already_annotated_alloc +// CHECK: bufferization.alloc_tensor() {memory_space = #gpu.address_space} + +// ----- + +// expected-error@+1 {{failed to set the gpu memory space for all `bufferization.alloc_tensor` ops}} +func.func @unknown_memory_space() -> tensor<2x3xi32> { + // expected-error@+1 {{unexpected gpu memory space must be private or workgroup.}} + %alloc = bufferization.alloc_tensor() {memory_space = "bad"} : tensor<2x3xi32> + return %alloc : tensor<2x3xi32> +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index 9ef45c757d63..35bafc739f14 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -166,6 +166,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:ControlFlowToLLVM", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index a5d3b0844462..0f8b40b34bc5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -116,6 +116,7 @@ iree_cc_library( MLIRArithToLLVM MLIRArithTransforms MLIRBufferizationDialect + MLIRBufferizationTransforms MLIRComplexToLLVM MLIRComplexToStandard MLIRControlFlowToLLVM diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index f71f616e3e32..7c71f8675717 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" @@ -143,20 +144,6 @@ static FailureOr gpuAllocationFn(OpBuilder &builder, Location loc, .getResult(); } -static FailureOr gpuWorkgroupAllocationFn(OpBuilder &builder, - Location loc, - MemRefType memRefType, - ValueRange dynamicSizes, - unsigned alignment) { - gpu::AddressSpaceAttr addressSpace = gpu::AddressSpaceAttr::get( - builder.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); - MemRefType allocType = - MemRefType::get(memRefType.getShape(), memRefType.getElementType(), - AffineMap(), addressSpace); - return builder.create(loc, allocType, dynamicSizes) - .getResult(); -} - // Barriers are only needed when copying to/from workgroup memory. The only // other kind of memory that can be allocated is function memory, which is local // to a thread. @@ -211,10 +198,8 @@ static ReorderWorkgroupsStrategy getReorderWorkgroupsStrategy( // Common Pass Recipes //===----------------------------------------------------------------------===// -static void addBufferizePasses(OpPassManager &funcPassManager, - bool allowPrivateAllocations = true) { - BufferizationOptions::AllocationFn allocationFn = - allowPrivateAllocations ? gpuAllocationFn : gpuWorkgroupAllocationFn; +static void addBufferizePasses(OpPassManager &funcPassManager) { + BufferizationOptions::AllocationFn allocationFn = gpuAllocationFn; BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn; addIREEComprehensiveBufferizePasses(funcPassManager, allocationFn, memcpyFn); funcPassManager.addPass(createCanonicalizerPass()); @@ -305,6 +290,34 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) { // Tile and Fuse //===---------------------------------------------------------------------===// +static FailureOr gpuRequireMemSpaceAllocationFn(OpBuilder &builder, + Location loc, + MemRefType memRefType, + ValueRange dynamicSizes, + unsigned alignment) { + // Bail out if the memref type does not specify a memory space. + if (!isa(memRefType.getMemorySpace())) { + return failure(); + } + return builder.create(loc, memRefType, dynamicSizes) + .getResult(); +} + +static void addGPUBufferizePasses(OpPassManager &funcPassManager) { + funcPassManager.addPass(createEliminateEmptyTensorsPass()); + funcPassManager.addPass(bufferization::createEmptyTensorToAllocTensorPass()); + funcPassManager.addPass(createGPUInferMemorySpacePass()); + BufferizationOptions::AllocationFn allocationFn = + gpuRequireMemSpaceAllocationFn; + BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn; + funcPassManager.addPass( + createIREEComprehensiveBufferizePass(allocationFn, memcpyFn)); + addIREEPostBufferizationPasses(funcPassManager); + + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); +} + void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { tileAndDistributeToWorkgroup(funcPassManager, /*useWARForCooperativeMatrixCodegen=*/false, @@ -371,6 +384,10 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { /*normalizeForall=*/true})); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); + + // TODO: This LICM instance is load bearing due to brittleness of the + // hoisting and fusion pass, as well as a lack of a fallback distribution + // pass. funcPassManager.addPass(createLoopInvariantCodeMotionPass()); // Step 5. Greedily fuse parallel loops and hoist from serial loops. @@ -385,7 +402,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createCleanupBufferAllocViewPass()); // Step 7. Bufferize. - addBufferizePasses(funcPassManager, /*allowPrivateAllocations=*/true); + addGPUBufferizePasses(funcPassManager); // Step 8. Resolve remaining parallel loops. funcPassManager.addPass(createGPUVerifyDistributionPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 4f5425a52b48..991450905b73 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -558,6 +558,7 @@ hal.executable public @main { // CHECK-LABEL: func @conv_nchw_fused // CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 // CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: outs(%{{.*}} : memref<1x1x1x1xf32, #gpu.address_space>) // CHECK: arith.addf // CHECK: arith.cmpf // CHECK: arith.select diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index b34209ac9ba0..7cbddf4e79bc 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -39,6 +39,32 @@ llvm::SmallVector getSubgroupIdsAndCounts(OpBuilder &builder, Location loc, unsigned warpSize, unsigned numDims, llvm::ArrayRef numSubgroups); +// Indicates whether the given `scf.forall` op has a processor ID mapping of +// the template type(s). +template +bool forallOpHasMappingType(scf::ForallOp forallOp) { + std::optional mapping = forallOp.getMapping(); + if (!mapping || mapping.value().empty()) { + return false; + } + + return isa(*mapping.value().begin()); +} + +// Indicates whether an operation is within a distributed context with the +// specified mapping type(s). +template +bool operationHasParentForallOfMappingType(Operation *op) { + auto parentForallOp = op->getParentOfType(); + while (parentForallOp) { + if (forallOpHasMappingType(parentForallOp)) { + return true; + } + parentForallOp = parentForallOp->getParentOfType(); + } + return false; +} + //===----------------------------------------------------------------------===// // GPU vectorization //===----------------------------------------------------------------------===//