From 4e1aadf5311ca5c732356904ec9b5e924ca2d923 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 29 Oct 2024 21:39:17 -0700 Subject: [PATCH] Adding a flag to force indirect command buffers on in non-reusable cases. --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 14 ++++++++++++-- .../compiler/API/Internal/CompilerDriver.cpp | 8 ++++++-- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 13 ++++++++++++- .../iree/compiler/Dialect/HAL/IR/HALBase.td | 2 -- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 2 +- .../HAL/Transforms/OutlineMemoizeRegions.cpp | 19 +++++++++++++++---- .../Dialect/VM/Transforms/DropUnusedCalls.cpp | 1 - tests/compiler_driver/streams.mlir | 8 ++++---- 8 files changed, 50 insertions(+), 17 deletions(-) diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 05ab66779271..308c88097cd7 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -160,6 +160,17 @@ struct ROCMOptions { } }; +// Returns the ABI or an empty string if unspecified. +static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) { + if (targetAttr) { + if (auto config = targetAttr.getConfiguration()) { + auto abiAttr = targetAttr.getConfiguration().getAs("abi"); + return abiAttr ? abiAttr.getValue() : ""; + } + } + return ""; +} + static void dumpModuleToPath(StringRef path, StringRef baseName, StringRef suffix, StringRef extension, llvm::Module &module) { @@ -585,8 +596,7 @@ class ROCMTargetBackend final : public TargetBackend { // Wrap the HSACO ELF binary in a Flatbuffers container. FailureOr binaryContainer; - if (targetAttr.getConfiguration() && - targetAttr.getConfiguration().getAs("abi") == "amdgpu") { + if (getABI(targetAttr) == "amdgpu") { binaryContainer = serializeAMDGPUBinaryContainer( serializationOptions, variantOp, exportOps, targetHSACO); } else { diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index 7f83a5e3b3fe..d12dbc5c6d4b 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -1100,8 +1100,12 @@ Error *Invocation::outputHALExecutable(Output &output) { return new Error("not a valid HAL executable"); } auto binaryOp = binaryOps.front(); - auto rawData = binaryOp.getData().getRawData(); - output.outputStream->write(rawData.data(), rawData.size()); + if (failed(cast(binaryOp.getData()) + .serializeToStream(binaryOp.getLoc(), llvm::endianness::little, + *output.outputStream))) { + return new Error( + "data attribute failed to serialize: unsupported format or encoding"); + } output.outputStream->flush(); return output.getWriteError(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 4ca2fdbde223..491bd876ecb1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -32,6 +32,15 @@ static llvm::cl::opt clIndirectCommandBuffers{ llvm::cl::init(true), }; +// TODO(benvanik): remove when we support capturing dynamic values for reuse. +static llvm::cl::opt clForceIndirectCommandBuffers{ + "iree-hal-force-indirect-command-buffers", + llvm::cl::desc("Forces indirect command buffers when they would otherwise " + "not be chosen due to the values they capture. They may not " + "be reusable but will still be outlined."), + llvm::cl::init(false), +}; + struct ContextResolveOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -1002,7 +1011,9 @@ struct CmdExecuteOpPattern // changes dispatches to use them for any dispatch we can - note that there // may still be some that slip through due to custom executables. const bool capturesDynamicUniformValues = - regionCapturesDynamicUniformValues(executeOp); + clForceIndirectCommandBuffers + ? false + : regionCapturesDynamicUniformValues(executeOp); // Calculate the indirect buffer references used within the command buffer // by analyzing captured resources. This analysis will be used by subsequent diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td index 2b2f23c8dd17..3f1e8110b83e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -150,8 +150,6 @@ def HAL_Ordinal : TypeAlias; def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">; def HAL_OrdinalArrayAttr : TypedArrayAttrBase; -def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>; - def HAL_ElementType : TypeAlias; def HAL_ElementTypeAttr : SignlessIntegerAttrBase< I32, "element type attribute">; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 9e370a10c22b..2db507be002a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -2604,7 +2604,7 @@ def HAL_ExecutableBinaryOp : HAL_Op<"executable.binary", [ OptionalAttr:$sym_visibility, SymbolNameAttr:$sym_name, StrAttr:$format, - HAL_ExecutableDataAttr:$data, + Util_AnySerializableAttr:$data, OptionalAttr:$mime_type // TODO(benvanik): add compatibility and versioning attributes. ); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp index 30447862945f..19b8d490a5ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Utils/StringUtils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -24,6 +25,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" +#define DEBUG_TYPE "iree-hal-outline-memoize-regions" + namespace mlir::iree_compiler::IREE::HAL { #define GEN_PASS_DEF_OUTLINEMEMOIZEREGIONSPASS @@ -153,6 +156,8 @@ static IREE::Util::FuncOp outlineMemoizeRegionBody( name, funcType); moduleSymbolTable.insert(funcOp); funcOp.setVisibility(SymbolTable::Visibility::Private); + funcOp.setInliningPolicyAttr( + moduleBuilder.getAttr()); auto funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); // Remap any captured operands that have corresponding function arguments. @@ -521,8 +526,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp, // If we can't memoize the resources at initialization time then we need // to do it on-demand. if (!memoizeAnalysis.canRunAtInitializationTime()) { - memoizeOp.emitWarning( - "memoization failed: dynamic values captured at the call site"); + LLVM_DEBUG({ + llvm::dbgs() + << "memoization failed: dynamic values captured at the call site\n"; + memoizeOp.dump(); + }); replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp); return; } @@ -532,8 +540,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp, auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(memoizeOp.getDevice()); if (!deviceGlobals) { - memoizeOp.emitWarning("memoization failed: unable to analyze devices " - "that may be used with memoized region"); + LLVM_DEBUG({ + llvm::dbgs() << "memoization failed: unable to analyze devices that may " + "be used with memoized region\n"; + memoizeOp.dump(); + }); replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp); return; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp index 9690bca8b5ad..a71011b2b45c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp @@ -84,7 +84,6 @@ class DropUnusedCallsPass // Note that we want to remove entire chains of unused calls and run this // as a pattern application. RewritePatternSet patterns(&getContext()); - // patterns patterns.insert, EraseUnusedCallOp>( &getContext(), noSideEffectsSymbols); diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir index 03ebbc331527..9e30a2e12f8f 100644 --- a/tests/compiler_driver/streams.mlir +++ b/tests/compiler_driver/streams.mlir @@ -51,10 +51,10 @@ stream.executable private @executable_0 { } } } -// CHECK: vm.func private @simple_mul +// CHECK: vm.func private @__simple_mul_memoize_apply +// CHECK: vm.call.variadic @hal.command_buffer.dispatch func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %c4 = arith.constant 4 : index - // CHECK: vm.call.variadic @hal.command_buffer.dispatch %ret0 = flow.dispatch @executable_0::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %ret0 : tensor<4xf32> } @@ -98,10 +98,10 @@ stream.executable private @executable_1 { } } } -// CHECK: vm.func private @simple_mul_inplace +// CHECK: vm.func private @__simple_mul_inplace_memoize_apply +// CHECK: vm.call.variadic @hal.command_buffer.dispatch func.func @simple_mul_inplace(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %c4 = arith.constant 4 : index - // CHECK: vm.call.variadic @hal.command_buffer.dispatch %ret0 = flow.dispatch @executable_1::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> %arg0 return %ret0 : tensor<4xf32> }