Skip to content

Commit

Permalink
Adding --iree-vulkan-experimental-indirect-bindings=true flag. (#14977
Browse files Browse the repository at this point in the history
)

This makes all descriptor set layouts have the new `Indirect` bit set
and plumbs it all the way through to the runtime
`IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_INDIRECT` bit. SPIR-V codegen can
inspect the pipeline layout attr of exports to discover which descriptor
sets are indirect and lower via `VK_KHR_buffer_device_address` and for
the runtime to specially handle the indirect descriptor sets by
producing device address buffers. The flag is currently experimental as
interop with non-indirect dispatches (custom/produced by other higher
layers like IREE input dialects/plugins) and multi-versioning (producing
both direct and indirect) are TBD. It should be sufficient for users
targeting specific Vulkan devices where they know the support is
present, though.

Note that while this is just the plumbing for the flag and the
IR/runtime bits nothing is either lowering differently or setting up the
appropriate runtime structures but it should allow codegen to start
experimenting with alternative lowerings.

Progress on #13945.
  • Loading branch information
benvanik authored Sep 20, 2023
1 parent 0f4dd73 commit fb9e1b6
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 52 deletions.
18 changes: 12 additions & 6 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def HAL_DescriptorFlagsAttr :
}

def HAL_DescriptorSetLayoutFlags_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_DescriptorSetLayoutFlags_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>;
def HAL_DescriptorSetLayoutFlags_Indirect : I32BitEnumAttrCase<"Indirect", 0x0001>;
def HAL_DescriptorSetLayoutFlagsAttr :
I32BitEnumAttr<"DescriptorSetLayoutFlags", "valid DescriptorSetLayout flags", [
HAL_DescriptorSetLayoutFlags_None,
HAL_DescriptorSetLayoutFlags_Reserved, // to make tblgen happy
HAL_DescriptorSetLayoutFlags_Indirect,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
Expand Down Expand Up @@ -614,12 +614,14 @@ def HAL_DescriptorSetLayoutAttr :
}];
let parameters = (ins
AttrParameter<"int64_t", "">:$ordinal,
ArrayRefParameter<"DescriptorSetBindingAttr", "">:$bindings
ArrayRefParameter<"DescriptorSetBindingAttr", "">:$bindings,
OptionalParameter<"std::optional<DescriptorSetLayoutFlags>">:$flags
);
let assemblyFormat = [{
`<`
$ordinal `,`
`bindings` `=` `[` $bindings `]`
(`,` `flags` `=` $flags^)?
`>`
}];
}
Expand Down Expand Up @@ -714,7 +716,7 @@ def HAL_DeviceTargetAttr :
bool hasConfigurationAttr(StringRef name);

// Returns zero or more executable targets that this device supports.
SmallVector<ExecutableTargetAttr, 4> getExecutableTargets();
SmallVector<IREE::HAL::ExecutableTargetAttr, 4> getExecutableTargets();

// Returns a list of target devices that may be active for the given
// operation. This will recursively walk parent operations until one with
Expand Down Expand Up @@ -752,7 +754,7 @@ def HAL_DeviceTargetAttr :

// Returns a list of all target executable configurations that may be
// required for the given operation.
static SmallVector<ExecutableTargetAttr, 4>
static SmallVector<IREE::HAL::ExecutableTargetAttr, 4>
lookupExecutableTargets(Operation *op);
}];
let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -807,6 +809,10 @@ def HAL_ExecutableTargetAttr :
// device that can load an executable of this target.
Attribute getMatchExpression();

// Returns true if there's an attribute with the given name in the
// configuration dictionary.
bool hasConfigurationAttr(StringRef name);

// Returns true if this attribute is a generic version of |specificAttr|.
// A more generic version will match with many specific versions.
bool isGenericOf(IREE::HAL::ExecutableTargetAttr specificAttr);
Expand All @@ -815,7 +821,7 @@ def HAL_ExecutableTargetAttr :
// This will recursively walk parent operations until one with the
// `hal.executable.target` attribute is found or a `hal.executable.variant`
// specifies a value. Returns nullptr if no target specification can be found.
static ExecutableTargetAttr lookup(Operation *op);
static IREE::HAL::ExecutableTargetAttr lookup(Operation *op);
}];

let hasCustomAssemblyFormat = 1;
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ Attribute ExecutableTargetAttr::getMatchExpression() {
return DeviceMatchExecutableFormatAttr::get(getContext(), getFormat());
}

bool ExecutableTargetAttr::hasConfigurationAttr(StringRef name) {
auto configAttr = getConfiguration();
return configAttr && configAttr.get(name);
}

// For now this is very simple: if there are any specified fields that are
// present in this attribute they must match. We could allow target backends
// to customize this via attribute interfaces in the future if we needed.
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,31 @@ operator<<(AsmPrinter &printer,
return printer;
}

template <>
struct FieldParser<
std::optional<mlir::iree_compiler::IREE::HAL::DescriptorSetLayoutFlags>> {
static FailureOr<mlir::iree_compiler::IREE::HAL::DescriptorSetLayoutFlags>
parse(AsmParser &parser) {
std::string value;
if (parser.parseKeywordOrString(&value))
return failure();
auto result = mlir::iree_compiler::IREE::HAL::symbolizeEnum<
mlir::iree_compiler::IREE::HAL::DescriptorSetLayoutFlags>(value);
if (!result.has_value())
return failure();
return result.value();
}
};
static inline AsmPrinter &operator<<(
AsmPrinter &printer,
std::optional<mlir::iree_compiler::IREE::HAL::DescriptorSetLayoutFlags>
param) {
printer << (param.has_value()
? mlir::iree_compiler::IREE::HAL::stringifyEnum(param.value())
: StringRef{""});
return printer;
}

template <>
struct FieldParser<
std::optional<mlir::iree_compiler::IREE::HAL::DescriptorFlags>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,26 @@ VulkanSPIRVTargetOptions getVulkanSPIRVTargetOptionsFromFlags() {
// "IREE Vulkan/SPIR-V backend options");

static llvm::cl::opt<std::string> clVulkanTargetTriple(
"iree-vulkan-target-triple", llvm::cl::desc("Vulkan target triple"),
"iree-vulkan-target-triple",
llvm::cl::desc(
"Vulkan target triple controlling the SPIR-V environment."),
llvm::cl::init("unknown-unknown-unknown"));

static llvm::cl::opt<std::string> clVulkanTargetEnv(
"iree-vulkan-target-env",
llvm::cl::desc(
"Vulkan target environment as #vk.target_env attribute assembly"),
"Vulkan target environment as #vk.target_env attribute assembly."),
llvm::cl::init(""));

static llvm::cl::opt<bool> clVulkanIndirectBindings(
"iree-vulkan-experimental-indirect-bindings",
llvm::cl::desc("Force indirect bindings for all generated dispatches."),
llvm::cl::init(false));

VulkanSPIRVTargetOptions targetOptions;
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
targetOptions.vulkanTargetTriple = clVulkanTargetTriple;
targetOptions.targetEnv = clVulkanTargetEnv;
targetOptions.targetTriple = clVulkanTargetTriple;
targetOptions.indirectBindings = clVulkanIndirectBindings;

return targetOptions;
}
Expand Down Expand Up @@ -291,23 +299,30 @@ class VulkanSPIRVTargetBackend : public TargetBackend {
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
targetAttrs.push_back(getExecutableTarget(
context, getSPIRVTargetEnv(options_.vulkanTargetEnv,
options_.vulkanTargetTriple, context)));
context,
getSPIRVTargetEnv(options_.targetEnv, options_.targetTriple, context),
options_.indirectBindings));
return ArrayAttr::get(context, targetAttrs);
}

IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context,
spirv::TargetEnvAttr targetEnv) const {
getExecutableTarget(MLIRContext *context, spirv::TargetEnvAttr targetEnv,
bool indirectBindings) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr(spirv::getTargetEnvAttrName()),
targetEnv);
if (indirectBindings) {
configItems.emplace_back(b.getStringAttr("hal.bindings.indirect"),
UnitAttr::get(context));
}

auto configAttr = b.getDictionaryAttr(configItems);
return IREE::HAL::ExecutableTargetAttr::get(
context, b.getStringAttr("vulkan"), b.getStringAttr("vulkan-spirv-fb"),
context, b.getStringAttr("vulkan"),
indirectBindings ? b.getStringAttr("vulkan-spirv-fb-ptr")
: b.getStringAttr("vulkan-spirv-fb"),
configAttr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ namespace HAL {
// Options controlling the SPIR-V translation.
struct VulkanSPIRVTargetOptions {
// Vulkan target environment as #vk.target_env attribute assembly.
std::string vulkanTargetEnv;
std::string targetEnv;
// Vulkan target triple.
std::string vulkanTargetTriple;
std::string targetTriple;
// Whether to use indirect bindings for all generated dispatches.
bool indirectBindings = false;
};

// Returns a VulkanSPIRVTargetOptions struct initialized with Vulkan/SPIR-V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ static LogicalResult verifyEntryPointTypes(mlir::func::FuncOp entryFuncOp) {
// Creates an pipeline layout attr from the analysis results.
static IREE::HAL::PipelineLayoutAttr
makePipelineLayoutAttr(const PipelineLayout &pipelineLayout,
IREE::HAL::ExecutableTargetAttr targetAttr,
OpBuilder &builder) {
SmallVector<IREE::HAL::DescriptorSetLayoutAttr> setLayoutAttrs;
for (const auto &setLayout : pipelineLayout.setLayouts) {
Expand All @@ -181,8 +182,12 @@ makePipelineLayoutAttr(const PipelineLayout &pipelineLayout,
? binding.flags
: std::optional<IREE::HAL::DescriptorFlags>{}));
}
std::optional<IREE::HAL::DescriptorSetLayoutFlags> flags;
if (targetAttr.hasConfigurationAttr("hal.bindings.indirect")) {
flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
builder.getContext(), setLayout.ordinal, bindingAttrs));
builder.getContext(), setLayout.ordinal, bindingAttrs, flags));
}
return IREE::HAL::PipelineLayoutAttr::get(
builder.getContext(), pipelineLayout.pushConstantCount, setLayoutAttrs);
Expand Down Expand Up @@ -312,8 +317,9 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());

// Build a map of source function definitions to their version with the
// updated interface.
DenseMap<Operation *, Operation *> targetFuncOps;
// updated interface per variant.
DenseMap<Operation *, DenseMap<IREE::HAL::ExecutableVariantOp, Operation *>>
targetFuncOps;
int nextOrdinal = 0;
for (auto exportOp : sourceExecutableOp.getBody()
.getOps<IREE::Stream::ExecutableExportOp>()) {
Expand All @@ -325,7 +331,6 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
// Create the interface for this entry point based on the analysis of its
// usage within the program.
const auto &pipelineLayout = layoutAnalysis.getPipelineLayout(exportOp);
auto layoutAttr = makePipelineLayoutAttr(pipelineLayout, executableBuilder);

// Update all dispatch sites with the binding information required for
// conversion into the HAL dialect. By doing this here we ensure that the
Expand All @@ -338,7 +343,6 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
// Clone the updated function declaration into each variant.
int ordinal = nextOrdinal++;
for (auto variantOp : variantOps) {
// Declare the entry point on the target.
OpBuilder targetBuilder(variantOp.getInnerModule());
// Check if workgroup size is set externally.
ArrayAttr workgroupSize;
Expand All @@ -356,6 +360,10 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
break;
}
}

// Declare the entry point on the target.
auto layoutAttr = makePipelineLayoutAttr(
pipelineLayout, variantOp.getTargetAttr(), targetBuilder);
auto newExportOp = targetBuilder.create<IREE::HAL::ExecutableExportOp>(
exportOp.getLoc(),
targetBuilder.getStringAttr(exportOp.getFunctionRef()),
Expand All @@ -380,39 +388,37 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
newExportOp.getWorkgroupCount().insertArgument(0u, deviceType,
newExportOp.getLoc());
}
}

// Clone the source function and update it to use the new interface.
auto targetFuncOp =
cloneFuncWithInterface(sourceFuncOp, pipelineLayout, layoutAttr);
targetFuncOps[sourceFuncOp] = targetFuncOp;
// Clone the source function and update it to use the new interface.
auto variantFuncOp =
cloneFuncWithInterface(sourceFuncOp, pipelineLayout, layoutAttr);
targetFuncOps[sourceFuncOp][variantOp] = variantFuncOp;
}
}

// Clone all of the ops in the source module to each variant.
// We'll use the exported functions with the updated interfaces in place of
// the original versions and copy everything else verbatim.
// Note that we do this as a cleanup setup because there may be multiple
// functions and multiple exports (with an N:M mapping) and in this way we
// perform the variant construction in a single pass with deterministic
// ordering that preserves the unmodified ops.
for (auto variantOp : variantOps) {
auto targetBuilder = OpBuilder::atBlockBegin(
&variantOp.getInnerModule().getBodyRegion().front());
for (auto &op : sourceModuleOp.getOps()) {
auto targetFuncOp = targetFuncOps.find(&op);
if (targetFuncOp != targetFuncOps.end()) {
// Clone the updated function instead of the original.
targetBuilder.clone(*targetFuncOp->second);
auto targetVariantFuncOps = targetFuncOps.find(&op);
if (targetVariantFuncOps != targetFuncOps.end()) {
// Move the updated function into place.
auto variantFuncOp = targetVariantFuncOps->second[variantOp];
targetBuilder.insert(variantFuncOp);
} else {
// Regular op (globals, external function declarations, etc).
targetBuilder.clone(op);
}
}
}

// Drop the temporary target functions. We could avoid an additional clone if
// we only had one variant but this is relatively small in cost (once per
// variant).
for (auto it : targetFuncOps)
it.second->erase();
targetFuncOps.clear();

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ class MaterializeResourceCachesPass
}

private:
IREE::Util::GlobalOp defineDescriptorSetLayoutOp(Location loc,
ArrayAttr bindingAttrs) {
auto existingIt = descriptorSetLayoutCache_.find(bindingAttrs);
IREE::Util::GlobalOp
defineDescriptorSetLayoutOp(Location loc, ArrayAttr bindingAttrs,
IREE::HAL::DescriptorSetLayoutFlags flags) {
std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags> key = {
bindingAttrs, flags};
auto existingIt = descriptorSetLayoutCache_.find(key);
if (existingIt != descriptorSetLayoutCache_.end()) {
return existingIt->second;
}
Expand All @@ -134,15 +137,14 @@ class MaterializeResourceCachesPass
loc, symbolName,
/*isMutable=*/false, layoutType);
globalOp.setPrivate();
descriptorSetLayoutCache_.try_emplace(bindingAttrs, globalOp);
descriptorSetLayoutCache_.try_emplace(key, globalOp);

auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
OpBuilder blockBuilder =
OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
auto layoutFlags = IREE::HAL::DescriptorSetLayoutFlags::None;
auto layoutValue = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
loc, layoutType, deviceValue, layoutFlags, bindingAttrs);
loc, layoutType, deviceValue, flags, bindingAttrs);
blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layoutValue,
globalOp.getName());
blockBuilder.create<IREE::Util::InitializerReturnOp>(loc);
Expand All @@ -167,7 +169,9 @@ class MaterializeResourceCachesPass
bindingAttrs.push_back(bindingAttr);
}
setLayoutGlobalOps.push_back(defineDescriptorSetLayoutOp(
loc, ArrayAttr::get(loc.getContext(), bindingAttrs)));
loc, ArrayAttr::get(loc.getContext(), bindingAttrs),
setLayoutAttr.getFlags().value_or(
IREE::HAL::DescriptorSetLayoutFlags::None)));
}

auto symbolName = (StringRef("_pipeline_layout_") +
Expand Down Expand Up @@ -319,8 +323,8 @@ class MaterializeResourceCachesPass
void
replaceDescriptorSetLayoutLookupOp(DescriptorSetLayoutLookupOp &lookupOp) {
OpBuilder builder(lookupOp);
auto globalOp =
defineDescriptorSetLayoutOp(lookupOp.getLoc(), lookupOp.getBindings());
auto globalOp = defineDescriptorSetLayoutOp(
lookupOp.getLoc(), lookupOp.getBindings(), lookupOp.getFlags());
auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
lookupOp.getLoc(), DescriptorSetLayoutType::get(lookupOp.getContext()),
globalOp.getSymName());
Expand Down Expand Up @@ -355,7 +359,9 @@ class MaterializeResourceCachesPass
TargetOptions targetOptions_;

OpBuilder moduleBuilder{static_cast<MLIRContext *>(nullptr)};
DenseMap<Attribute, IREE::Util::GlobalOp> descriptorSetLayoutCache_;
DenseMap<std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags>,
IREE::Util::GlobalOp>
descriptorSetLayoutCache_;
DenseMap<Attribute, IREE::Util::GlobalOp> pipelineLayoutCache_;
DenseMap<StringRef, IREE::Util::GlobalOp> executableCache_;

Expand Down
Loading

0 comments on commit fb9e1b6

Please sign in to comment.