Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into windows-runner-staging
Browse files Browse the repository at this point in the history
  • Loading branch information
Elias Joseph committed Oct 31, 2024
2 parents 3bfbf0c + 2ec9017 commit 059538f
Show file tree
Hide file tree
Showing 292 changed files with 8,051 additions and 3,713 deletions.
1 change: 0 additions & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
# Experimental
# It's experimental, but we still don't want any old directory added here.
/experimental/ @benvanik @stellaraccident
/experimental/rocm/ @benvanik
/experimental/web/ @ScottTodd
/experimental/webgpu/ @benvanik @ScottTodd

Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ jobs:
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -241,9 +241,9 @@ jobs:
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
4 changes: 0 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
[submodule "third_party/vulkan_headers"]
path = third_party/vulkan_headers
url = https://github.com/KhronosGroup/Vulkan-Headers.git
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
branch = stable
[submodule "third_party/benchmark"]
path = third_party/benchmark
url = https://github.com/google/benchmark.git
Expand Down
37 changes: 24 additions & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,30 @@ endif()
# MLIR/LLVM Dependency
#-------------------------------------------------------------------------------

# Both the IREE and MLIR Python bindings require pybind11. We initialize it here
# at the top level so that everything uses ours consistently.
if(IREE_BUILD_PYTHON_BINDINGS AND IREE_BUILD_COMPILER)
set(pybind11_VERSION 2.13.6)
include(FetchContent)
FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11
GIT_TAG v${pybind11_VERSION}
)
set(PYBIND11_FINDPYTHON ON)
FetchContent_MakeAvailable(pybind11)
# pybind11 source fetches do not include find_package integration, which is
# a shame since sub-projects can require that to work. If we were using
# CMake 3.24, we could just add OVERRIDE_FIND_PACKAGE to the
# FetchContent_Declare call above and it would take care of doing the
# following to let subsequent sub-project find_package calls to resolve
# successfully.
set(pybind11_DIR "${pybind11_BINARY_DIR}")
file(WRITE "${pybind11_BINARY_DIR}/pybind11Config.cmake" "")
file(WRITE "${pybind11_BINARY_DIR}/pybind11ConfigVersion.cmake"
"set(PACKAGE_VERSION ${pybind11_VERSION})\nset(PACKAGE_VERSION_COMPATIBLE TRUE)")
endif()

if(NOT IREE_BUILD_COMPILER)
message(STATUS "Not adding LLVM/MLIR because the configuration does not require it")
else()
Expand Down Expand Up @@ -921,19 +945,6 @@ if(IREE_BUILD_TESTS)
include(iree_configure_testing)
endif()

if(IREE_BUILD_PYTHON_BINDINGS)
# The compiler uses pybind11
if(IREE_BUILD_COMPILER)
if(NOT TARGET pybind11::module)
message(STATUS "Using bundled pybind11")
set(PYBIND11_FINDPYTHON ON)
add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL)
else()
message(STATUS "Not including bundled pybind11 (already configured)")
endif()
endif()
endif()

if(IREE_TARGET_BACKEND_METAL_SPIRV)
# SPIRV-Cross is needed to cross compile SPIR-V into MSL source code.
iree_set_spirv_cross_cmake_options()
Expand Down
7 changes: 7 additions & 0 deletions build_tools/bazel/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ def configure_iree_submodule_deps(iree_repo_alias = "@", iree_path = "./"):
path = paths.join(iree_path, "third_party/nccl"),
)

maybe(
native.new_local_repository,
name = "hsa_runtime_headers",
build_file = iree_repo_alias + "//:build_tools/third_party/hsa-runtime-headers/BUILD.overlay",
path = paths.join(iree_path, "third_party/hsa-runtime-headers"),
)

maybe(
native.new_local_repository,
name = "webgpu_headers",
Expand Down
1 change: 1 addition & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, repo_map: Dict[str, str]):
"@com_google_googletest//:gtest": ["gmock", "gtest"],
"@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"],
"@cpuinfo": ["${IREE_CPUINFO_TARGET}"],
"@hsa_runtime_headers": ["hsa_runtime::headers"],
"@webgpu_headers": [],
}
)
Expand Down
6 changes: 6 additions & 0 deletions build_tools/cmake/build_and_test_byo_llvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ echo "Setting up venv at $VENV_DIR"
python3 -m venv "$VENV_DIR"
source "$VENV_DIR/bin/activate"
python -m pip install -r runtime/bindings/python/iree/runtime/build_requirements.txt
python -m pip install -r third_party/llvm-project/mlir/python/requirements.txt
# Note: IREE's Python bindings for Python 3.13 are build with support for
# free-threading for which support was added to pybind with version 2.13.0.
# Therefore, we upgrade to a more recent version and avoid mixing of different
# pybind versions.
python -m pip install pybind11==2.13.6

# Note: by using the `build_llvm` action here, we are exercising byo_llvm.sh's
# ability to build LLVM... from our own third_party/llvm-project. That's not
Expand Down
3 changes: 3 additions & 0 deletions build_tools/llvm/byo_llvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ do_build_mlir() {

cmake_options="-DLLVM_DIR='${main_install_dir}/lib/cmake/llvm'"
cmake_options="${cmake_options} -DPython3_EXECUTABLE='$(which $python3_command)'"
# Note: Building the MLIR Python bindings requires the installation of
# dependencies as specified in `mlir/python/requirements.txt`, which among
# others include pybind11.
cmake_options="${cmake_options} -DMLIR_ENABLE_BINDINGS_PYTHON=ON"
cmake_options="${cmake_options} -DCMAKE_INSTALL_PREFIX=${mlir_install_dir}"
cmake_options="${cmake_options} -C $TD/mlir_config.cmake"
Expand Down
16 changes: 16 additions & 0 deletions build_tools/third_party/hsa-runtime-headers/BUILD.overlay
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

package(default_visibility = ["//visibility:public"])

cc_library(
name = "hsa_runtime_headers",
hdrs = glob([
"include/hsa/*.h",
]),
include_prefix = "third_party/hsa-runtime-headers/",
includes = ["include"],
)
28 changes: 28 additions & 0 deletions build_tools/third_party/hsa-runtime-headers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(HSA_RUNTIME_HEADERS_ROOT "${IREE_ROOT_DIR}/third_party/hsa-runtime-headers/")

external_cc_library(
PACKAGE
hsa_runtime
NAME
headers
ROOT
${HSA_RUNTIME_HEADERS_ROOT}
SYSTEM_INCLUDES
${HSA_RUNTIME_HEADERS_ROOT}/include/
PUBLIC
)

iree_install_targets(
TARGETS
hsa_runtime_headers
COMPONENT
IREEBundledLibraries
EXPORT_SET
Runtime
)
15 changes: 13 additions & 2 deletions compiler/plugins/input/Torch/PluginRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,19 @@ struct TorchSession
OpPassManager &passManager, std::string_view typeMnemonic) override {
if (typeMnemonic == "onnx") {
// ONNX input is a pre-processing step to torch.
passManager.addNestedPass<func::FuncOp>(
mlir::torch::onnx_c::createTorchOnnxToTorchPass());
mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions;
// The `aten.flatten.using_ints` and `aten.unflatten.int` are added to the
// list of backend legal ops so that they are not decomposed into the
// `aten.view` op during the run of `DecomposeComplexOps` pass. The issue
// with this is that the `aten.view` op eventually lowers to
// `tensor.reshape` op while there exists a direct torch->linalg lowering
// for both the flatten/unflatten ops which lowers to
// `tensor.collapse_shape/expand_shape` op, and this is a more preferred
// path for the downstream pipeline.
torchOnnxPipelineOptions.backendLegalOps = {"aten.flatten.using_ints",
"aten.unflatten.int"};
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
passManager, torchOnnxPipelineOptions);
}

if (typeMnemonic == "torch" || typeMnemonic == "onnx") {
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ class CUDATargetBackend final : public TargetBackend {
buildLLVMGPUCodegenPassPipeline(passManager, false);
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMGPULinkingPassPipeline(passManager, "cuda");
}

LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
Expand Down
44 changes: 36 additions & 8 deletions compiler/plugins/target/CUDA/test/smoketest.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX

#map = affine_map<(d0) -> (d0)>

module attributes {
hal.device.targets = [
#hal.device.target<"cuda", [
Expand All @@ -11,13 +9,13 @@ module attributes {
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
stream.executable public @add_dispatch_executable {
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
Expand All @@ -26,7 +24,7 @@ stream.executable public @add_dispatch_0 {
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
Expand All @@ -36,12 +34,42 @@ stream.executable public @add_dispatch_0 {
}
}

stream.executable public @mul_dispatch_executable {
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
%0 = tensor.empty() : tensor<16xf32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.mulf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
return
}
}
}

}

// PTX: .entry add_dispatch_0
// PTX: .entry add_dispatch
// PTX: .maxntid 64, 1, 1
// PTX: add.rn.f32

// CHECK: hal.executable.binary public @cuda_nvptx_fb attributes {
// PTX: .entry mul_dispatch
// PTX: .maxntid 64, 1, 1
// PTX: mul.rn.f32

// CHECK: hal.executable public @smoketest_linked
// CHECK-NEXT: hal.executable.binary public @cuda_nvptx_fb attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = "cuda-nvptx-fb"
2 changes: 1 addition & 1 deletion compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMCPULinkingPassPipeline(passManager);
buildLLVMCPULinkingPassPipeline(passManager, "llvm-cpu");
}

// Gets the LLVM target from |variantOp|.
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs",
"//runtime/src/iree/schemas:executable_debug_info_c_fbs",
"//runtime/src/iree/schemas:hip_executable_def_c_fbs",
"@llvm-project//llvm:AMDGPUCodeGen",
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iree_cc_library(
iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::amdgpu_executable_def_c_fbs
iree::schemas::executable_debug_info_c_fbs
iree::schemas::hip_executable_def_c_fbs
PUBLIC
Expand Down
Loading

0 comments on commit 059538f

Please sign in to comment.