Skip to content

Commit

Permalink
Adding a flag to force indirect command buffers on in non-reusable ca…
Browse files Browse the repository at this point in the history
…ses.
  • Loading branch information
benvanik committed Oct 30, 2024
1 parent 26ba4fd commit 4e1aadf
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 17 deletions.
14 changes: 12 additions & 2 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ struct ROCMOptions {
}
};

// Returns the ABI or an empty string if unspecified.
static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) {
if (targetAttr) {
if (auto config = targetAttr.getConfiguration()) {
auto abiAttr = targetAttr.getConfiguration().getAs<StringAttr>("abi");
return abiAttr ? abiAttr.getValue() : "";
}
}
return "";
}

static void dumpModuleToPath(StringRef path, StringRef baseName,
StringRef suffix, StringRef extension,
llvm::Module &module) {
Expand Down Expand Up @@ -585,8 +596,7 @@ class ROCMTargetBackend final : public TargetBackend {

// Wrap the HSACO ELF binary in a Flatbuffers container.
FailureOr<DenseIntElementsAttr> binaryContainer;
if (targetAttr.getConfiguration() &&
targetAttr.getConfiguration().getAs<StringAttr>("abi") == "amdgpu") {
if (getABI(targetAttr) == "amdgpu") {
binaryContainer = serializeAMDGPUBinaryContainer(
serializationOptions, variantOp, exportOps, targetHSACO);
} else {
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,12 @@ Error *Invocation::outputHALExecutable(Output &output) {
return new Error("not a valid HAL executable");
}
auto binaryOp = binaryOps.front();
auto rawData = binaryOp.getData().getRawData();
output.outputStream->write(rawData.data(), rawData.size());
if (failed(cast<IREE::Util::SerializableAttrInterface>(binaryOp.getData())
.serializeToStream(binaryOp.getLoc(), llvm::endianness::little,
*output.outputStream))) {
return new Error(
"data attribute failed to serialize: unsupported format or encoding");
}
output.outputStream->flush();
return output.getWriteError();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ static llvm::cl::opt<bool> clIndirectCommandBuffers{
llvm::cl::init(true),
};

// TODO(benvanik): remove when we support capturing dynamic values for reuse.
static llvm::cl::opt<bool> clForceIndirectCommandBuffers{
"iree-hal-force-indirect-command-buffers",
llvm::cl::desc("Forces indirect command buffers when they would otherwise "
"not be chosen due to the values they capture. They may not "
"be reusable but will still be outlined."),
llvm::cl::init(false),
};

struct ContextResolveOpPattern
: public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
using StreamConversionPattern::StreamConversionPattern;
Expand Down Expand Up @@ -1002,7 +1011,9 @@ struct CmdExecuteOpPattern
// changes dispatches to use them for any dispatch we can - note that there
// may still be some that slip through due to custom executables.
const bool capturesDynamicUniformValues =
regionCapturesDynamicUniformValues(executeOp);
clForceIndirectCommandBuffers
? false
: regionCapturesDynamicUniformValues(executeOp);

// Calculate the indirect buffer references used within the command buffer
// by analyzing captured resources. This analysis will be used by subsequent
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def HAL_Ordinal : TypeAlias<Index>;
def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">;
def HAL_OrdinalArrayAttr : TypedArrayAttrBase<HAL_OrdinalAttr, "Array of index ordinal attributes">;

def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>;

def HAL_ElementType : TypeAlias<I32>;
def HAL_ElementTypeAttr : SignlessIntegerAttrBase<
I32, "element type attribute">;
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2604,7 +2604,7 @@ def HAL_ExecutableBinaryOp : HAL_Op<"executable.binary", [
OptionalAttr<StrAttr>:$sym_visibility,
SymbolNameAttr:$sym_name,
StrAttr:$format,
HAL_ExecutableDataAttr:$data,
Util_AnySerializableAttr:$data,
OptionalAttr<StrAttr>:$mime_type
// TODO(benvanik): add compatibility and versioning attributes.
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Utils/StringUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -24,6 +25,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"

#define DEBUG_TYPE "iree-hal-outline-memoize-regions"

namespace mlir::iree_compiler::IREE::HAL {

#define GEN_PASS_DEF_OUTLINEMEMOIZEREGIONSPASS
Expand Down Expand Up @@ -153,6 +156,8 @@ static IREE::Util::FuncOp outlineMemoizeRegionBody(
name, funcType);
moduleSymbolTable.insert(funcOp);
funcOp.setVisibility(SymbolTable::Visibility::Private);
funcOp.setInliningPolicyAttr(
moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
auto funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock());

// Remap any captured operands that have corresponding function arguments.
Expand Down Expand Up @@ -521,8 +526,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp,
// If we can't memoize the resources at initialization time then we need
// to do it on-demand.
if (!memoizeAnalysis.canRunAtInitializationTime()) {
memoizeOp.emitWarning(
"memoization failed: dynamic values captured at the call site");
LLVM_DEBUG({
llvm::dbgs()
<< "memoization failed: dynamic values captured at the call site\n";
memoizeOp.dump();
});
replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
return;
}
Expand All @@ -532,8 +540,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp,
auto deviceGlobals =
deviceAnalysis.lookupDeviceGlobals(memoizeOp.getDevice());
if (!deviceGlobals) {
memoizeOp.emitWarning("memoization failed: unable to analyze devices "
"that may be used with memoized region");
LLVM_DEBUG({
llvm::dbgs() << "memoization failed: unable to analyze devices that may "
"be used with memoized region\n";
memoizeOp.dump();
});
replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class DropUnusedCallsPass
// Note that we want to remove entire chains of unused calls and run this
// as a pattern application.
RewritePatternSet patterns(&getContext());
// patterns
patterns.insert<EraseUnusedCallOp<IREE::VM::CallOp>,
EraseUnusedCallOp<IREE::VM::CallVariadicOp>>(
&getContext(), noSideEffectsSymbols);
Expand Down
8 changes: 4 additions & 4 deletions tests/compiler_driver/streams.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ stream.executable private @executable_0 {
}
}
}
// CHECK: vm.func private @simple_mul
// CHECK: vm.func private @__simple_mul_memoize_apply
// CHECK: vm.call.variadic @hal.command_buffer.dispatch
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index
// CHECK: vm.call.variadic @hal.command_buffer.dispatch
%ret0 = flow.dispatch @executable_0::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %ret0 : tensor<4xf32>
}
Expand Down Expand Up @@ -98,10 +98,10 @@ stream.executable private @executable_1 {
}
}
}
// CHECK: vm.func private @simple_mul_inplace
// CHECK: vm.func private @__simple_mul_inplace_memoize_apply
// CHECK: vm.call.variadic @hal.command_buffer.dispatch
func.func @simple_mul_inplace(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index
// CHECK: vm.call.variadic @hal.command_buffer.dispatch
%ret0 = flow.dispatch @executable_1::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> %arg0
return %ret0 : tensor<4xf32>
}
Expand Down

0 comments on commit 4e1aadf

Please sign in to comment.