Skip to content

Commit

Permalink
[cuda] Switch cuda2 on and cuda1 off by default (#16107)
Browse files Browse the repository at this point in the history
This commit switches the cuda2 HAL driver on and
the cuda HAL driver (which is renamed to cuda1) off
by default in CMake. In order to do this, we also
switched cuda2 to use stream-based command buffer
by default to follow cuda1 for simple transition. 

Fixes #13245

benchmark-extra: cuda-large
  • Loading branch information
antiagainst authored Jan 23, 2024
1 parent e31b5a2 commit 3b3cef9
Show file tree
Hide file tree
Showing 15 changed files with 46 additions and 45 deletions.
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL driv
# not cross compiling. Note: a CUDA-compatible GPU with drivers is still
# required to actually run CUDA workloads.
set(IREE_HAL_DRIVER_CUDA_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA1_DEFAULT OFF)
if(NOT IREE_CUDA_AVAILABLE OR CMAKE_CROSSCOMPILING)
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA1_DEFAULT OFF)
endif()

# Vulkan support is enabled by default if the platform might support Vulkan.
Expand All @@ -262,7 +262,7 @@ if(NOT APPLE OR NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
endif()

option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA2 "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_CUDA1 "Enables the 'cuda1' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA1_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
Expand Down Expand Up @@ -318,8 +318,8 @@ message(STATUS "IREE HAL drivers:")
if(IREE_HAL_DRIVER_CUDA)
message(STATUS " - cuda")
endif()
if(IREE_HAL_DRIVER_CUDA2)
message(STATUS " - cuda2")
if(IREE_HAL_DRIVER_CUDA1)
message(STATUS " - cuda1")
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
message(STATUS " - local-sync")
Expand Down
8 changes: 4 additions & 4 deletions runtime/src/iree/hal/drivers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ endif()

set(_INIT_INTERNAL_DEPS)
if(IREE_HAL_DRIVER_CUDA)
add_subdirectory(cuda)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration)
endif()
if(IREE_HAL_DRIVER_CUDA2)
add_subdirectory(cuda2)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda2::registration)
endif()
if(IREE_HAL_DRIVER_CUDA1)
add_subdirectory(cuda)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration)
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
add_subdirectory(local_sync)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::local_sync::registration)
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
DRIVER_REGISTRATION_HDR
"runtime/src/iree/hal/drivers/cuda/registration/driver_module.h"
DRIVER_REGISTRATION_FN
Expand All @@ -28,7 +28,7 @@ iree_hal_cts_test_suite(
# Variant test suite using graph command buffers (--cuda_use_streams=0)
iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/registration/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ iree_runtime_cc_library(
"driver_module.h",
],
defines = [
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1",
"IREE_HAVE_HAL_CUDA1_DRIVER_MODULE=1",
],
tags = ["driver=cuda"],
tags = ["driver=cuda1"],
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:flags",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ iree_cc_library(
iree::hal
iree::hal::drivers::cuda
DEFINES
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1"
"IREE_HAVE_HAL_CUDA1_DRIVER_MODULE=1"
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ static iree_status_t iree_hal_cuda_driver_factory_enumerate(
const iree_hal_driver_info_t** out_driver_infos) {
// NOTE: we could query supported cuda versions or featuresets here.
static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = iree_string_view_literal("cuda"),
.full_name = iree_string_view_literal("CUDA (dynamic)"),
.driver_name = iree_string_view_literal("cuda1"),
.full_name = iree_string_view_literal("deprecated CUDA (dynamic)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -79,7 +79,7 @@ static iree_status_t iree_hal_cuda_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda1"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda2/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand All @@ -31,7 +31,7 @@ iree_hal_cts_test_suite(

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
stream
DRIVER_REGISTRATION_HDR
Expand Down
3 changes: 2 additions & 1 deletion runtime/src/iree/hal/drivers/cuda2/registration/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ iree_runtime_cc_library(
"driver_module.h",
],
defines = [
"IREE_HAVE_HAL_CUDA2_DRIVER_MODULE=1",
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1",
],
tags = ["driver=cuda"],
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:flags",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ iree_cc_library(
iree::hal
iree::hal::drivers::cuda2
DEFINES
"IREE_HAVE_HAL_CUDA2_DRIVER_MODULE=1"
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1"
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "iree/hal/drivers/cuda2/api.h"

IREE_FLAG(
bool, cuda2_use_streams, false,
bool, cuda2_use_streams, true,
"Use CUDA streams (instead of graphs) for executing command buffers.");

IREE_FLAG(bool, cuda2_allow_inline_execution, false,
Expand Down Expand Up @@ -70,8 +70,8 @@ static iree_status_t iree_hal_cuda2_driver_factory_enumerate(
IREE_TRACE_ZONE_BEGIN(z0);

static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = IREE_SVL("cuda2"),
.full_name = IREE_SVL("next-gen NVIDIA CUDA HAL driver (via dylib)"),
.driver_name = IREE_SVL("cuda"),
.full_name = IREE_SVL("NVIDIA CUDA HAL driver (via dylib)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -85,7 +85,7 @@ static iree_status_t iree_hal_cuda2_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);

if (!iree_string_view_equal(driver_name, IREE_SV("cuda2"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
16 changes: 8 additions & 8 deletions runtime/src/iree/hal/drivers/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include "iree/hal/drivers/init.h"

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA1_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
#include "iree/hal/drivers/local_sync/registration/driver_module.h"
Expand Down Expand Up @@ -47,13 +47,13 @@ iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_driver_module_register(registry));
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
z0, iree_hal_cuda_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA1_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ iree_check_single_backend_test_suite(
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
],
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=false"],
tags = [
Expand Down Expand Up @@ -499,7 +499,7 @@ iree_check_single_backend_test_suite(
include = ["*.mlir"],
exclude = [],
),
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=true"],
tags = [
Expand Down Expand Up @@ -589,7 +589,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -609,7 +609,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
INPUT_TYPE
Expand Down Expand Up @@ -453,7 +453,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
INPUT_TYPE
"stablehlo"
RUNNER_ARGS
Expand Down Expand Up @@ -534,7 +534,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -618,7 +618,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -323,7 +323,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -332,7 +332,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down

0 comments on commit 3b3cef9

Please sign in to comment.