Skip to content

Commit

Permalink
[cuda] Collect tracing events after command buffer completion (iree-o…
Browse files Browse the repository at this point in the history
…rg#16158)

Now we have proper async execution in the cuda HAL driver, command
buffers may not execute immediately after enqueuing, so we should not
collect the tracing events there. Instead, we should collect when we
know the command buffers have completed in a deferred and async manner.
  • Loading branch information
antiagainst authored Jan 23, 2024
1 parent 8dae5b5 commit a2df2cc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
12 changes: 9 additions & 3 deletions runtime/src/iree/hal/drivers/cuda2/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ static iree_status_t iree_hal_cuda2_device_queue_write(
return loop_status;
}

static void iree_hal_cuda2_device_collect_tracing_context(void* user_data) {
iree_hal_cuda2_tracing_context_collect(
(iree_hal_cuda2_tracing_context_t*)user_data);
}

static iree_status_t iree_hal_cuda2_device_queue_execute(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
Expand All @@ -753,15 +758,16 @@ static iree_status_t iree_hal_cuda2_device_queue_execute(

iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
base_device, device->dispatch_cu_stream, device->callback_cu_stream,
device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
command_buffer_count, command_buffers);
device->pending_queue_actions,
iree_hal_cuda2_device_collect_tracing_context, device->tracing_context,
wait_semaphore_list, signal_semaphore_list, command_buffer_count,
command_buffers);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda2_pending_queue_actions_issue(
device->pending_queue_actions);
}

iree_hal_cuda2_tracing_context_collect(device->tracing_context);
IREE_TRACE_ZONE_END(z0);
return status;
}
Expand Down
12 changes: 12 additions & 0 deletions runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ typedef struct iree_hal_cuda2_queue_action_t {
// Retained to make sure it outlives the current action.
iree_hal_cuda2_pending_queue_actions_t* owning_actions;

// The callback to run after completing this action and before freeing
// all resources.
iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback;
// User data to pass into the callback.
void* callback_user_data;

iree_hal_cuda2_queue_action_kind_t kind;
union {
struct {
Expand Down Expand Up @@ -403,6 +409,8 @@ static void iree_hal_cuda2_free_semaphore_list(
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback,
void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
Expand All @@ -417,6 +425,8 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
(void**)&action));

action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION;
action->cleanup_callback = cleanup_callback;
action->callback_user_data = callback_user_data;
action->device = device;
action->dispatch_cu_stream = dispatch_stream;
action->callback_cu_stream = callback_stream;
Expand Down Expand Up @@ -604,6 +614,8 @@ static void iree_hal_cuda2_pending_queue_actions_cleanup_execution(
iree_allocator_t host_allocator = actions->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);

action->cleanup_callback(action->callback_user_data);

iree_hal_resource_set_free(action->resource_set);
iree_hal_cuda2_free_semaphore_list(host_allocator,
&action->wait_semaphore_list);
Expand Down
9 changes: 9 additions & 0 deletions runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_create(
// Destroys the pending |actions| queue.
void iree_hal_cuda2_pending_queue_actions_destroy(iree_hal_resource_t* actions);

// Callback to execute user code after action completion but before resource
// releasing.
//
// Data behind |user_data| must remain alive before the action is released.
typedef void(IREE_API_PTR* iree_hal_cuda2_pending_action_cleanup_callback_t)(
void* user_data);

// Enqueues the given list of |command_buffers| that waits on
// |wait_semaphore_list| and signals |signal_semaphore_lsit|.
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback,
void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
Expand Down

0 comments on commit a2df2cc

Please sign in to comment.