From 7b22f09f7b04af7cf062e3fa8937c59197fa3993 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 29 Aug 2024 10:46:07 -0700 Subject: [PATCH] Adding command buffer recording state validation. (#18398) This will error out if command buffers are re-recorded or if a submission is made with a command buffer that has not been recorded or that is still in the recording state. --- runtime/src/iree/hal/command_buffer.c | 8 +++- runtime/src/iree/hal/command_buffer.h | 2 +- .../src/iree/hal/command_buffer_validation.c | 38 +++++++++++++++---- .../src/iree/hal/command_buffer_validation.h | 12 ++++-- runtime/src/iree/hal/device.c | 2 +- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c index cf77e4a0fdeb..44d767a900c9 100644 --- a/runtime/src/iree/hal/command_buffer.c +++ b/runtime/src/iree/hal/command_buffer.c @@ -587,11 +587,17 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( // Validation support //===----------------------------------------------------------------------===// -IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_binding_table( +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission( iree_hal_command_buffer_t* command_buffer, const iree_hal_buffer_binding_table_t* binding_table) { IREE_ASSERT_ARGUMENT(command_buffer); + // Validate the command buffer has been recorded properly. + IF_VALIDATING(command_buffer, { + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_submission_validation( + command_buffer, VALIDATION_STATE(command_buffer))); + }); + // Only check binding tables when one is required and otherwise ignore any // bindings provided. Require at least as many bindings in the table as there // are used by the command buffer. This may be less than the total capacity diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h index fe9b3e8837e4..7d2cded12dc7 100644 --- a/runtime/src/iree/hal/command_buffer.h +++ b/runtime/src/iree/hal/command_buffer.h @@ -761,7 +761,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( // requirements of |command_buffer| as recorded. If the command buffer does not // use any indirect bindings the table will be ignored. If more bindings than // are used by the command buffer are provided they will be ignored. -IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_binding_table( +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission( iree_hal_command_buffer_t* command_buffer, const iree_hal_buffer_binding_table_t* binding_table); diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c index 28b9360856d6..832e652a8407 100644 --- a/runtime/src/iree/hal/command_buffer_validation.c +++ b/runtime/src/iree/hal/command_buffer_validation.c @@ -23,11 +23,14 @@ static iree_status_t iree_hal_command_buffer_validate_categories( const iree_hal_command_buffer_t* command_buffer, const iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_command_category_t required_categories) { - if (IREE_UNLIKELY(!validation_state->is_recording)) { + if (IREE_UNLIKELY(!validation_state->has_began || + validation_state->has_ended)) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, - "command buffer is not in a recording state"); + "command buffer is not in a recording state; all " + "recording must happen between a begin/end pair"); } - if (!iree_all_bits_set(command_buffer->allowed_categories, + if (required_categories && + !iree_all_bits_set(command_buffer->allowed_categories, required_categories)) { #if IREE_STATUS_MODE iree_bitfield_string_temp_t temp0, temp1; @@ -220,18 +223,23 @@ void iree_hal_command_buffer_initialize_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* out_validation_state) { out_validation_state->device_allocator = device_allocator; - out_validation_state->is_recording = false; + out_validation_state->has_began = false; + out_validation_state->has_ended = false; out_validation_state->debug_group_depth = 0; } iree_status_t iree_hal_command_buffer_begin_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state) { - if (validation_state->is_recording) { + if (validation_state->has_began && validation_state->has_ended) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "command buffer has already been recorded; " + "re-recording command buffers is not allowed"); + } else if (validation_state->has_began) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "command buffer is already in a recording state"); } - validation_state->is_recording = true; + validation_state->has_began = true; return iree_ok_status(); } @@ -242,11 +250,11 @@ iree_status_t iree_hal_command_buffer_end_validation( return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "unbalanced debug group depth (expected 0, is %d)", validation_state->debug_group_depth); - } else if (!validation_state->is_recording) { + } else if (!validation_state->has_began || validation_state->has_ended) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "command buffer is not in a recording state"); } - validation_state->is_recording = false; + validation_state->has_ended = true; return iree_ok_status(); } @@ -637,6 +645,20 @@ iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( bindings, flags); } +iree_status_t iree_hal_command_buffer_submission_validation( + iree_hal_command_buffer_t* command_buffer, + const iree_hal_command_buffer_validation_state_t* validation_state) { + if (!validation_state->has_began) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "command buffer has not been recorded"); + } else if (!validation_state->has_ended) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "command buffer recording has not been ended and " + "it is still in a recording state"); + } + return iree_ok_status(); +} + iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer, const iree_hal_command_buffer_validation_state_t* validation_state, diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h index 2174c06c3098..dee7bb4085ec 100644 --- a/runtime/src/iree/hal/command_buffer_validation.h +++ b/runtime/src/iree/hal/command_buffer_validation.h @@ -30,10 +30,12 @@ typedef struct iree_hal_command_buffer_validation_state_t { // Allocator from the device the command buffer is targeting. // Used to verify buffer compatibility. iree_hal_allocator_t* device_allocator; - // 1 when in a begin/end recording sequence. - int32_t is_recording : 1; + // 1 when begin has been called. + int32_t has_began : 1; + // 1 when end has been called. + int32_t has_ended : 1; // Debug group depth for tracking proper begin/end pairing. - int32_t debug_group_depth : 31; + int32_t debug_group_depth : 30; // TODO(benvanik): current pipeline layout/descriptor set layout info. // TODO(benvanik): valid push constant bit ranges. // Requirements for each binding table entry. @@ -140,6 +142,10 @@ iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); +iree_status_t iree_hal_command_buffer_submission_validation( + iree_hal_command_buffer_t* command_buffer, + const iree_hal_command_buffer_validation_state_t* validation_state); + iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer, const iree_hal_command_buffer_validation_state_t* validation_state, diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c index 40f261089dad..7ae9abb4c169 100644 --- a/runtime/src/iree/hal/device.c +++ b/runtime/src/iree/hal/device.c @@ -313,7 +313,7 @@ IREE_API_EXPORT iree_status_t iree_hal_device_queue_execute( for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, - iree_hal_command_buffer_validate_binding_table( + iree_hal_command_buffer_validate_submission( command_buffers[i], binding_tables ? &binding_tables[i] : NULL)); }