Skip to content

Commit

Permalink
[runtime][hip][cuda] Fix waiting on a semaphore on the host (#17073)
Browse files Browse the repository at this point in the history
When waiting on a semaphore we first attempt to wait on a timepoint
event on the dispatch stream.
This is not correct as the semaphore is signaled in a callback on the
host stream when signaling that an action is complete.
This callback happens after the dispatch stream event, so the semaphore
value may not be update yet.

The fix is to directly wait on the host semaphore and not on the
dispatch stream event.

Also add a test that catches this failure.
  • Loading branch information
sogartar authored Apr 17, 2024
1 parent fdfe344 commit cd282de
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 39 deletions.
97 changes: 94 additions & 3 deletions runtime/src/iree/hal/cts/semaphore_submission_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,100 @@ TEST_P(semaphore_submission_test,
iree_hal_semaphore_release(device_signal_semaphore);
}

// TODO: test device -> device synchronization: submit two batches with a
// semaphore singal -> wait dependency.
//
// Test device -> device synchronization: submit two batches with a
// semaphore signal -> wait dependency.
TEST_P(semaphore_submission_test, IntermediateSemaphoreBetweenDeviceBatches) {
// The signaling relationship is
// command_buffer1 -> semaphore1 -> command_buffer2 -> semaphore2

// Create first command buffer.
iree_hal_command_buffer_t* command_buffer1 = NULL;
IREE_ASSERT_OK(iree_hal_command_buffer_create(
device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
/*binding_capacity=*/0, &command_buffer1));
IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer1));
IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer1));

// Create second command buffer.
iree_hal_command_buffer_t* command_buffer2 = NULL;
IREE_ASSERT_OK(iree_hal_command_buffer_create(
device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
/*binding_capacity=*/0, &command_buffer2));
IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer2));
IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer2));

// Semaphore to signal when command_buffer1 is done and to wait to
// start executing command_buffer2.
iree_hal_semaphore_t* semaphore1 = NULL;
IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0, &semaphore1));
uint64_t semaphore1_value = 1;
iree_hal_semaphore_list_t semaphore1_list = {/*count=*/1, &semaphore1,
&semaphore1_value};

// Semaphore to signal when all work (command_buffer2) is done.
iree_hal_semaphore_t* semaphore2 = NULL;
IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0, &semaphore2));
uint64_t semaphore2_value = 1;
iree_hal_semaphore_list_t semaphore2_list = {/*count=*/1, &semaphore2,
&semaphore2_value};

// Dispatch the second command buffer.
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/semaphore1_list,
/*signal_semaphore_list=*/semaphore2_list, 1, &command_buffer2));

// Make sure that the intermediate and second semaphores have not advanced
// since only command_buffer2 is queued.
uint64_t semaphore2_value_after_queueing_command_buffer2;
IREE_ASSERT_OK(iree_hal_semaphore_query(
semaphore2, &semaphore2_value_after_queueing_command_buffer2));
EXPECT_EQ(static_cast<uint64_t>(0),
semaphore2_value_after_queueing_command_buffer2);
uint64_t semaphore1_value_after_queueing_command_buffer2;
IREE_ASSERT_OK(iree_hal_semaphore_query(
semaphore1, &semaphore1_value_after_queueing_command_buffer2));
EXPECT_EQ(static_cast<uint64_t>(0),
semaphore1_value_after_queueing_command_buffer2);

// Submit the first command buffer.
iree_hal_semaphore_list_t command_buffer1_wait_semaphore_list = {
/*count=*/0, nullptr, nullptr};
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer1_wait_semaphore_list,
/*signal_semaphore_list=*/semaphore1_list, 1, &command_buffer1));

// Wait on the intermediate semaphore and check its value.
IREE_ASSERT_OK(
iree_hal_semaphore_wait(semaphore1, semaphore1_value,
iree_make_deadline(IREE_TIME_INFINITE_FUTURE)));
uint64_t semaphore1_value_after_command_buffer1_has_done_executing;
IREE_ASSERT_OK(iree_hal_semaphore_query(
semaphore1, &semaphore1_value_after_command_buffer1_has_done_executing));
uint64_t expected_semaphore1_value = semaphore1_value;
EXPECT_EQ(semaphore1_value,
semaphore1_value_after_command_buffer1_has_done_executing);

// Wait on the second semaphore and check its value.
IREE_ASSERT_OK(
iree_hal_semaphore_wait(semaphore2, semaphore2_value,
iree_make_deadline(IREE_TIME_INFINITE_FUTURE)));
uint64_t semaphore2_value_after_command_buffer2_has_done_executing;
IREE_ASSERT_OK(iree_hal_semaphore_query(
semaphore2, &semaphore2_value_after_command_buffer2_has_done_executing));
uint64_t expected_semaphore2_value = semaphore2_value;
EXPECT_EQ(expected_semaphore2_value,
semaphore2_value_after_command_buffer2_has_done_executing);

iree_hal_command_buffer_release(command_buffer1);
iree_hal_command_buffer_release(command_buffer2);
iree_hal_semaphore_release(semaphore1);
iree_hal_semaphore_release(semaphore2);
}

// TODO: test device -> device synchronization: submit multiple batches with
// multiple later batches waiting on the same signaling from a former batch.
//
Expand Down
18 changes: 1 addition & 17 deletions runtime/src/iree/hal/drivers/cuda/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,23 +282,6 @@ static iree_status_t iree_hal_cuda_semaphore_wait(
}
iree_slim_mutex_unlock(&semaphore->mutex);

iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);

// Slow path: try to see if we can have a device CUevent to wait on. This
// should happen outside of the lock given that acquiring has its own internal
// locks. This is faster than waiting on a host timepoint.
iree_hal_cuda_event_t* wait_event = NULL;
if (iree_hal_cuda_semaphore_acquire_event_host_wait(&semaphore->base, value,
&wait_event)) {
IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, semaphore->symbols,
cuEventSynchronize(iree_hal_cuda_event_handle(wait_event)),
"cuEventSynchronize");
iree_hal_cuda_event_release(wait_event);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

// Slow path: acquire a timepoint. This should happen outside of the lock too
// given that acquiring has its own internal locks.
iree_hal_cuda_timepoint_t* timepoint = NULL;
Expand All @@ -312,6 +295,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait(
// Wait until the timepoint resolves.
// If satisfied the timepoint is automatically cleaned up and we are done. If
// the deadline is reached before satisfied then we have to clean it up.
iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
status = iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns);
if (!iree_status_is_ok(status)) {
iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base);
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ iree_status_t iree_hal_cuda_pending_queue_actions_issue(
}
action->events[action->event_count++] = wait_event;

// Remove the wait timepoint as we have a correspnding event that we
// Remove the wait timepoint as we have a corresponding event that we
// will wait on.
iree_hal_semaphore_list_remove_element(&action->wait_semaphore_list, i);
--i;
Expand Down
18 changes: 1 addition & 17 deletions runtime/src/iree/hal/drivers/hip/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -281,23 +281,6 @@ static iree_status_t iree_hal_hip_semaphore_wait(
}
iree_slim_mutex_unlock(&semaphore->mutex);

iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);

// Slow path: try to see if we can have a device hipEvent_t to wait on. This
// should happen outside of the lock given that acquiring has its own internal
// locks. This is faster than waiting on a host timepoint.
iree_hal_hip_event_t* wait_event = NULL;
if (iree_hal_hip_semaphore_acquire_event_host_wait(&semaphore->base, value,
&wait_event)) {
IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
z0, semaphore->symbols,
hipEventSynchronize(iree_hal_hip_event_handle(wait_event)),
"hipEventSynchronize");
iree_hal_hip_event_release(wait_event);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

// Slow path: acquire a timepoint. This should happen outside of the lock too
// given that acquiring has its own internal locks.
iree_hal_hip_timepoint_t* timepoint = NULL;
Expand All @@ -311,6 +294,7 @@ static iree_status_t iree_hal_hip_semaphore_wait(
// Wait until the timepoint resolves.
// If satisfied the timepoint is automatically cleaned up and we are done. If
// the deadline is reached before satisfied then we have to clean it up.
iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
status = iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns);
if (!iree_status_is_ok(status)) {
iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base);
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue(
}
action->events[action->event_count++] = wait_event;

// Remove the wait timepoint as we have a correspnding event that we
// Remove the wait timepoint as we have a corresponding event that we
// will wait on.
iree_hal_semaphore_list_remove_element(&action->wait_semaphore_list, i);
--i;
Expand Down

0 comments on commit cd282de

Please sign in to comment.