Skip to content

Commit

Permalink
[cuda] Move to hal/drivers and wire up BUILD files
Browse files Browse the repository at this point in the history
This commit moves the CUDA HAL driver rewrite to the
`hal/drivers` directory given it's functional and ready for
normal usage. By this we can start run tests with CI
to make sure it does not regress. Further improvements
can happen directly in this directory.

This provides an easy route for trying out the rewrite before
eventually replace the existing HAL driver.

Along the way wired up BUILD configurations.
  • Loading branch information
antiagainst committed Nov 7, 2023
1 parent e14dff4 commit a5c70a0
Show file tree
Hide file tree
Showing 66 changed files with 894 additions and 550 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@
# Runtime
/runtime/src/iree/ @benvanik
/runtime/src/iree/hal/cts/ @ScottTodd
/runtime/src/iree/hal/drivers/cuda2/ @antiagainst
/runtime/src/iree/hal/drivers/metal/ @antiagainst
/runtime/src/iree/hal/drivers/vulkan/ @antiagainst @ScottTodd
16 changes: 6 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,10 @@ option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL driv
# not cross compiling. Note: a CUDA-compatible GPU with drivers is still
# required to actually run CUDA workloads.
set(IREE_HAL_DRIVER_CUDA_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
set(IREE_HAL_DRIVER_CUDA2_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
if(NOT IREE_CUDA_AVAILABLE OR CMAKE_CROSSCOMPILING)
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
endif()

# Vulkan support is enabled by default if the platform might support Vulkan.
Expand All @@ -251,6 +253,7 @@ if(NOT APPLE OR NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
endif()

option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA2 "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
Expand Down Expand Up @@ -306,6 +309,9 @@ message(STATUS "IREE HAL drivers:")
if(IREE_HAL_DRIVER_CUDA)
message(STATUS " - cuda")
endif()
if(IREE_HAL_DRIVER_CUDA2)
message(STATUS " - cuda2")
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
message(STATUS " - local-sync")
endif()
Expand Down Expand Up @@ -341,16 +347,6 @@ if(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY)
message(STATUS " - system-library")
endif()

#-------------------------------------------------------------------------------
# Experimental next-generation CUDA HAL driver
# Enable with: -DIREE_EXTERNAL_HAL_DRIVERS=cuda2
#-------------------------------------------------------------------------------

set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/experimental/cuda2")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experimental/cuda2")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_TARGET "iree::experimental::cuda2::registration")
set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_REGISTER "iree_hal_cuda2_driver_module_register")

#-------------------------------------------------------------------------------
# Experimental ROCM HAL driver
# Enable with: -DIREE_EXTERNAL_HAL_DRIVERS=rocm
Expand Down
21 changes: 0 additions & 21 deletions experimental/cuda2/registration/CMakeLists.txt

This file was deleted.

7 changes: 0 additions & 7 deletions experimental/cuda2/tests/CMakeLists.txt

This file was deleted.

167 changes: 0 additions & 167 deletions experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt

This file was deleted.

129 changes: 0 additions & 129 deletions experimental/cuda2/tests/tosa_ops/CMakeLists.txt

This file was deleted.

5 changes: 4 additions & 1 deletion runtime/src/iree/hal/drivers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ iree_runtime_cc_library(
"//runtime/src/iree/base",
"//runtime/src/iree/hal",
] + select({
":cuda_enabled": ["//runtime/src/iree/hal/drivers/cuda/registration"],
":cuda_enabled": [
"//runtime/src/iree/hal/drivers/cuda/registration",
"//runtime/src/iree/hal/drivers/cuda2/registration",
],
"//conditions:default": [],
}) +
select({
Expand Down
Loading

0 comments on commit a5c70a0

Please sign in to comment.