Skip to content

Commit

Permalink
[Codegen][GPU] Add pass to annotate memory spaces on allocations (ire…
Browse files Browse the repository at this point in the history
…e-org#18251)

Trying to infer the memory space of an allocation from within the
bufferization alloc callback function is too late. This adds a
rudimentary pass to annotate the memory space in obvious situations
and then disallows all cases of a bufferization allocation without an
already pre-determined memory space (for the LLVMGPUTileAndFuse
pipeline). This gives us correctness guarantees that were somewhat hand
wavy before.

This makes all allocations that aren't marked explicitly as shared (or can
be obviously inferred as shared) as thread local. Any previous lowerings
that violate this invariant is a bug (most likely from a failure to tile
an operation).
  • Loading branch information
qedawkins authored Aug 20, 2024
1 parent 5beb9ad commit 137e365
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 41 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ iree_compiler_cc_library(
"GPUDistributeSharedMemoryCopy.cpp",
"GPUDistributionPatterns.cpp",
"GPUGeneralizeNamedOps.cpp",
"GPUInferMemorySpace.cpp",
"GPULowerToUKernels.cpp",
"GPUMultiBuffering.cpp",
"GPUNestedLayoutDistributionPatterns.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_cc_library(
"GPUDistributeSharedMemoryCopy.cpp"
"GPUDistributionPatterns.cpp"
"GPUGeneralizeNamedOps.cpp"
"GPUInferMemorySpace.cpp"
"GPULowerToUKernels.cpp"
"GPUMultiBuffering.cpp"
"GPUNestedLayoutDistributionPatterns.cpp"
Expand Down
106 changes: 106 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_GPUINFERMEMORYSPACEPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {

/// Pass to infer the memory spaces of unmarked `bufferization.alloc_tensor`
/// ops. Inferring the memory space during bufferization (in the allocation
/// function) is infeasible due to some limited analysis of surrounding loop
/// structures needed. After this pass, any unexpected allocations are then
/// treated as a compiler failure indicating something went wrong during
/// bufferization.
struct GPUInferMemorySpacePass final
: impl::GPUInferMemorySpacePassBase<GPUInferMemorySpacePass> {

void runOnOperation() override;
};

bool isDefinitelyShared(bufferization::AllocTensorOp alloc) {
// An allocation can be inferred as shared if it is the destination of a
// thread distributed `scf.forall` op. All other shared allocations are
// expected to be properly indicated in advance.
for (auto user : alloc->getUsers()) {
auto forallOp = dyn_cast<scf::ForallOp>(user);
if (!forallOp ||
!forallOpHasMappingType<gpu::GPUThreadMappingAttr,
gpu::GPUWarpMappingAttr>(forallOp)) {
return false;
}
}
return true;
}

void GPUInferMemorySpacePass::runOnOperation() {
MLIRContext *context = &getContext();
FunctionOpInterface funcOp = getOperation();

gpu::AddressSpaceAttr privateAddressSpace = gpu::AddressSpaceAttr::get(
context, gpu::GPUDialect::getPrivateAddressSpace());
gpu::AddressSpaceAttr sharedAddressSpace = gpu::AddressSpaceAttr::get(
context, gpu::GPUDialect::getWorkgroupAddressSpace());

WalkResult res = funcOp.walk([&](bufferization::AllocTensorOp alloc) {
// Continue if the allocation already has a valid memory space.
std::optional<Attribute> currentMemSpace = alloc.getMemorySpace();
if (currentMemSpace.has_value()) {
if (currentMemSpace.value() == privateAddressSpace ||
currentMemSpace.value() == sharedAddressSpace) {
return WalkResult::advance();
}
alloc.emitOpError(
"unexpected gpu memory space must be private or workgroup.");
return WalkResult::interrupt();
}

/// Determining GPU memory spaces must be trivial by the time of this pass.
/// Because this pass runs immediately before bufferization, input IR is
/// expected to mix (thread) distributed and shared contexts. Because after
/// bufferization distributed loops (scf.forall) ops are expected to be
/// inlined as-is with no further tiling occurring, all tensors at this
/// point in the IR are assumed to be thread-local unless it is explicitly
/// marked as shared. This gives the following invariants:
///
/// 1. If the alloc_tensor is annotated with `#gpu.address_space<private>`
/// already, or if it is used as the immediate destination of a thread
/// or warp distributed `scf.forall` op, then the allocation must be
/// shared memory.
/// 2. All other allocations are thread local.
///
/// Any allocation that is not explicitly marked as shared memory that is
/// supposed to be indicates a bug in earlier passes/lowerings.
if (isDefinitelyShared(alloc)) {
alloc.setMemorySpaceAttr(sharedAddressSpace);
} else {
alloc.setMemorySpaceAttr(privateAddressSpace);
}
return WalkResult::advance();
});

if (res.wasInterrupted()) {
funcOp->emitOpError("failed to set the gpu memory space for all "
"`bufferization.alloc_tensor` ops");
return signalPassFailure();
}
}

} // namespace

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
Expand All @@ -19,28 +20,6 @@ namespace mlir::iree_compiler {

namespace {

template <typename... Type>
bool forallOpHasMappingType(scf::ForallOp forallOp) {
std::optional<ArrayAttr> mapping = forallOp.getMapping();
if (!mapping || mapping.value().empty()) {
return false;
}

return isa<Type...>(*mapping.value().begin());
}

template <typename... Type>
bool operationHasParentForallOfMappingType(Operation *op) {
auto parentForallOp = op->getParentOfType<scf::ForallOp>();
while (parentForallOp) {
if (forallOpHasMappingType<Type...>(parentForallOp)) {
return true;
}
parentForallOp = parentForallOp->getParentOfType<scf::ForallOp>();
}
return false;
}

/// Pass to verify that writes only happen in distributed contexts. Code in
/// shared contexts are executed uniformly across all threads after resolution
/// of distributed contexts (i.e. scf.forall), thus operations with write
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def GPUGeneralizeNamedOpsPass :
let summary = "Convert named Linalg ops to linalg.generic ops";
}

def GPUInferMemorySpacePass :
InterfacePass<"iree-codegen-gpu-infer-memory-space", "mlir::FunctionOpInterface"> {
let summary = "Pass to infer and set the memory space for all alloc_tensor ops.";
let dependentDialects = [
"::mlir::gpu::GPUDialect"
];
}

def GPULowerToUKernelsPass :
Pass<"iree-codegen-gpu-lower-to-ukernels", ""> {
let summary = "Separate out parts of the IR that lower to a micro-kernel";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"gpu_distribute_scf_for.mlir",
"gpu_distribute_shared_memory.mlir",
"gpu_generalize_named_ops.mlir",
"gpu_infer_memory_space.mlir",
"gpu_lower_to_ukernels.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
"gpu_nested_layout_vector_distribution.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"gpu_distribute_scf_for.mlir"
"gpu_distribute_shared_memory.mlir"
"gpu_generalize_named_ops.mlir"
"gpu_infer_memory_space.mlir"
"gpu_lower_to_ukernels.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: iree-opt %s --split-input-file --verify-diagnostics \
// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-infer-memory-space))" | FileCheck %s

func.func @write_in_lane_forall(%dest : tensor<4x3xi32>) -> tensor<4x3xi32> {
%alloc = bufferization.alloc_tensor() : tensor<2x3xi32>
%cst = arith.constant dense<0> : vector<2x3xi32>
%c0 = arith.constant 0 : index
%res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> {
%w = vector.transfer_write %cst, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<2x3xi32>, tensor<2x3xi32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32>
}
} {mapping = [#iree_gpu.lane_id<0>]}
return %res : tensor<4x3xi32>
}

// CHECK: func @write_in_lane_forall
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>}
// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]]

// -----

func.func @forall_shared_dest(%w : tensor<2x3xi32>) -> tensor<4x3xi32> {
%dest = bufferization.alloc_tensor() : tensor<4x3xi32>
%res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> {
scf.forall.in_parallel {
tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32>
}
} {mapping = [#gpu.warp<x>]}
return %res : tensor<4x3xi32>
}

// CHECK: func @forall_shared_dest
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>}
// CHECK: scf.forall {{.*}} shared_outs(%{{.*}} = %[[ALLOC]])

// -----

func.func @already_annotated_alloc() -> tensor<2x3xi32> {
%alloc = bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>} : tensor<2x3xi32>
return %alloc : tensor<2x3xi32>
}

// CHECK: func @already_annotated_alloc
// CHECK: bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>}

// -----

// expected-error@+1 {{failed to set the gpu memory space for all `bufferization.alloc_tensor` ops}}
func.func @unknown_memory_space() -> tensor<2x3xi32> {
// expected-error@+1 {{unexpected gpu memory space must be private or workgroup.}}
%alloc = bufferization.alloc_tensor() {memory_space = "bad"} : tensor<2x3xi32>
return %alloc : tensor<2x3xi32>
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ComplexToStandard",
"@llvm-project//mlir:ControlFlowToLLVM",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ iree_cc_library(
MLIRArithToLLVM
MLIRArithTransforms
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRComplexToLLVM
MLIRComplexToStandard
MLIRControlFlowToLLVM
Expand Down
55 changes: 36 additions & 19 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
Expand Down Expand Up @@ -143,20 +144,6 @@ static FailureOr<Value> gpuAllocationFn(OpBuilder &builder, Location loc,
.getResult();
}

static FailureOr<Value> gpuWorkgroupAllocationFn(OpBuilder &builder,
Location loc,
MemRefType memRefType,
ValueRange dynamicSizes,
unsigned alignment) {
gpu::AddressSpaceAttr addressSpace = gpu::AddressSpaceAttr::get(
builder.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
MemRefType allocType =
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
AffineMap(), addressSpace);
return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes)
.getResult();
}

// Barriers are only needed when copying to/from workgroup memory. The only
// other kind of memory that can be allocated is function memory, which is local
// to a thread.
Expand Down Expand Up @@ -211,10 +198,8 @@ static ReorderWorkgroupsStrategy getReorderWorkgroupsStrategy(
// Common Pass Recipes
//===----------------------------------------------------------------------===//

static void addBufferizePasses(OpPassManager &funcPassManager,
bool allowPrivateAllocations = true) {
BufferizationOptions::AllocationFn allocationFn =
allowPrivateAllocations ? gpuAllocationFn : gpuWorkgroupAllocationFn;
static void addBufferizePasses(OpPassManager &funcPassManager) {
BufferizationOptions::AllocationFn allocationFn = gpuAllocationFn;
BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn;
addIREEComprehensiveBufferizePasses(funcPassManager, allocationFn, memcpyFn);
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -305,6 +290,34 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
// Tile and Fuse
//===---------------------------------------------------------------------===//

static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
Location loc,
MemRefType memRefType,
ValueRange dynamicSizes,
unsigned alignment) {
// Bail out if the memref type does not specify a memory space.
if (!isa<gpu::AddressSpaceAttr>(memRefType.getMemorySpace())) {
return failure();
}
return builder.create<memref::AllocOp>(loc, memRefType, dynamicSizes)
.getResult();
}

static void addGPUBufferizePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createEliminateEmptyTensorsPass());
funcPassManager.addPass(bufferization::createEmptyTensorToAllocTensorPass());
funcPassManager.addPass(createGPUInferMemorySpacePass());
BufferizationOptions::AllocationFn allocationFn =
gpuRequireMemSpaceAllocationFn;
BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn;
funcPassManager.addPass(
createIREEComprehensiveBufferizePass(allocationFn, memcpyFn));
addIREEPostBufferizationPasses(funcPassManager);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager,
/*useWARForCooperativeMatrixCodegen=*/false,
Expand Down Expand Up @@ -371,6 +384,10 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) {
/*normalizeForall=*/true}));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// TODO: This LICM instance is load bearing due to brittleness of the
// hoisting and fusion pass, as well as a lack of a fallback distribution
// pass.
funcPassManager.addPass(createLoopInvariantCodeMotionPass());

// Step 5. Greedily fuse parallel loops and hoist from serial loops.
Expand All @@ -385,7 +402,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCleanupBufferAllocViewPass());

// Step 7. Bufferize.
addBufferizePasses(funcPassManager, /*allowPrivateAllocations=*/true);
addGPUBufferizePasses(funcPassManager);

// Step 8. Resolve remaining parallel loops.
funcPassManager.addPass(createGPUVerifyDistributionPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ hal.executable public @main {
// CHECK-LABEL: func @conv_nchw_fused
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
// CHECK: linalg.conv_2d_nchw_fchw
// CHECK-SAME: outs(%{{.*}} : memref<1x1x1x1xf32, #gpu.address_space<private>>)
// CHECK: arith.addf
// CHECK: arith.cmpf
// CHECK: arith.select
26 changes: 26 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@ llvm::SmallVector<linalg::ProcInfo, 2>
getSubgroupIdsAndCounts(OpBuilder &builder, Location loc, unsigned warpSize,
unsigned numDims, llvm::ArrayRef<int64_t> numSubgroups);

// Indicates whether the given `scf.forall` op has a processor ID mapping of
// the template type(s).
template <typename... Type>
bool forallOpHasMappingType(scf::ForallOp forallOp) {
std::optional<ArrayAttr> mapping = forallOp.getMapping();
if (!mapping || mapping.value().empty()) {
return false;
}

return isa<Type...>(*mapping.value().begin());
}

// Indicates whether an operation is within a distributed context with the
// specified mapping type(s).
template <typename... Type>
bool operationHasParentForallOfMappingType(Operation *op) {
auto parentForallOp = op->getParentOfType<scf::ForallOp>();
while (parentForallOp) {
if (forallOpHasMappingType<Type...>(parentForallOp)) {
return true;
}
parentForallOp = parentForallOp->getParentOfType<scf::ForallOp>();
}
return false;
}

//===----------------------------------------------------------------------===//
// GPU vectorization
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 137e365

Please sign in to comment.