Skip to content

Commit

Permalink
Enabling linking in the ROCM/CUDA compiler targets. (#18936)
Browse files Browse the repository at this point in the history
This does exactly what the LLVMCPU side does - which is bad for compile
time (serializes LLVM codegen) but much better for runtime. Future
improvements should move LLVM codegen to the linking phase so it can
happen in parallel and then perform the linking using LLVM's linker
(each executable turned into a .o and then combined into a .so, or
last-level bitcode if then we just want serialization to be bitcode to
machine code). This is definitely a compile-time regression but we can't
keep pessimizing runtime.
  • Loading branch information
benvanik authored Oct 29, 2024
1 parent a321be2 commit 49ffdac
Show file tree
Hide file tree
Showing 26 changed files with 669 additions and 167 deletions.
4 changes: 4 additions & 0 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ class CUDATargetBackend final : public TargetBackend {
buildLLVMGPUCodegenPassPipeline(passManager, false);
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMGPULinkingPassPipeline(passManager, "cuda");
}

LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
Expand Down
44 changes: 36 additions & 8 deletions compiler/plugins/target/CUDA/test/smoketest.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX

#map = affine_map<(d0) -> (d0)>

module attributes {
hal.device.targets = [
#hal.device.target<"cuda", [
Expand All @@ -11,13 +9,13 @@ module attributes {
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
stream.executable public @add_dispatch_executable {
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
Expand All @@ -26,7 +24,7 @@ stream.executable public @add_dispatch_0 {
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
Expand All @@ -36,12 +34,42 @@ stream.executable public @add_dispatch_0 {
}
}

stream.executable public @mul_dispatch_executable {
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
%0 = tensor.empty() : tensor<16xf32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.mulf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
return
}
}
}

}

// PTX: .entry add_dispatch_0
// PTX: .entry add_dispatch
// PTX: .maxntid 64, 1, 1
// PTX: add.rn.f32

// CHECK: hal.executable.binary public @cuda_nvptx_fb attributes {
// PTX: .entry mul_dispatch
// PTX: .maxntid 64, 1, 1
// PTX: mul.rn.f32

// CHECK: hal.executable public @smoketest_linked
// CHECK-NEXT: hal.executable.binary public @cuda_nvptx_fb attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = "cuda-nvptx-fb"
2 changes: 1 addition & 1 deletion compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMCPULinkingPassPipeline(passManager);
buildLLVMCPULinkingPassPipeline(passManager, "llvm-cpu");
}

// Gets the LLVM target from |variantOp|.
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ class ROCMTargetBackend final : public TargetBackend {
buildLLVMGPUCodegenPassPipeline(passManager, true);
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMGPULinkingPassPipeline(passManager, "rocm");
}

// Performs optimizations on |module| (including LTO-style whole-program
// ones). Inspired by code section in
// https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
Expand Down
48 changes: 37 additions & 11 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

module attributes {
hal.device.targets = [
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
#hal.device.target<"amdgpu", [
#hal.executable.target<"rocm", "amdgcn-amd-amdhsa">
]> : !hal.device
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
stream.executable public @add_dispatch_executable {
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
Expand All @@ -23,7 +23,7 @@ stream.executable public @add_dispatch_0 {
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
Expand All @@ -33,11 +33,37 @@ stream.executable public @add_dispatch_0 {
}
}

stream.executable public @mul_dispatch_executable {
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
%0 = tensor.empty() : tensor<16xf32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.mulf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
return
}
}
}

}

// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes {
// CHECK: hal.executable public @smoketest_linked
// CHECK: hal.executable.binary public @amdgcn_amd_amdhsa attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = "rocm-hsaco-fb"
// CHECK-SAME: format = "amdgcn-amd-amdhsa"

// -----

Expand All @@ -52,13 +78,13 @@ module attributes {
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
stream.executable public @executable {
stream.executable.export @export workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
} loc(#loc)
builtin.module {
func.func @add_dispatch_0() {
func.func @export() {
return
} loc(#loc)
} loc(#loc)
Expand Down
Loading

0 comments on commit 49ffdac

Please sign in to comment.