diff --git a/compiler/plugins/target/ROCM/BUILD.bazel b/compiler/plugins/target/ROCM/BUILD.bazel index 7962cf8e6073..9692d1aafd26 100644 --- a/compiler/plugins/target/ROCM/BUILD.bazel +++ b/compiler/plugins/target/ROCM/BUILD.bazel @@ -39,6 +39,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/Utils", + "//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs", "//runtime/src/iree/schemas:executable_debug_info_c_fbs", "//runtime/src/iree/schemas:hip_executable_def_c_fbs", "@llvm-project//llvm:AMDGPUCodeGen", diff --git a/compiler/plugins/target/ROCM/CMakeLists.txt b/compiler/plugins/target/ROCM/CMakeLists.txt index 9430dca4fc16..938261acd14e 100644 --- a/compiler/plugins/target/ROCM/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/CMakeLists.txt @@ -64,6 +64,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils iree::compiler::PluginAPI iree::compiler::Utils + iree::schemas::amdgpu_executable_def_c_fbs iree::schemas::executable_debug_info_c_fbs iree::schemas::hip_executable_def_c_fbs PUBLIC diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index e384dd79c405..c860b630fe77 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -23,6 +23,7 @@ #include "iree/compiler/PluginAPI/Client.h" #include "iree/compiler/Utils/FlatbufferUtils.h" #include "iree/compiler/Utils/ToolUtils.h" +#include "iree/schemas/amdgpu_executable_def_builder.h" #include "iree/schemas/hip_executable_def_builder.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -54,7 +55,9 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { -struct ROCmOptions { +// TODO(#18792): rename flags back to iree-rocm- as they are not HIP-specific. +// Only iree-hip-legacy-sync applies uniquely to HIP. +struct ROCMOptions { std::string target = ""; std::string targetFeatures = ""; std::string bitcodeDirectory = getDefaultBitcodeDirectory(); @@ -196,45 +199,9 @@ static std::string translateModuleToISA(llvm::Module &module, } } // namespace -class ROCMTargetDevice final : public TargetDevice { -public: - ROCMTargetDevice(const ROCmOptions &options) : options(options) {} - - IREE::HAL::DeviceTargetAttr - getDefaultDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - - SmallVector deviceConfigAttrs; - if (options.legacySync) { - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), - b.getUnitAttr()); - } - auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); - - SmallVector executableConfigAttrs; - auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector executableTargetAttrs; - targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( - context, "rocm", executableConfigAttr, executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), - deviceConfigAttr, - executableTargetAttrs); - } - -private: - const ROCmOptions &options; -}; - class ROCMTargetBackend final : public TargetBackend { public: - ROCMTargetBackend(const ROCmOptions &options) : options(options) {} + ROCMTargetBackend(const ROCMOptions &options) : options(options) {} std::string getLegacyDefaultDeviceID() const override { return "hip"; } @@ -242,31 +209,43 @@ class ROCMTargetBackend final : public TargetBackend { MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr, SmallVectorImpl &executableTargetAttrs) const override { - if (auto target = getExecutableTarget(context)) + if (auto target = getExecutableTarget(deviceID, context)) { executableTargetAttrs.push_back(target); + } } IREE::HAL::ExecutableTargetAttr - getExecutableTarget(MLIRContext *context) const { + getExecutableTarget(StringRef deviceID, MLIRContext *context) const { Builder b(context); SmallVector configItems; auto addConfig = [&](StringRef name, Attribute value) { configItems.emplace_back(b.getStringAttr(name), value); }; - if (failed(options.verify(b))) + if (failed(options.verify(b))) { return nullptr; + } + + addConfig("abi", b.getStringAttr(deviceID)); + std::string format; + if (deviceID == "amdgpu") { + format = options.target; + } else { + format = "rocm-hsaco-fb"; // legacy HIP + } - if (auto target = GPU::getHIPTargetDetails(options.target, - options.targetFeatures, context)) + if (auto target = GPU::getHIPTargetDetails( + options.target, options.targetFeatures, context)) { addConfig("iree.gpu.target", target); + } addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels)); - if (options.wavesPerEu > 0) + if (options.wavesPerEu > 0) { addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu)); + } return b.getAttr( - b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"), + b.getStringAttr("rocm"), b.getStringAttr(format), b.getDictionaryAttr(configItems)); } @@ -356,9 +335,10 @@ class ROCMTargetBackend final : public TargetBackend { return success(); } - LogicalResult serializeExecutable(const SerializationOptions &serOptions, - IREE::HAL::ExecutableVariantOp variantOp, - OpBuilder &executableBuilder) override { + LogicalResult + serializeExecutable(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + OpBuilder &executableBuilder) override { ModuleOp innerModuleOp = variantOp.getInnerModule(); auto targetAttr = variantOp.getTargetAttr(); StringRef targetArch = options.target; @@ -552,18 +532,18 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".linked.ll", *llvmModule); } // Run LLVM optimization passes. optimizeModule(*llvmModule, *targetMachine, options.passPlugins, options.slpVectorization); - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".optimized.ll", *llvmModule); } @@ -572,7 +552,7 @@ class ROCMTargetBackend final : public TargetBackend { } // Dump the assembly output. - if (!serOptions.dumpIntermediatesPath.empty()) { + if (!serializationOptions.dumpIntermediatesPath.empty()) { auto moduleCopy = llvm::CloneModule(*llvmModule); if (!moduleCopy) { llvm::errs() << "Error: cloning LLVM IR failed\n"; @@ -580,9 +560,9 @@ class ROCMTargetBackend final : public TargetBackend { } std::string targetISA = translateModuleToISA(*moduleCopy.get(), *targetMachine); - dumpDataToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), ".rocmasm", - targetISA); + dumpDataToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".rocmasm", targetISA); } // Serialize hsaco kernel into the binary that we will embed in the @@ -593,23 +573,136 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpBinariesPath.empty()) { - dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName, - variantOp.getName(), ".hsaco", targetHSACO); + if (!serializationOptions.dumpBinariesPath.empty()) { + dumpDataToPath(serializationOptions.dumpBinariesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".hsaco", targetHSACO); + } + + // Wrap the HSACO ELF binary in a Flatbuffers container. + FailureOr binaryContainer; + if (targetAttr.getConfiguration() && + targetAttr.getConfiguration().getAs("abi") == "amdgpu") { + binaryContainer = serializeAMDGPUBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } else { + binaryContainer = serializeHIPBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } + if (failed(binaryContainer) || !binaryContainer.value()) { + return failure(); + } + + // Add the binary data to the target executable. + executableBuilder.create( + variantOp.getLoc(), variantOp.getSymName(), + variantOp.getTarget().getFormat(), binaryContainer.value()); + + return success(); + } + +protected: + FailureOr serializeAMDGPUBinaryContainer( + const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { + iree_compiler::FlatbufferBuilder builder; + iree_hal_amdgpu_ExecutableDef_start_as_root(builder); + + // Attach embedded source file contents. + auto sourceFilesRef = createSourceFilesVec( + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); + + // Only a single module today. + SmallVector moduleRefs; + { + auto hsacoImageRef = flatbuffers_string_create( + builder, hsacoModule.data(), hsacoModule.size()); + moduleRefs.push_back( + iree_hal_amdgpu_ModuleDef_create(builder, hsacoImageRef)); + } + auto modulesRef = builder.createOffsetVecDestructive(moduleRefs); + + // Generate optional per-export debug information. + // May be empty if no debug information was requested. + auto exportDebugInfos = + createExportDefs(serializationOptions.debugLevel, exportOps, builder); + + SmallVector exportRefs; + exportRefs.resize(exportOps.size(), 0); + for (auto exportOp : exportOps) { + auto ordinalAttr = exportOp.getOrdinalAttr(); + if (!ordinalAttr) { + return mlir::emitError(exportOp.getLoc()) + << "could not compile rocm binary: export op is missing ordinal"; + } + int64_t ordinal = ordinalAttr.getInt(); + + auto symbolNameRef = builder.createString(exportOp.getName()); + + iree_hal_amdgpu_Dims_t workgroupSize = {0}; + if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { + auto workgroupSizeDims = workgroupSizeAttr->getValue(); + workgroupSize.x = cast(workgroupSizeDims[0]).getInt(); + workgroupSize.y = cast(workgroupSizeDims[1]).getInt(); + workgroupSize.z = cast(workgroupSizeDims[2]).getInt(); + } + + auto layoutAttr = exportOp.getLayoutAttr(); + uint32_t constantCount = static_cast(layoutAttr.getConstants()); + SmallVector bindingFlags; + for (auto bindingAttr : layoutAttr.getBindings()) { + iree_hal_amdgpu_BindingBits_enum_t flags = 0; + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::ReadOnly)) { + flags |= iree_hal_amdgpu_BindingBits_READ_ONLY; + } + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::Indirect)) { + flags |= iree_hal_amdgpu_BindingBits_INDIRECT; + } + bindingFlags.push_back(flags); + } + auto bindingFlagsRef = iree_hal_amdgpu_BindingBits_vec_create( + builder, bindingFlags.data(), bindingFlags.size()); + + iree_hal_amdgpu_ExportDef_start(builder); + iree_hal_amdgpu_ExportDef_symbol_name_add(builder, symbolNameRef); + iree_hal_amdgpu_ExportDef_workgroup_size_add(builder, &workgroupSize); + iree_hal_amdgpu_ExportDef_constant_count_add(builder, constantCount); + iree_hal_amdgpu_ExportDef_binding_flags_add(builder, bindingFlagsRef); + iree_hal_amdgpu_ExportDef_debug_info_add(builder, + exportDebugInfos[ordinal]); + exportRefs[ordinal] = iree_hal_amdgpu_ExportDef_end(builder); } + auto exportsRef = builder.createOffsetVecDestructive(exportRefs); + + iree_hal_amdgpu_ExecutableDef_exports_add(builder, exportsRef); + iree_hal_amdgpu_ExecutableDef_modules_add(builder, modulesRef); + iree_hal_amdgpu_ExecutableDef_source_files_add(builder, sourceFilesRef); + iree_hal_amdgpu_ExecutableDef_end_as_root(builder); + return builder.getBufferAttr(variantOp.getContext()); + } + + FailureOr + serializeHIPBinaryContainer(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { iree_compiler::FlatbufferBuilder builder; iree_hal_hip_ExecutableDef_start_as_root(builder); // Attach embedded source file contents. auto sourceFilesRef = createSourceFilesVec( - serOptions.debugLevel, variantOp.getSourcesAttr(), builder); + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); // Only a single module today. SmallVector moduleRefs; { auto hsacoImageRef = flatbuffers_string_create( - builder, targetHSACO.c_str(), targetHSACO.size()); + builder, hsacoModule.data(), hsacoModule.size()); moduleRefs.push_back( iree_hal_hip_ModuleDef_create(builder, hsacoImageRef)); } @@ -618,7 +711,7 @@ class ROCMTargetBackend final : public TargetBackend { // Generate optional per-export debug information. // May be empty if no debug information was requested. auto exportDebugInfos = - createExportDefs(serOptions.debugLevel, exportOps, builder); + createExportDefs(serializationOptions.debugLevel, exportOps, builder); SmallVector exportRefs; exportRefs.resize(exportOps.size(), 0); @@ -682,27 +775,91 @@ class ROCMTargetBackend final : public TargetBackend { iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef); iree_hal_hip_ExecutableDef_end_as_root(builder); - // Add the binary data to the target executable. - executableBuilder.create( - variantOp.getLoc(), variantOp.getSymName(), - variantOp.getTarget().getFormat(), - builder.getBufferAttr(executableBuilder.getContext())); + return builder.getBufferAttr(variantOp.getContext()); + } - return success(); +private: + const ROCMOptions &options; +}; + +class AMDGPUTargetDevice final : public TargetDevice { +public: + AMDGPUTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "amdgpu", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("amdgpu"), + deviceConfigAttr, + executableTargetAttrs); + } + +private: + const ROCMOptions &options; +}; + +class HIPTargetDevice final : public TargetDevice { +public: + HIPTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + if (options.legacySync) { + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), + b.getUnitAttr()); + } + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "hip", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), + deviceConfigAttr, + executableTargetAttrs); } private: - const ROCmOptions &options; + const ROCMOptions &options; }; namespace { struct ROCMSession final - : PluginSession { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { + // #hal.device.target<"amdgpu", ... + targets.add("amdgpu", [&]() { + return std::make_shared(options); + }); // #hal.device.target<"hip", ... targets.add("hip", - [&]() { return std::make_shared(options); }); + [&]() { return std::make_shared(options); }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { // #hal.executable.target<"rocm", ... @@ -728,4 +885,4 @@ extern "C" bool iree_register_compiler_plugin_hal_target_rocm( return true; } -IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCmOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCMOptions); diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp index 7453af749b80..a1757afd75f1 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp @@ -35,7 +35,7 @@ loadIRModule(Location loc, const std::string &filename, diagnostic, *llvm_context)); if (!module) { - mlir::emitError(loc) << "error loading HIP LLVM module: " + mlir::emitError(loc) << "error loading ROCM LLVM module: " << diagnostic.getFilename().str() << ":" << diagnostic.getLineNo() << ":" << diagnostic.getColumnNo() << ": " @@ -90,7 +90,7 @@ static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker, auto setAlwaysInline = [&](llvm::Module &module) { if (targetMachine.getTargetCPU().contains("gfx10") || targetMachine.getTargetCPU().contains("gfx11")) { - // some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if + // Some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if // inlined. return; } diff --git a/runtime/src/iree/schemas/BUILD.bazel b/runtime/src/iree/schemas/BUILD.bazel index a8fbfcab8b12..e98a425424ee 100644 --- a/runtime/src/iree/schemas/BUILD.bazel +++ b/runtime/src/iree/schemas/BUILD.bazel @@ -20,6 +20,13 @@ FLATCC_ARGS = [ "--json", ] +iree_flatbuffer_c_library( + name = "amdgpu_executable_def_c_fbs", + srcs = ["amdgpu_executable_def.fbs"], + flatcc_args = FLATCC_ARGS, + includes = ["executable_debug_info.fbs"], +) + iree_flatbuffer_c_library( name = "bytecode_module_def_c_fbs", srcs = ["bytecode_module_def.fbs"], @@ -70,6 +77,7 @@ iree_flatbuffer_c_library( iree_build_test( name = "schema_build_test", targets = [ + ":amdgpu_executable_def_c_fbs", ":bytecode_module_def_c_fbs", ":cuda_executable_def_c_fbs", ":executable_debug_info_c_fbs", diff --git a/runtime/src/iree/schemas/CMakeLists.txt b/runtime/src/iree/schemas/CMakeLists.txt index 574b2cac4578..f30430df0789 100644 --- a/runtime/src/iree/schemas/CMakeLists.txt +++ b/runtime/src/iree/schemas/CMakeLists.txt @@ -10,6 +10,21 @@ iree_add_all_subdirs() +flatbuffer_c_library( + NAME + amdgpu_executable_def_c_fbs + SRCS + "amdgpu_executable_def.fbs" + FLATCC_ARGS + "--reader" + "--builder" + "--verifier" + "--json" + INCLUDES + "executable_debug_info.fbs" + PUBLIC +) + flatbuffer_c_library( NAME bytecode_module_def_c_fbs diff --git a/runtime/src/iree/schemas/amdgpu_executable_def.fbs b/runtime/src/iree/schemas/amdgpu_executable_def.fbs new file mode 100644 index 000000000000..43efdb0a34dc --- /dev/null +++ b/runtime/src/iree/schemas/amdgpu_executable_def.fbs @@ -0,0 +1,63 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include "iree/schemas/executable_debug_info.fbs"; + +namespace iree.hal.amdgpu; + +// 'AMDGPU v1 Executable'. +file_identifier "AMD1"; +file_extension "amd1"; + +// A struct for the kernel block size along each dimension. +struct Dims { + x:uint32; + y:uint32; + z:uint32; +} + +// Describes the behavior of each binding. +enum BindingBits:uint64 (bit_flags) { + READ_ONLY = 0, // 1u << 0 + INDIRECT = 1, // 1u << 1 +} + +// Information about an exported function on the executable. +table ExportDef { + // String name of the exported function symbol in the module. + symbol_name:string; + + // Workgroup size for the export. + workgroup_size:Dims; + + // Total number of 32-bit push constants used by the export. + constant_count:uint32; + + // Binding count and flags for each binding. + binding_flags:[BindingBits]; + + // Optional debug information related to the export. + debug_info:iree.hal.debug.ExportDef; +} + +// A library containing one or more exported functions. +table ModuleDef { + // AMD ELF image for loading an hsa_executable_t. + image:string; +} + +table ExecutableDef { + // Exported functions in canonical executable entry point order. + exports:[ExportDef]; + + // Modules containing executable code. + modules:[ModuleDef]; + + // Embedded source files sorted ascending by path. + source_files:[iree.hal.debug.SourceFileDef]; +} + +root_type ExecutableDef;