Skip to content

Commit

Permalink
[cuda] Move to hal/drivers and wire up BUILD files
Browse files Browse the repository at this point in the history
This commit moves the CUDA HAL driver rewrite to the
`hal/drivers` directory given it's functional and ready for
normal usage. By this we can start run tests with CI
to make sure it does not regress. Further improvements
can happen directly in this directory.

This provides an easy route for trying out the rewrite before
eventually replace the existing HAL driver.

Along the way wired up BUILD configurations.
  • Loading branch information
antiagainst committed Aug 25, 2023
1 parent eba7eac commit 921cf3b
Show file tree
Hide file tree
Showing 64 changed files with 693 additions and 389 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@
# Runtime
/runtime/src/iree/ @benvanik
/runtime/src/iree/hal/cts/ @ScottTodd
/runtime/src/iree/hal/drivers/cuda2/ @antiagainst
/runtime/src/iree/hal/drivers/metal/ @antiagainst
/runtime/src/iree/hal/drivers/vulkan/ @antiagainst @ScottTodd
15 changes: 6 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL driv
# not cross compiling. Note: a CUDA-compatible GPU with drivers is still
# required to actually run CUDA workloads.
set(IREE_HAL_DRIVER_CUDA_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
set(IREE_HAL_DRIVER_CUDA2_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
if(NOT IREE_CUDA_AVAILABLE OR CMAKE_CROSSCOMPILING)
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
endif()

# Vulkan support is enabled by default if the platform might support Vulkan.
Expand All @@ -243,6 +245,7 @@ if(NOT APPLE OR NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
endif()

option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA2 "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
Expand Down Expand Up @@ -298,6 +301,9 @@ message(STATUS "IREE HAL drivers:")
if(IREE_HAL_DRIVER_CUDA)
message(STATUS " - cuda")
endif()
if(IREE_HAL_DRIVER_CUDA2)
message(STATUS " - cuda2")
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
message(STATUS " - local-sync")
endif()
Expand Down Expand Up @@ -333,15 +339,6 @@ if(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY)
message(STATUS " - system-library")
endif()

#-------------------------------------------------------------------------------
# Experimental next-generation CUDA HAL driver
#-------------------------------------------------------------------------------

set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/experimental/cuda2")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experimental/cuda2")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_TARGET "iree::experimental::cuda2::registration")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_REGISTER "iree_hal_cuda2_driver_module_register")

#-------------------------------------------------------------------------------
# Experimental ROCM HAL driver
#-------------------------------------------------------------------------------
Expand Down
21 changes: 0 additions & 21 deletions experimental/cuda2/registration/CMakeLists.txt

This file was deleted.

7 changes: 0 additions & 7 deletions experimental/cuda2/tests/CMakeLists.txt

This file was deleted.

85 changes: 0 additions & 85 deletions experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt

This file was deleted.

66 changes: 0 additions & 66 deletions experimental/cuda2/tests/tosa_ops/CMakeLists.txt

This file was deleted.

5 changes: 4 additions & 1 deletion runtime/src/iree/hal/drivers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ iree_runtime_cc_library(
"//runtime/src/iree/base",
"//runtime/src/iree/hal",
] + select({
":cuda_enabled": ["//runtime/src/iree/hal/drivers/cuda/registration"],
":cuda_enabled": [
"//runtime/src/iree/hal/drivers/cuda/registration",
"//runtime/src/iree/hal/drivers/cuda2/registration",
],
"//conditions:default": [],
}) +
select({
Expand Down
4 changes: 4 additions & 0 deletions runtime/src/iree/hal/drivers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ if(IREE_HAL_DRIVER_CUDA)
add_subdirectory(cuda)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration)
endif()
if(IREE_HAL_DRIVER_CUDA2)
add_subdirectory(cuda2)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda2::registration)
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
add_subdirectory(local_sync)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::local_sync::registration)
Expand Down
113 changes: 113 additions & 0 deletions runtime/src/iree/hal/drivers/cuda2/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_runtime_cc_library(
name = "cuda2",
srcs = [
"api.h",
"cuda_allocator.c",
"cuda_allocator.h",
"cuda_buffer.c",
"cuda_buffer.h",
"cuda_device.c",
"cuda_device.h",
"cuda_driver.c",
"event_pool.c",
"event_pool.h",
"event_semaphore.c",
"event_semaphore.h",
"graph_command_buffer.c",
"graph_command_buffer.h",
"memory_pools.c",
"memory_pools.h",
"native_executable.c",
"native_executable.h",
"nccl_channel.c",
"nccl_channel.h",
"nop_executable_cache.c",
"nop_executable_cache.h",
"pending_queue_actions.c",
"pending_queue_actions.h",
"pipeline_layout.c",
"pipeline_layout.h",
"timepoint_pool.c",
"timepoint_pool.h",
"tracing.c",
"tracing.h",
],
hdrs = [
"api.h",
],
deps = [
":dynamic_symbols",
"//runtime/src/iree/base",
"//runtime/src/iree/base:core_headers",
"//runtime/src/iree/base/internal",
"//runtime/src/iree/base/internal:arena",
"//runtime/src/iree/base/internal:event_pool",
"//runtime/src/iree/base/internal:synchronization",
"//runtime/src/iree/base/internal/flatcc:parsing",
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/utils:buffer_transfer",
"//runtime/src/iree/hal/utils:collective_batch",
"//runtime/src/iree/hal/utils:deferred_command_buffer",
"//runtime/src/iree/hal/utils:file_transfer",
"//runtime/src/iree/hal/utils:memory_file",
"//runtime/src/iree/hal/utils:resource_set",
"//runtime/src/iree/hal/utils:semaphore_base",
"//runtime/src/iree/schemas:cuda_executable_def_c_fbs",
],
)

iree_runtime_cc_library(
name = "dynamic_symbols",
srcs = [
"cuda_dynamic_symbols.c",
"cuda_headers.h",
"cuda_status_util.c",
"nccl_dynamic_symbols.c",
"nccl_headers.h",
"nccl_status_util.c",
],
hdrs = [
"cuda_dynamic_symbols.h",
"cuda_status_util.h",
"nccl_dynamic_symbols.h",
"nccl_status_util.h",
],
textual_hdrs = [
"cuda_dynamic_symbol_table.h",
"nccl_dynamic_symbol_table.h",
],
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:dynamic_library",
"@iree_cuda//:headers",
"@nccl//:headers",
],
)

iree_runtime_cc_test(
name = "dynamic_symbols_test",
srcs = [
"dynamic_symbols_test.cc",
],
tags = ["driver=cuda2"],
deps = [
":dynamic_symbols",
"//runtime/src/iree/base",
"//runtime/src/iree/testing:gtest",
"//runtime/src/iree/testing:gtest_main",
],
)
Loading

0 comments on commit 921cf3b

Please sign in to comment.