From 697e162b3c58fe0ebc257ebb87d6f054f746be97 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 25 Aug 2023 12:02:00 -0400 Subject: [PATCH 1/3] [LLVMGPU] Extract subgroup size from export op to use for vector distribution Previously subgroup size was hard coded to 32. This extracts the subgroup size from the `hal.executable.export` op associated with the target. --- .../TransformExtensions/LLVMGPUExtensions.cpp | 41 ++++++++++++++----- .../LLVMGPUExtensionsOps.td | 2 +- ...transform_dialect_vector_distribution.mlir | 2 +- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 83b38444f237..b54d15109859 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -574,11 +574,13 @@ static Value simpleWarpShuffleFunction(Location loc, OpBuilder &builder, static void populatePropagateVectorDistribution(Operation *target, RewritePatternSet &patterns, - PatternBenefit benefit) { - auto groupReductionFn = [](Location loc, OpBuilder &builder, Value input, - vector::CombiningKind kind, uint32_t size) { + PatternBenefit benefit, + unsigned subgroupSize) { + auto groupReductionFn = [subgroupSize]( + Location loc, OpBuilder &builder, Value input, + vector::CombiningKind kind, uint32_t size) { return mlir::iree_compiler::emitGPUGroupReduction(loc, builder, input, kind, - size, 32); + size, subgroupSize); }; assert(target->hasTrait()); vector::populatePropagateWarpVectorDistributionPatterns( @@ -604,14 +606,30 @@ static void populateWarpExecuteOnLane0ToScf( DiagnosedSilenceableFailure transform_dialect::VectorWarpDistributionOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, + transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - if (!target->hasTrait()) { - target->emitOpError( - "applies only to isolated-from-above targets because it " - "needs to apply " - "patterns greedily"); + if (!isa(state.getTopLevel())) { + state.getTopLevel()->emitOpError( + "requires HAL::ExecutableOp or HAL::ExecutableVariantOp " + "toplevel to extract subgroup size information"); + return emitDefaultDefiniteFailure(target); + } + + IREE::HAL::ExecutableExportOp exportOp; + state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) { + if (op.getSymName() == target.getName()) + exportOp = op; + }); + if (!exportOp) { + state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); + return emitDefaultDefiniteFailure(target); + } + + std::optional subgroupSize = exportOp.getSubgroupSize(); + if (!subgroupSize) { + state.getTopLevel()->emitOpError( + "could not extract subgroup size from IREE::HAL::ExecutableExportOp"); return emitDefaultDefiniteFailure(target); } @@ -645,7 +663,8 @@ transform_dialect::VectorWarpDistributionOp::applyToOne( populateVectorTransferWriteDistribution(target, patterns, /*benefit=*/2); populatePropagateVectorDistribution(target, patterns, - /*benefit=*/1); + /*benefit=*/1, + subgroupSize->getSExtValue()); if (failed( applyPatternsAndFoldGreedily(target, std::move(patterns), config))) { return mlir::emitDefiniteFailure( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 9c2f6ee94c56..9c0d0b485ef9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -325,7 +325,7 @@ def VectorWarpDistributionOp : Op Date: Fri, 25 Aug 2023 13:01:39 -0400 Subject: [PATCH 2/3] Use getEntryPoint --- .../TransformExtensions/LLVMGPUExtensions.cpp | 35 ++++--------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index b54d15109859..c26f6609e478 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -79,24 +80,12 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - if (!isa(state.getTopLevel())) { - state.getTopLevel()->emitOpError( - "requires HAL::ExecutableOp or HAL::ExecutableVariantOp " - "toplevel to " - "attach the workgroup size information to a nested " - "ExecutableExportOp"); - return emitDefaultDefiniteFailure(target); - } - - IREE::HAL::ExecutableExportOp exportOp; - state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) { - if (op.getSymName() == target.getName()) - exportOp = op; - }); - if (!exportOp) { + FailureOr maybeExportOp = getEntryPoint(target); + if (failed(maybeExportOp)) { state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return emitDefaultDefiniteFailure(target); } + IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp; auto transformOp = cast(getOperation()); @@ -609,22 +598,12 @@ transform_dialect::VectorWarpDistributionOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - if (!isa(state.getTopLevel())) { - state.getTopLevel()->emitOpError( - "requires HAL::ExecutableOp or HAL::ExecutableVariantOp " - "toplevel to extract subgroup size information"); - return emitDefaultDefiniteFailure(target); - } - - IREE::HAL::ExecutableExportOp exportOp; - state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) { - if (op.getSymName() == target.getName()) - exportOp = op; - }); - if (!exportOp) { + FailureOr maybeExportOp = getEntryPoint(target); + if (failed(maybeExportOp)) { state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return emitDefaultDefiniteFailure(target); } + IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp; std::optional subgroupSize = exportOp.getSubgroupSize(); if (!subgroupSize) { From 3a2493bff71af6ac1ff705b16334aee078dfa422 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 25 Aug 2023 13:03:39 -0400 Subject: [PATCH 3/3] clang format --- .../LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index c26f6609e478..49dd178f5fbc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -80,7 +80,8 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - FailureOr maybeExportOp = getEntryPoint(target); + FailureOr maybeExportOp = + getEntryPoint(target); if (failed(maybeExportOp)) { state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return emitDefaultDefiniteFailure(target); @@ -598,7 +599,8 @@ transform_dialect::VectorWarpDistributionOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - FailureOr maybeExportOp = getEntryPoint(target); + FailureOr maybeExportOp = + getEntryPoint(target); if (failed(maybeExportOp)) { state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return emitDefaultDefiniteFailure(target);