diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 9b4705e2c9c1..2175db72ba82 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -22,7 +22,6 @@ #include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h" #include "iree/compiler/PluginAPI/Client.h" #include "iree/compiler/Utils/FlatbufferUtils.h" -#include "iree/compiler/Utils/ModuleUtils.h" #include "iree/compiler/Utils/ToolUtils.h" #include "iree/schemas/hip_executable_def_builder.h" #include "llvm/ADT/StringExtras.h" @@ -37,17 +36,12 @@ #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" @@ -136,68 +130,6 @@ struct ROCmOptions { } }; -// Extracts the amdgpu chipset version from the chip architecture in the -// executable target attribute. -static FailureOr -getChipsetVersion(ExecutableTargetAttr targetAttr) { - IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(targetAttr); - if (!gpuTarget) - return failure(); - - return amdgpu::Chipset::parse(gpuTarget.getArch()); -} - -// Set attributes on `funcOp` in order to use upstream's translation of -// ROCDL dialect attributes to LLVM. Primarily this is `rocdl.kernel` -// (sets the calling convention and workgroup size uniformity) but this will -// also set both forms of workgroup size metadata from `exportOp` (if it is set) -// and will set the waves_per_eq flag where relevant. Finally, it will mark -// kernel arguments `inreg` to enable argument preloading on supported -// architectures. -static void annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, - ExecutableExportOp exportOp, - ExecutableTargetAttr targetAttr, - OpBuilder &builder) { - auto *rocdlDialect = - funcOp.getContext()->getLoadedDialect(); - UnitAttr unitAttr = builder.getUnitAttr(); - rocdlDialect->getKernelAttrHelper().setAttr(funcOp, unitAttr); - std::optional workgroupSizeAttr = exportOp.getWorkgroupSize(); - if (workgroupSizeAttr && workgroupSizeAttr->size() <= 3) { - std::array wgSizes; - int32_t flatWgSize = 1; - for (auto [value, attr] : llvm::zip_equal( - wgSizes, workgroupSizeAttr->getAsRange())) { - value = attr.getInt(); - flatWgSize *= value; - } - rocdlDialect->getReqdWorkGroupSizeAttrHelper().setAttr( - funcOp, builder.getDenseI32ArrayAttr(wgSizes)); - rocdlDialect->getFlatWorkGroupSizeAttrHelper().setAttr( - funcOp, - builder.getStringAttr(Twine(flatWgSize) + "," + Twine(flatWgSize))); - } - - if (std::optional attr = - getConfigIntegerAttr(targetAttr, "waves_per_eu")) { - rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, *attr); - } - - // Kernel argument preloading is only supported on gfx940 and newer targets - // from the CDNA family. This is enabled using the `inreg` function argument - // attribute. - FailureOr chipset = getChipsetVersion(targetAttr); - if (failed(chipset)) - return; - if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) - return; - - auto inRegAttrName = - builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName()); - for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) - funcOp.setArgAttr(i, inRegAttrName, unitAttr); -} - static void dumpModuleToPath(StringRef path, StringRef baseName, StringRef suffix, StringRef extension, llvm::Module &module) { @@ -318,8 +250,6 @@ class ROCMTargetBackend final : public TargetBackend { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); - registry.insert(); } void @@ -407,11 +337,8 @@ class ROCMTargetBackend final : public TargetBackend { // Collect all the entry point names. auto exportOps = llvm::to_vector_of( variantOp.getExportOps()); - llvm::StringMap exportOpMap; std::optional subgroupSize; for (IREE::HAL::ExecutableExportOp exportOp : exportOps) { - exportOpMap[exportOp.getSymName()] = exportOp; - // TODO: put this either on the variant or propagate as a function // attribute instead - today this *must* be consistent across all exports // and it shouldn't need to be. @@ -436,7 +363,9 @@ class ROCMTargetBackend final : public TargetBackend { if (!variantOp.getObjects().has_value()) { return variantOp.emitOpError() << "no objects defined for external variant"; - } else if (variantOp.getObjects()->getValue().size() != 1) { + } + + if (variantOp.getObjects()->getValue().size() != 1) { // For now we assume there will be exactly one object file. // In the future we will want to perform a linking step here and ideally // support _also_ linking in the codegen results. @@ -457,17 +386,6 @@ class ROCMTargetBackend final : public TargetBackend { // Perform the translation in a separate context to avoid any // multi-threading issues. llvm::LLVMContext context; - - // Set up attributes so upstream's conversions work right. - for (auto func : innerModuleOp.getOps()) { - // Un-exported functions are library functions or otherwise - // not kernels, so don't need these annotations. - if (!exportOpMap.contains(func.getName())) - continue; - annotateKernelForTranslation(func, exportOpMap[func.getName()], - targetAttr, executableBuilder); - } - std::unique_ptr llvmModule = mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName); if (!llvmModule) { @@ -486,10 +404,10 @@ class ROCMTargetBackend final : public TargetBackend { for (NamedAttribute funcAttr : funcAttrs) { auto value = dyn_cast(funcAttr.getValue()); if (!value) { - return variantOp->emitError("llvm_func_attrs attribute must be " - "adictionary of strings. Attribute " + - llvm::Twine(funcAttr.getName()) + - " is not a StringAttr."); + return variantOp->emitError() + << "llvm_func_attrs attribute must be a dictionary of " + "strings. Attribute " + << funcAttr.getName() << " is not a StringAttr."; } llvmFunc->addFnAttr(funcAttr.getName(), value.getValue()); } diff --git a/compiler/plugins/target/ROCM/test/external_function_validation.mlir b/compiler/plugins/target/ROCM/test/external_function_validation.mlir index dcebcce888c5..b84455bd793d 100644 --- a/compiler/plugins/target/ROCM/test/external_function_validation.mlir +++ b/compiler/plugins/target/ROCM/test/external_function_validation.mlir @@ -31,7 +31,7 @@ builtin.module { } builtin.module { llvm.func @external_func() attributes {sym_visibility = "private"} - llvm.func @test() { + llvm.func @test() attributes { rocdl.kernel } { llvm.call @external_func() : () -> () llvm.return } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index c6fefa699e2d..f8fd1f02cb9b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -107,6 +107,7 @@ iree_compiler_cc_library( "LLVMGPUVectorLowering.cpp", "LLVMGPUVectorToGPU.cpp", "Passes.cpp", + "ROCDLAnnotateKernelForTranslation.cpp", "ROCDLKernelConfig.cpp", "ROCDLLowerExecutableTarget.cpp", "ROCDLSelectLoweringStrategy.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index aaee76a02ea3..14d2825a9b11 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -92,6 +92,7 @@ iree_cc_library( "LLVMGPUVectorLowering.cpp" "LLVMGPUVectorToGPU.cpp" "Passes.cpp" + "ROCDLAnnotateKernelForTranslation.cpp" "ROCDLKernelConfig.cpp" "ROCDLLowerExecutableTarget.cpp" "ROCDLSelectLoweringStrategy.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index b63dfa7e111d..15518e94bb7b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -33,6 +33,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -1099,6 +1100,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, if (forROCDL) { // convert to ROCDL. modulePassManager.addPass(createConvertToROCDLPass()); + modulePassManager.addNestedPass( + createROCDLAnnotateKernelForTranslationPass()); } else { // convert to NVVM. modulePassManager.addPass(createConvertToNVVMPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp new file mode 100644 index 000000000000..0a6eca5cfe1f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp @@ -0,0 +1,130 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include "iree/compiler/Codegen/Common/PassUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_ROCDLANNOTATEKERNELFORTRANSLATIONPASS +#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h.inc" + +namespace { +// Extracts the amdgpu chipset version from the chip architecture in the +// executable target attribute. +static FailureOr +getChipsetVersion(IREE::HAL::ExecutableTargetAttr targetAttr) { + IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(targetAttr); + assert(gpuTarget); + return amdgpu::Chipset::parse(gpuTarget.getArch()); +} + +// Set attributes on `funcOp` in order to use upstream's translation of +// ROCDL dialect attributes to LLVM. Primarily this is `rocdl.kernel` +// (sets the calling convention and workgroup size uniformity) but this will +// also set both forms of workgroup size metadata from `exportOp` (if it is set) +// and will set the waves_per_eq flag where relevant. Finally, it will mark +// kernel arguments `inreg` to enable argument preloading on supported +// architectures. +static LogicalResult +annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, + IREE::HAL::ExecutableVariantOp variantOp, + IREE::HAL::ExecutableExportOp exportOp) { + OpBuilder builder(funcOp); + auto *rocdlDialect = + funcOp.getContext()->getLoadedDialect(); + assert(rocdlDialect && "ROCDL dialect not loaded"); + UnitAttr unitAttr = builder.getUnitAttr(); + rocdlDialect->getKernelAttrHelper().setAttr(funcOp, unitAttr); + std::optional workgroupSizeAttr = exportOp.getWorkgroupSize(); + if (workgroupSizeAttr && workgroupSizeAttr->size() <= 3) { + std::array wgSizes; + int32_t flatWgSize = 1; + for (auto [value, attr] : llvm::zip_equal( + wgSizes, workgroupSizeAttr->getAsRange())) { + value = attr.getInt(); + flatWgSize *= value; + } + rocdlDialect->getReqdWorkGroupSizeAttrHelper().setAttr( + funcOp, builder.getDenseI32ArrayAttr(wgSizes)); + rocdlDialect->getFlatWorkGroupSizeAttrHelper().setAttr( + funcOp, + builder.getStringAttr(Twine(flatWgSize) + "," + Twine(flatWgSize))); + } + + IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTarget(); + if (std::optional attr = + getConfigIntegerAttr(targetAttr, "waves_per_eu")) { + rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, *attr); + } + + // Kernel argument preloading is only supported on gfx940 and newer targets + // from the CDNA family. This is enabled using the `inreg` function argument + // attribute. + FailureOr chipset = getChipsetVersion(targetAttr); + if (failed(chipset)) + return variantOp.emitError() << "failed to parse amdgpu chipset"; + + if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) + return success(); + + auto inRegAttrName = + builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName()); + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) + funcOp.setArgAttr(i, inRegAttrName, unitAttr); + + return success(); +} + +/// Lowers an IREE hal.executable.variant operation using a suitable pass +/// pipeline. +struct ROCDLAnnotateKernelForTranslationPass final + : impl::ROCDLAnnotateKernelForTranslationPassBase< + ROCDLAnnotateKernelForTranslationPass> { + void runOnOperation() override { + LLVM::LLVMFuncOp funcOp = getOperation(); + StringRef funcName = funcOp.getName(); + + auto variantOp = funcOp->getParentOfType(); + if (!variantOp) { + funcOp.emitError() << "cannot find parent hal.executable.variant op"; + return signalPassFailure(); + } + + IREE::HAL::ExecutableExportOp exportOp; + // Try to find the matching executable export op. + for (IREE::HAL::ExecutableExportOp candidate : variantOp.getExportOps()) { + if (candidate.getSymName() == funcName) { + exportOp = candidate; + break; + } + } + + // Un-exported functions are library functions or otherwise not kernels, so + // don't need these annotations. + if (!exportOp) + return; + + if (failed(annotateKernelForTranslation(funcOp, variantOp, exportOp))) { + return signalPassFailure(); + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h index 696a6c0fab19..fe7320ce087f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h @@ -8,6 +8,7 @@ #define IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLPASSES_H_ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" namespace mlir::iree_compiler { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td index bf91b6ebd084..b13d2a5efae3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td @@ -13,6 +13,12 @@ include "mlir/Pass/PassBase.td" // ROCDL Passes (keep alphabetical) //===----------------------------------------------------------------------===// +def ROCDLAnnotateKernelForTranslationPass : Pass< + "iree-rocdl-annotate-kernel-for-translation", "LLVM::LLVMFuncOp"> { + let summary = "Set function attributes before translating to LLVM IR"; + let dependentDialects = ["ROCDL::ROCDLDialect"]; +} + def ROCDLLowerExecutableTargetPass : InterfacePass< "iree-rocdl-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Lower an IREE hal.executable.variant op using a suitable " @@ -25,5 +31,4 @@ def ROCDLSelectLoweringStrategyPass : "hal.executable.variant op"; } - #endif // IREE_CODEGEN_LLVMGPU_ROCDLPASSES diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 4c25ee453397..5789cefb0fc9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "annotate_kernel_for_translation.mlir", "config_tile_and_fuse.mlir", "config_vector_distribute.mlir", "config_user_vector_distribute.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index fb1d8edef4e7..e843564040fc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "annotate_kernel_for_translation.mlir" "config_tile_and_fuse.mlir" "config_user_vector_distribute.mlir" "config_vector_distribute.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir new file mode 100644 index 000000000000..825b28cbc028 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -0,0 +1,122 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(llvm.func(iree-rocdl-annotate-kernel-for-translation)))))' \ +// RUN: --split-input-file %s | FileCheck %s + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", + {iree.gpu.target = #iree_gpu.target>, + ukernels = "none"}> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test ordinal(0) layout(#pipeline_layout) + attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} { + ^bb0(%arg0: !hal.device): + %c128 = arith.constant 128 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + hal.return %c128, %c2, %c1 : index, index, index + } + builtin.module { + llvm.func @test() { + llvm.return + } + llvm.func @test_not_exported() { + llvm.return + } + } + } + } +} + +// CHECK-LABEL: llvm.func @test() attributes { +// CHECK-SAME: rocdl.flat_work_group_size = "256,256" +// CHECK-SAME: rocdl.kernel +// CHECK-SAME: rocdl.reqd_work_group_size = array +// +// CHECK-LABEL: llvm.func @test_not_exported() { + +// ----- + +// Check that we annotate kernel arguments on gfx940-series. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", + {iree.gpu.target = #iree_gpu.target>, + ukernels = "none"}> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_kern_arg { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_kern_arg ordinal(0) layout(#pipeline_layout) + attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} { + ^bb0(%arg0: !hal.device): + %c128 = arith.constant 128 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + hal.return %c128, %c2, %c1 : index, index, index + } + builtin.module { + llvm.func @test_kern_arg(%arg0: i32) { + llvm.return + } + } + } + } +} + +// CHECK-LABEL: llvm.func @test_kern_arg +// CHECK-SAME: (%{{.+}}: i32 {llvm.inreg}) + +// ----- + +// Check that we *do not* annotate kernel arguments on gfx90a (not supported by the firmware). + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", + {iree.gpu.target = #iree_gpu.target>, + ukernels = "none"}> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_no_kern_arg { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_no_kern_arg ordinal(0) layout(#pipeline_layout) + attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} { + ^bb0(%arg0: !hal.device): + %c128 = arith.constant 128 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + hal.return %c128, %c2, %c1 : index, index, index + } + builtin.module { + llvm.func @test_no_kern_arg(%arg0: i32) { + llvm.return + } + } + } + } +} + +// CHECK-LABEL: llvm.func @test_no_kern_arg +// CHECK-SAME: (%{{.+}}: i32)