From a2df2cc2493b72c3ecbc8ff45ad67d3ed027d285 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 23 Jan 2024 11:10:46 -0800 Subject: [PATCH] [cuda] Collect tracing events after command buffer completion (#16158) Now we have proper async execution in the cuda HAL driver, command buffers may not execute immediately after enqueuing, so we should not collect the tracing events there. Instead, we should collect when we know the command buffers have completed in a deferred and async manner. --- runtime/src/iree/hal/drivers/cuda2/cuda_device.c | 12 +++++++++--- .../iree/hal/drivers/cuda2/pending_queue_actions.c | 12 ++++++++++++ .../iree/hal/drivers/cuda2/pending_queue_actions.h | 9 +++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c index ac519a2f3104..1d0a74a481b7 100644 --- a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c @@ -742,6 +742,11 @@ static iree_status_t iree_hal_cuda2_device_queue_write( return loop_status; } +static void iree_hal_cuda2_device_collect_tracing_context(void* user_data) { + iree_hal_cuda2_tracing_context_collect( + (iree_hal_cuda2_tracing_context_t*)user_data); +} + static iree_status_t iree_hal_cuda2_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -753,15 +758,16 @@ static iree_status_t iree_hal_cuda2_device_queue_execute( iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution( base_device, device->dispatch_cu_stream, device->callback_cu_stream, - device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list, - command_buffer_count, command_buffers); + device->pending_queue_actions, + iree_hal_cuda2_device_collect_tracing_context, device->tracing_context, + wait_semaphore_list, signal_semaphore_list, command_buffer_count, + command_buffers); if (iree_status_is_ok(status)) { // Try to advance the pending workload queue. status = iree_hal_cuda2_pending_queue_actions_issue( device->pending_queue_actions); } - iree_hal_cuda2_tracing_context_collect(device->tracing_context); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c index 4886ebc757db..ab76728f2be3 100644 --- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c +++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c @@ -49,6 +49,12 @@ typedef struct iree_hal_cuda2_queue_action_t { // Retained to make sure it outlives the current action. iree_hal_cuda2_pending_queue_actions_t* owning_actions; + // The callback to run after completing this action and before freeing + // all resources. + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback; + // User data to pass into the callback. + void* callback_user_data; + iree_hal_cuda2_queue_action_kind_t kind; union { struct { @@ -403,6 +409,8 @@ static void iree_hal_cuda2_free_semaphore_list( iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( iree_hal_device_t* device, CUstream dispatch_stream, CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions, + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback, + void* callback_user_data, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_host_size_t command_buffer_count, @@ -417,6 +425,8 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( (void**)&action)); action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION; + action->cleanup_callback = cleanup_callback; + action->callback_user_data = callback_user_data; action->device = device; action->dispatch_cu_stream = dispatch_stream; action->callback_cu_stream = callback_stream; @@ -604,6 +614,8 @@ static void iree_hal_cuda2_pending_queue_actions_cleanup_execution( iree_allocator_t host_allocator = actions->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + action->cleanup_callback(action->callback_user_data); + iree_hal_resource_set_free(action->resource_set); iree_hal_cuda2_free_semaphore_list(host_allocator, &action->wait_semaphore_list); diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h index 1484c2bda8ff..574d4c39a6ad 100644 --- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h +++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h @@ -45,11 +45,20 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_create( // Destroys the pending |actions| queue. void iree_hal_cuda2_pending_queue_actions_destroy(iree_hal_resource_t* actions); +// Callback to execute user code after action completion but before resource +// releasing. +// +// Data behind |user_data| must remain alive before the action is released. +typedef void(IREE_API_PTR* iree_hal_cuda2_pending_action_cleanup_callback_t)( + void* user_data); + // Enqueues the given list of |command_buffers| that waits on // |wait_semaphore_list| and signals |signal_semaphore_lsit|. iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( iree_hal_device_t* device, CUstream dispatch_stream, CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions, + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback, + void* callback_user_data, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_host_size_t command_buffer_count,