Skip to content

Commit

Permalink
[hip][cuda] Merge the tracing implementations. (iree-org#18299)
Browse files Browse the repository at this point in the history
These were entirely copy-pasted off one another and it's not likely they
will have to diverge in the future.

---------

Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
  • Loading branch information
AWoloszyn authored Aug 22, 2024
1 parent 8e42839 commit 86ecf39
Show file tree
Hide file tree
Showing 24 changed files with 930 additions and 1,270 deletions.
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ iree_runtime_cc_library(
"stream_command_buffer.h",
"timepoint_pool.c",
"timepoint_pool.h",
"tracing.c",
"tracing.h",
],
hdrs = [
"api.h",
Expand All @@ -69,6 +67,7 @@ iree_runtime_cc_library(
"//runtime/src/iree/hal/utils:memory_file",
"//runtime/src/iree/hal/utils:resource_set",
"//runtime/src/iree/hal/utils:semaphore_base",
"//runtime/src/iree/hal/utils:stream_tracing",
"//runtime/src/iree/schemas:cuda_executable_def_c_fbs",
],
)
Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ iree_cc_library(
"stream_command_buffer.h"
"timepoint_pool.c"
"timepoint_pool.h"
"tracing.c"
"tracing.h"
DEPS
::dynamic_symbols
iree::base
Expand All @@ -66,6 +64,7 @@ iree_cc_library(
iree::hal::utils::memory_file
iree::hal::utils::resource_set
iree::hal::utils::semaphore_base
iree::hal::utils::stream_tracing
iree::schemas::cuda_executable_def_c_fbs
PUBLIC
)
Expand Down
165 changes: 152 additions & 13 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/stream_command_buffer.h"
#include "iree/hal/drivers/cuda/timepoint_pool.h"
#include "iree/hal/drivers/cuda/tracing.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/deferred_work_queue.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"
#include "iree/hal/utils/stream_tracing.h"

//===----------------------------------------------------------------------===//
// iree_hal_cuda_device_t
Expand Down Expand Up @@ -62,7 +62,7 @@ typedef struct iree_hal_cuda_device_t {
// The CUstream used to issue device kernels and allocations.
CUstream dispatch_cu_stream;

iree_hal_cuda_tracing_context_t* tracing_context;
iree_hal_stream_tracing_context_t* tracing_context;

iree_allocator_t host_allocator;

Expand Down Expand Up @@ -259,6 +259,108 @@ iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer(
return status;
}

typedef struct iree_hal_cuda_tracing_device_interface_t {
iree_hal_stream_tracing_device_interface_t base;
CUdevice cu_device;
CUcontext cu_context;
CUstream dispatch_cu_stream;
iree_allocator_t host_allocator;
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols;
} iree_hal_cuda_tracing_device_interface_t;
static const iree_hal_stream_tracing_device_interface_vtable_t
iree_hal_cuda_tracing_device_interface_vtable_t;

void iree_hal_cuda_tracing_device_interface_destroy(
iree_hal_stream_tracing_device_interface_t* base_device_interface) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

iree_allocator_free(device_interface->host_allocator, device_interface);
}

iree_status_t iree_hal_cuda_tracing_device_interface_synchronize_native_event(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_event_t base_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventSynchronize((CUevent)base_event));
}

iree_status_t iree_hal_cuda_tracing_device_interface_create_native_event(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_event_t* base_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventCreate((CUevent*)base_event, CU_EVENT_DEFAULT));
}

iree_status_t iree_hal_cuda_tracing_device_interface_query_native_event(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_event_t base_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventQuery((CUevent)base_event));
}

void iree_hal_cuda_tracing_device_interface_event_elapsed_time(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
float* relative_millis, iree_hal_stream_tracing_native_event_t start_event,
iree_hal_stream_tracing_native_event_t end_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

IREE_CUDA_IGNORE_ERROR(
device_interface->cuda_symbols,
cuEventElapsedTime(relative_millis, (CUevent)start_event,
(CUevent)end_event));
}

void iree_hal_cuda_tracing_device_interface_destroy_native_event(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_event_t base_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

IREE_CUDA_IGNORE_ERROR(device_interface->cuda_symbols,
cuEventDestroy((CUevent)base_event));
}

iree_status_t iree_hal_cuda_tracing_device_interface_record_native_event(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_event_t base_event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventRecord((CUevent)base_event, device_interface->dispatch_cu_stream));
}

iree_status_t
iree_hal_cuda_tracing_device_interface_add_graph_event_record_node(
iree_hal_stream_tracing_device_interface_t* base_device_interface,
iree_hal_stream_tracing_native_graph_node_t* out_node,
iree_hal_stream_tracing_native_graph_t graph,
iree_hal_stream_tracing_native_graph_node_t* dependency_nodes,
size_t dependency_nodes_count,
iree_hal_stream_tracing_native_event_t event) {
iree_hal_cuda_tracing_device_interface_t* device_interface =
(iree_hal_cuda_tracing_device_interface_t*)base_device_interface;

return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuGraphAddEventRecordNode((CUgraphNode*)out_node, (CUgraph)graph,
(CUgraphNode*)dependency_nodes,
dependency_nodes_count, (CUevent)event));
}

static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
iree_hal_device_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_device_vtable);
Expand Down Expand Up @@ -346,18 +448,36 @@ static iree_status_t iree_hal_cuda_device_create_internal(

// Enable tracing for the (currently only) stream - no-op if disabled.
if (iree_status_is_ok(status) && device->params.stream_tracing) {
if (device->params.stream_tracing >= IREE_HAL_CUDA_TRACING_VERBOSITY_MAX ||
device->params.stream_tracing < IREE_HAL_CUDA_TRACING_VERBOSITY_OFF) {
if (device->params.stream_tracing >= IREE_HAL_TRACING_VERBOSITY_MAX ||
device->params.stream_tracing < IREE_HAL_TRACING_VERBOSITY_OFF) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"invalid stream_tracing argument: expected to be between %d and %d",
IREE_HAL_CUDA_TRACING_VERBOSITY_OFF,
IREE_HAL_CUDA_TRACING_VERBOSITY_MAX);
IREE_HAL_TRACING_VERBOSITY_OFF, IREE_HAL_TRACING_VERBOSITY_MAX);
}

iree_hal_cuda_tracing_device_interface_t* tracing_device_interface = NULL;
status = iree_allocator_malloc(
host_allocator, sizeof(iree_hal_cuda_tracing_device_interface_t),
(void**)&tracing_device_interface);

if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
iree_hal_device_release((iree_hal_device_t*)device);
return status;
}
status = iree_hal_cuda_tracing_context_allocate(
device->cuda_symbols, device->identifier, dispatch_stream,
device->params.stream_tracing, &device->block_pool, host_allocator,
&device->tracing_context);

tracing_device_interface->base.vtable =
&iree_hal_cuda_tracing_device_interface_vtable_t;
tracing_device_interface->cu_context = context;
tracing_device_interface->cu_device = cu_device;
tracing_device_interface->dispatch_cu_stream = dispatch_stream;
tracing_device_interface->host_allocator = host_allocator;
tracing_device_interface->cuda_symbols = cuda_symbols;

status = iree_hal_stream_tracing_context_allocate(
(iree_hal_stream_tracing_device_interface_t*)tracing_device_interface,
device->identifier, device->params.stream_tracing, &device->block_pool,
host_allocator, &device->tracing_context);
}

// Memory pool support is conditional.
Expand Down Expand Up @@ -505,7 +625,7 @@ static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
// Destroy memory pools that hold on to reserved memory.
iree_hal_cuda_memory_pools_deinitialize(&device->memory_pools);

iree_hal_cuda_tracing_context_free(device->tracing_context);
iree_hal_stream_tracing_context_free(device->tracing_context);

// Destroy various pools for synchronization.
if (device->timepoint_pool) {
Expand Down Expand Up @@ -947,8 +1067,8 @@ static iree_status_t iree_hal_cuda_device_queue_write(
}

static void iree_hal_cuda_device_collect_tracing_context(void* user_data) {
iree_hal_cuda_tracing_context_collect(
(iree_hal_cuda_tracing_context_t*)user_data);
iree_hal_stream_tracing_context_collect(
(iree_hal_stream_tracing_context_t*)user_data);
}

static iree_status_t iree_hal_cuda_device_queue_execute(
Expand Down Expand Up @@ -1074,3 +1194,22 @@ static const iree_hal_deferred_work_queue_device_interface_vtable_t
.submit_command_buffer =
iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer,
};

static const iree_hal_stream_tracing_device_interface_vtable_t
iree_hal_cuda_tracing_device_interface_vtable_t = {
.destroy = iree_hal_cuda_tracing_device_interface_destroy,
.synchronize_native_event =
iree_hal_cuda_tracing_device_interface_synchronize_native_event,
.create_native_event =
iree_hal_cuda_tracing_device_interface_create_native_event,
.query_native_event =
iree_hal_cuda_tracing_device_interface_query_native_event,
.event_elapsed_time =
iree_hal_cuda_tracing_device_interface_event_elapsed_time,
.destroy_native_event =
iree_hal_cuda_tracing_device_interface_destroy_native_event,
.record_native_event =
iree_hal_cuda_tracing_device_interface_record_native_event,
.add_graph_event_record_node =
iree_hal_cuda_tracing_device_interface_add_graph_event_record_node,
};
Loading

0 comments on commit 86ecf39

Please sign in to comment.