Skip to content

Commit

Permalink
[cuda] Wire up basic creating devices, allocators, and buffers (#14011)
Browse files Browse the repository at this point in the history
This commit ports over existing CUDA HAL driver code to enable
creating devices, allocators, and buffers. It's mostly NFC,
with just symbol renaming and turning various functionalities
into unimplemented.

With this commit, we are able to pick up CTS tests.

Progress towards #13245
  • Loading branch information
antiagainst authored Jun 9, 2023
1 parent 967ab3b commit c4e01e9
Show file tree
Hide file tree
Showing 9 changed files with 615 additions and 9 deletions.
5 changes: 5 additions & 0 deletions experimental/cuda2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ iree_cc_library(
"cuda_allocator.h"
"cuda_buffer.c"
"cuda_buffer.h"
"cuda_device.c"
"cuda_device.h"
"cuda_driver.c"
"memory_pools.c"
"memory_pools.h"
DEPS
::dynamic_symbols
iree::base
iree::base::internal
iree::base::internal::arena
iree::base::core_headers
iree::base::tracing
iree::hal
iree::hal::utils::buffer_transfer
iree::schemas::cuda_executable_def_c_fbs
PUBLIC
)
Expand Down
29 changes: 28 additions & 1 deletion experimental/cuda2/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ extern "C" {
#endif // __cplusplus

//===----------------------------------------------------------------------===//
// iree_hal_cuda_device_t
// iree_hal_cuda2_device_t
//===----------------------------------------------------------------------===//

// Parameters defining a CUmemoryPool.
Expand All @@ -40,6 +40,32 @@ typedef struct iree_hal_cuda2_memory_pooling_params_t {
iree_hal_cuda2_memory_pool_params_t other;
} iree_hal_cuda2_memory_pooling_params_t;

// Parameters configuring an iree_hal_cuda2_device_t.
// Must be initialized with iree_hal_cuda2_device_params_initialize prior to
// use.
typedef struct iree_hal_cuda2_device_params_t {
// Number of queues exposed on the device.
// Each queue acts as a separate synchronization scope where all work executes
// concurrently unless prohibited by semaphores.
iree_host_size_t queue_count;

// Total size of each block in the device shared block pool.
// Larger sizes will lower overhead and ensure the heap isn't hit for
// transient allocations while also increasing memory consumption.
iree_host_size_t arena_block_size;

// Whether to use async allocations even if reported as available by the
// device. Defaults to true when the device supports it.
bool async_allocations;

// Parameters for each CUmemoryPool used for queue-ordered allocations.
iree_hal_cuda2_memory_pooling_params_t memory_pools;
} iree_hal_cuda2_device_params_t;

// Initializes |out_params| to default values.
IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize(
iree_hal_cuda2_device_params_t* out_params);

//===----------------------------------------------------------------------===//
// iree_hal_cuda2_driver_t
//===----------------------------------------------------------------------===//
Expand All @@ -62,6 +88,7 @@ IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize(
IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create(
iree_string_view_t identifier,
const iree_hal_cuda2_driver_options_t* options,
const iree_hal_cuda2_device_params_t* default_params,
iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);

#ifdef __cplusplus
Expand Down
24 changes: 24 additions & 0 deletions experimental/cuda2/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
DRIVER_REGISTRATION_HDR
"experimental/cuda2/registration/driver_module.h"
DRIVER_REGISTRATION_FN
"iree_hal_cuda2_driver_module_register"
COMPILER_TARGET_BACKEND
"cuda"
EXECUTABLE_FORMAT
"\"PTXE\""
DEPS
iree::experimental::cuda2::registration
INCLUDED_TESTS
"allocator"
"buffer_mapping"
"driver"
)
Loading

0 comments on commit c4e01e9

Please sign in to comment.