Skip to content

Commit

Permalink
Enabling linking in the ROCM compiler target.
Browse files Browse the repository at this point in the history
This does exactly what the LLVMCPU side does - which is bad for
compile time (serializes LLVM codegen) but much better for runtime.
Future improvements should move LLVM codegen to the linking phase so it
can happen in parallel and then perform the linking using LLVM's linker
(each executable turned into a .o and then combined into a .so, or
last-level bitcode if then we just want serialization to be bitcode to
machine code).
  • Loading branch information
benvanik committed Oct 29, 2024
1 parent 39ca877 commit 77849f2
Show file tree
Hide file tree
Showing 16 changed files with 467 additions and 135 deletions.
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ class ROCMTargetBackend final : public TargetBackend {
buildLLVMGPUCodegenPassPipeline(passManager, true);
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMGPULinkingPassPipeline(passManager);
}

// Performs optimizations on |module| (including LTO-style whole-program
// ones). Inspired by code section in
// https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
Expand Down
42 changes: 34 additions & 8 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,44 @@

module attributes {
hal.device.targets = [
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
#hal.device.target<"amdgpu", [
#hal.executable.target<"rocm", "amdgcn-amd-amdhsa">
]> : !hal.device
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
stream.executable public @add_dispatch_executable {
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
%0 = tensor.empty() : tensor<16xf32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
return
}
}
}

stream.executable public @mul_dispatch_executable {
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
Expand All @@ -23,7 +48,7 @@ stream.executable public @add_dispatch_0 {
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
Expand All @@ -35,9 +60,10 @@ stream.executable public @add_dispatch_0 {

}

// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes {
// CHECK: hal.executable public @smoketest_linked_llvm_gpu
// CHECK: hal.executable.binary public @amdgcn_amd_amdhsa attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = "rocm-hsaco-fb"
// CHECK-SAME: format = "amdgcn-amd-amdhsa"

// -----

Expand Down
240 changes: 120 additions & 120 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
@@ -1,120 +1,120 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"

bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
}

MlirAttribute
ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory,
bool *noReduceSharedMemoryBankConflicts,
MlirAttribute *reorderWorkgroupsStrategy) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
mlir::Builder b(ctx);
auto prefetchSharedMemoryAttr = mlir::BoolAttr();
if (prefetchSharedMemory) {
prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory);
}
auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr();
if (noReduceSharedMemoryBankConflicts) {
noReduceSharedMemoryBankConflictsAttr =
b.getBoolAttr(*noReduceSharedMemoryBankConflicts);
}
auto strategyAttr =
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr();
if (reorderWorkgroupsStrategy) {
strategyAttr = llvm::dyn_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(*reorderWorkgroupsStrategy));
}
return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get(
ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr,
strategyAttr));
}

MlirAttribute
ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getPrefetchSharedMemory());
}

MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts(
MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts());
}

MlirAttribute
ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getReorderWorkgroupsStrategy());
}

MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
return wrap(
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}

static_assert(
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::None) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Swizzle) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Transpose) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
mlir::iree_compiler::IREE::GPU::
getMaxEnumValForReorderWorkgroupsStrategy(),
"ireeGPUReorderWorkgroupsStrategyEnum and "
"mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
"have diverged");

bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr));
}

MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::
getTypeID());
}

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
ctx, static_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>(
value)));
}

ireeGPUReorderWorkgroupsStrategyEnum
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
"attr is not a GPUReorderWorkgroupsStrategyAttr");
return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr))
.getValue());
}
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"

bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
}

MlirAttribute
ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory,
bool *noReduceSharedMemoryBankConflicts,
MlirAttribute *reorderWorkgroupsStrategy) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
mlir::Builder b(ctx);
auto prefetchSharedMemoryAttr = mlir::BoolAttr();
if (prefetchSharedMemory) {
prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory);
}
auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr();
if (noReduceSharedMemoryBankConflicts) {
noReduceSharedMemoryBankConflictsAttr =
b.getBoolAttr(*noReduceSharedMemoryBankConflicts);
}
auto strategyAttr =
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr();
if (reorderWorkgroupsStrategy) {
strategyAttr = llvm::dyn_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(*reorderWorkgroupsStrategy));
}
return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get(
ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr,
strategyAttr));
}

MlirAttribute
ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getPrefetchSharedMemory());
}

MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts(
MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts());
}

MlirAttribute
ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getReorderWorkgroupsStrategy());
}

MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
return wrap(
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}

static_assert(
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::None) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Swizzle) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Transpose) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
mlir::iree_compiler::IREE::GPU::
getMaxEnumValForReorderWorkgroupsStrategy(),
"ireeGPUReorderWorkgroupsStrategyEnum and "
"mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
"have diverged");

bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr));
}

MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::
getTypeID());
}

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
ctx, static_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>(
value)));
}

ireeGPUReorderWorkgroupsStrategyEnum
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
"attr is not a GPUReorderWorkgroupsStrategyAttr");
return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr))
.getValue());
}
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ iree_compiler_cc_library(
"ConvertToROCDL.cpp",
"ExtractAddressComputationGPUPass.cpp",
"KernelConfig.cpp",
"LLVMGPUAssignConstantOrdinals.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPUConvolutionToIGEMM.cpp",
"LLVMGPULinkExecutables.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ iree_cc_library(
"ConvertToROCDL.cpp"
"ExtractAddressComputationGPUPass.cpp"
"KernelConfig.cpp"
"LLVMGPUAssignConstantOrdinals.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPUConvolutionToIGEMM.cpp"
"LLVMGPULinkExecutables.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
Expand Down
Loading

0 comments on commit 77849f2

Please sign in to comment.