diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 64d0412bf76d..7a81a2086396 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -52,6 +52,7 @@ iree_compiler_cc_library( "AMDGPUDistributeContract.cpp", "GPUApplyTilingLevel.cpp", "GPUCheckResourceUsage.cpp", + "GPUCombineValueBarriers.cpp", "GPUCreateFastSlowPath.cpp", "GPUDistribute.cpp", "GPUDistributeScfFor.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 5a376fdcc5e3..8d0ad6c887e3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -50,6 +50,7 @@ iree_cc_library( "AMDGPUDistributeContract.cpp" "GPUApplyTilingLevel.cpp" "GPUCheckResourceUsage.cpp" + "GPUCombineValueBarriers.cpp" "GPUCreateFastSlowPath.cpp" "GPUDistribute.cpp" "GPUDistributeScfFor.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp new file mode 100644 index 000000000000..6e9982eb14ff --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp @@ -0,0 +1,263 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// This file implements the pass to combine multiple `iree_gpu.value_barrier` +// ops. +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +#define DEBUG_TYPE "iree-codegen-gpu-combine-value-barriers" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUCOMBINEVALUEBARRIERSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +/// Move the given backward slice before the given barrier. +/// The backward slice should not have any operations which are before the +/// barrier or the barrier itself. +static void moveBackwardSliceBeforeBarrier(RewriterBase &rewriter, + llvm::SetVector &slice, + Operation *leadingBarrier) { + // Sort operations to be moved topologically. + slice = topologicalSort(slice); + + // It is always valid (w.r.t. dominance) to move topologically sorted + // operations in a backward slice which come after the insertion point, to + // before the insertion point. This is because: + // - Since we are operating on a backward slice, producers to every operation + // in the slice are already in the slice, and will be moved behind the + // insertion point. + // - Any consumers will still remain after the operation, as we are only + // moving the operation before. + for (Operation *sliceOp : slice) { + rewriter.moveOpBefore(sliceOp, leadingBarrier); + } +} + +/// Move the slice after the given barrier. +/// The forward slice should not have any operations which are after the +/// barrier or the barrier itself. +static void moveForwardSliceAfterBarrier(RewriterBase &rewriter, + llvm::SetVector &slice, + Operation *trailingBarrier) { + // Sort operations to be moved topologically. + slice = topologicalSort(slice); + + // It is always valid (w.r.t. dominance) to move topologically sorted + // operations in a forward slice which come before the insertion point, to + // after the insertion point. This is because: + // - Since we are operating on a forward slice, consumers to every operation + // in the slice are already in the slice, and will be moved after the + // insertion point, + // - Any producers will still remain before the operation, as we are only + // moving the operation after. + for (Operation *sliceOp : llvm::reverse(slice)) { + rewriter.moveOpAfter(sliceOp, trailingBarrier); + } +} + +/// Combine all value barriers into a single value barrier. +static LogicalResult +combineValueBarrierOps(RewriterBase &rewriter, Location loc, + ArrayRef valueBarriers) { + if (valueBarriers.size() <= 1) { + return success(); + } + SmallVector barrierOperands; + for (auto barrierOp : valueBarriers) { + barrierOperands.append(barrierOp.getInputs().begin(), + barrierOp.getInputs().end()); + } + auto combinedBarrierOp = + rewriter.create(loc, barrierOperands); + + // Replace all uses of the previous barrier with new barrier. + int resultNumber = 0; + for (auto barrierOp : valueBarriers) { + int numResults = barrierOp.getNumResults(); + rewriter.replaceOp(barrierOp, combinedBarrierOp->getResults().slice( + resultNumber, numResults)); + resultNumber += numResults; + } + return success(); +} + +/// Given two barriers, barrierA and barrierB, combine them into a single +/// barrier. +static FailureOr +combineValueBarrierPair(RewriterBase &rewriter, + IREE::GPU::ValueBarrierOp barrierA, + IREE::GPU::ValueBarrierOp barrierB) { + // Both barriers need to have either tensor semantics or vector semantics. + if (barrierA.hasTensorSemantics() && !barrierB.hasTensorSemantics()) { + return failure(); + } + if (!barrierA.hasTensorSemantics() && barrierB.hasTensorSemantics()) { + return failure(); + } + + // We assume barrierA is always before barrierB. + if (barrierB->isBeforeInBlock(barrierA)) { + std::swap(barrierA, barrierB); + } + + // barrierA and barrierB are in the same block. + assert(barrierA->getBlock() == barrierB->getBlock()); + Block *block = barrierA->getBlock(); + + auto sliceFilterBackward = [&block, &barrierA](Operation *candidate) -> bool { + if (candidate->getBlock() != block) { + return false; + } + if (candidate == block->getTerminator()) { + // Do not move the terminator. + return false; + } + if (candidate->isBeforeInBlock(barrierA)) { + return false; + } + return true; + }; + + // Find the combined backward slice of barrierA and barrierB and try + // to move it before barrierA (before both the barriers). + BackwardSliceOptions bOptions; + bOptions.filter = sliceFilterBackward; + SetVector backwardSliceA; + SetVector backwardSliceB; + getBackwardSlice(barrierA, &backwardSliceA, bOptions); + getBackwardSlice(barrierB, &backwardSliceB, bOptions); + backwardSliceA.insert(backwardSliceB.begin(), backwardSliceB.end()); + // If the first barrier is contained in the combined backward slice of both + // barriers, the barriers form a chain and cannot be combined. + if (backwardSliceA.contains(barrierA)) { + return failure(); + } + // Move the backward slice before barrierA. + moveBackwardSliceBeforeBarrier(rewriter, backwardSliceA, barrierA); + + auto sliceFilterForward = [&block, &barrierB](Operation *candidate) -> bool { + if (candidate->getBlock() != block) { + return false; + } + if (candidate == block->getTerminator()) { + // Do not move the terminator. + return false; + } + if (barrierB->isBeforeInBlock(candidate)) { + return false; + } + return true; + }; + + // Find the combined forward slice of barrierA and barrierB and try to + // move it after barrierB (after both the barriers). + ForwardSliceOptions fOptions; + fOptions.filter = sliceFilterForward; + SetVector forwardSliceA; + SetVector forwardSliceB; + getForwardSlice(barrierA, &forwardSliceA, fOptions); + getForwardSlice(barrierB, &forwardSliceB, fOptions); + forwardSliceA.insert(forwardSliceB.begin(), forwardSliceB.end()); + // If the second barrier is contained in the combined forward slice of both + // barriers, the barriers form a chain and cannot be combined. + if (forwardSliceA.contains(barrierA)) { + return failure(); + } + // Move the forward slice after barrierB. + moveForwardSliceAfterBarrier(rewriter, forwardSliceA, barrierB); + + // We add the new barrier after both the barriers (it is always better + // to sink barriers). + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(barrierB); + + SmallVector barrierOperands; + barrierOperands.append(barrierA.getOperands().begin(), + barrierA.getOperands().end()); + barrierOperands.append(barrierB.getOperands().begin(), + barrierB.getOperands().end()); + + auto combinedBarrierOp = rewriter.create( + barrierB.getLoc(), barrierOperands); + + int numOperandsA = barrierA.getNumOperands(); + int numOperandsB = barrierB.getNumOperands(); + rewriter.replaceOp(barrierA, + combinedBarrierOp->getResults().slice(0, numOperandsA)); + rewriter.replaceOp(barrierB, combinedBarrierOp->getResults().slice( + numOperandsA, numOperandsB)); + + return combinedBarrierOp; +} + +static void combineValueBarriersInBlock(RewriterBase &rewriter, Block *block) { + SmallVector barriers; + for (Operation &op : block->getOperations()) { + if (auto barrier = dyn_cast(op)) { + barriers.push_back(barrier); + } + } + + // We iterate over all pairs. This could be optimized to O(n) to take + // into account deletions, but we do the simplest thing for now. + int numBarriers = barriers.size(); + for (int i = 0; i < numBarriers; ++i) { + if (!barriers[i]) { + continue; + } + + for (int j = i + 1; j < numBarriers; ++j) { + if (!barriers[j]) { + continue; + } + + FailureOr combined = + combineValueBarrierPair(rewriter, barriers[i], barriers[j]); + if (succeeded(combined)) { + barriers[i] = combined.value(); + barriers[j] = nullptr; + } + } + } +} + +struct GPUCombineValueBarriersPass final + : impl::GPUCombineValueBarriersPassBase { + + void runOnOperation() override { + // Walk the operation to get all blocks that have value barriers. We + // restrict ourselves to blocks, because the order of operations in a block + // is easy to determine. + SmallVector blocks; + getOperation()->walk([&blocks](Block *block) { + if (llvm::any_of(block->getOperations(), + llvm::IsaPred)) { + blocks.push_back(block); + } + }); + + IRRewriter rewriter(&getContext()); + for (auto *block : blocks) { + combineValueBarriersInBlock(rewriter, block); + } + + return; + } +}; + +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index f02205aae0ac..f08455138f76 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -19,6 +19,12 @@ def GPUCheckResourceUsagePass : let constructor = "mlir::iree_compiler::createGPUCheckResourceUsagePass()"; } +def GPUCombineValueBarriersPass : + Pass<"iree-codegen-gpu-combine-value-barriers", ""> { + let summary = "Combines `iree_gpu.value_barrier` ops"; + let dependentDialects = ["::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"]; +} + def GPUCreateFastSlowPathPass : InterfacePass<"iree-codegen-gpu-create-fast-slow-path", "mlir::FunctionOpInterface"> { let summary = "Create separate fast and slow paths to handle padding"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 257cbe82db80..063541534ec3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "gpu_generalize_named_ops.mlir", "gpu_infer_memory_space.mlir", "gpu_lower_to_ukernels.mlir", + "gpu_combine_value_barriers.mlir", "gpu_nested_layout_contract_amdgpu.mlir", "gpu_nested_layout_vector_distribution.mlir", "gpu_pipeline.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index a67de538082d..98ec5d346417 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "gpu_apply_tiling_level.mlir" "gpu_check_resource_usage.mlir" + "gpu_combine_value_barriers.mlir" "gpu_create_fast_slow_path.mlir" "gpu_distribute.mlir" "gpu_distribute_scf_for.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_combine_value_barriers.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_combine_value_barriers.mlir new file mode 100644 index 000000000000..624e6b7ccbb1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_combine_value_barriers.mlir @@ -0,0 +1,129 @@ +// RUN: iree-opt --iree-codegen-gpu-combine-value-barriers %s --split-input-file | FileCheck %s + +// Since the pass only rearanges the order of instructions, we only check the +// number of value_barriers. + +func.func @tensor_barrier(%write: vector<8xf16>, %input: tensor<8xf16>, %input2 : tensor<16xf16>) -> (vector<8xf16>, vector<8xf16>) { + %c0 = arith.constant 0 : index + %cv0 = arith.constant 0.0 : f16 + + %wait1 = vector.transfer_write %write, %input[%c0] : vector<8xf16>, tensor<8xf16> + %synced1 = iree_gpu.value_barrier %wait1 : tensor<8xf16> + %out = vector.transfer_read %synced1[%c0], %cv0 : tensor<8xf16>, vector<8xf16> + + %wait2 = vector.transfer_write %write, %input2[%c0] : vector<8xf16>, tensor<16xf16> + %synced2 = iree_gpu.value_barrier %wait2 : tensor<16xf16> + %out2 = vector.transfer_read %synced2[%c0], %cv0 : tensor<16xf16>, vector<8xf16> + + return %out, %out2 : vector<8xf16>, vector<8xf16> +} + +// There should be only 1 value_barrier left + +// CHECK-LABEL: func.func @tensor_barrier +// CHECK: value_barrier +// CHECK-NOT: value_barrier + +// ----- + +func.func @vector_barrier(%write: vector<8xf16>, %write2: vector<8xf16>) -> vector<8xf16> { + %synced = iree_gpu.value_barrier %write : vector<8xf16> + %synced2 = iree_gpu.value_barrier %write2 : vector<8xf16> + %add = arith.addf %synced, %synced2 : vector<8xf16> + return %add : vector<8xf16> +} + +// There should be only 1 value_barrier left + +// CHECK-LABEL: func.func @vector_barrier +// CHECK: value_barrier +// CHECK-NOT: value_barrier + +// ----- + +func.func @tensor_and_vector_barrier(%write: vector<8xf16>, %input: tensor<8xf16>) -> (vector<8xf16>, vector<8xf16>) { + %c0 = arith.constant 0 : index + %cv0 = arith.constant 0.0 : f16 + + %wait1 = vector.transfer_write %write, %input[%c0] : vector<8xf16>, tensor<8xf16> + %synced1 = iree_gpu.value_barrier %wait1 : tensor<8xf16> + %out = vector.transfer_read %synced1[%c0], %cv0 : tensor<8xf16>, vector<8xf16> + + %synced2 = iree_gpu.value_barrier %write : vector<8xf16> + + return %out, %synced2 : vector<8xf16>, vector<8xf16> +} + +// tensor and vector barriers cannot be combined, so both should remain + +// CHECK-LABEL: func.func @tensor_and_vector_barrier +// CHECK: value_barrier +// CHECK: value_barrier + +// ----- + +func.func @barriers_with_users(%write: vector<8xf16>, %input: tensor<8xf16>, %input2 : tensor<16xf16>, %input3 : tensor<16xf16>) -> (vector<8xf16>) { + %c0 = arith.constant 0 : index + %cv0 = arith.constant 0.0 : f16 + + %wait1 = vector.transfer_write %write, %input[%c0] : vector<8xf16>, tensor<8xf16> + %synced1 = iree_gpu.value_barrier %wait1 : tensor<8xf16> + %out = vector.transfer_read %synced1[%c0], %cv0 : tensor<8xf16>, vector<8xf16> + + %wait2 = vector.transfer_write %write, %input2[%c0] : vector<8xf16>, tensor<16xf16> + %synced2 = iree_gpu.value_barrier %wait2 : tensor<16xf16> + %out2 = vector.transfer_read %synced2[%c0], %cv0 : tensor<16xf16>, vector<8xf16> + + %add1 = arith.addf %out, %out2 : vector<8xf16> + + %wait3 = vector.transfer_write %write, %input3[%c0] : vector<8xf16>, tensor<16xf16> + %synced3 = iree_gpu.value_barrier %wait3 : tensor<16xf16> + %out3 = vector.transfer_read %synced3[%c0], %cv0 : tensor<16xf16>, vector<8xf16> + + %add2 = arith.addf %add1, %out3 : vector<8xf16> + + return %add2 : vector<8xf16> +} + +// There should be only 1 value_barrier left + +// CHECK-LABEL: func.func @barriers_with_users +// CHECK: value_barrier +// CHECK-NOT: value_barrier + +// ----- + +func.func @barrier_diamond_chain(%write: vector<8xf16>, %input: tensor<8xf16>) -> (tensor<8xf16>) { + %c0 = arith.constant 0 : index + %cv0 = arith.constant 0.0 : f16 + + %wait1 = vector.transfer_write %write, %input[%c0] : vector<8xf16>, tensor<8xf16> + %synced1 = iree_gpu.value_barrier %wait1 : tensor<8xf16> + + %wait2 = vector.transfer_write %write, %synced1[%c0] : vector<8xf16>, tensor<8xf16> + %synced2 = iree_gpu.value_barrier %wait2 : tensor<8xf16> + + %wait3 = vector.transfer_write %write, %synced1[%c0] : vector<8xf16>, tensor<8xf16> + %synced3 = iree_gpu.value_barrier %wait3 : tensor<8xf16> + + %d1 = vector.transfer_read %synced2[%c0], %cv0 : tensor<8xf16>, vector<8xf16> + %d2 = vector.transfer_read %synced3[%c0], %cv0 : tensor<8xf16>, vector<8xf16> + + %add = arith.addf %d1, %d2 : vector<8xf16> + + %synced4 = iree_gpu.value_barrier %add : vector<8xf16> + + %empty = tensor.empty() : tensor<8xf16> + %out = vector.transfer_write %synced4, %empty[%c0] : vector<8xf16>, tensor<8xf16> + + return %out : tensor<8xf16> +} + +// There should be 3 value_barriers left, since in a diamond chain, you can +// only combine the middle barriers. + +// CHECK-LABEL: func.func @barrier_diamond_chain +// CHECK: value_barrier +// CHECK: value_barrier +// CHECK: value_barrier +// CHECK-NOT: value_barrier