From 1c10bebe327bbf21babdd0414a47f86daecc100e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 30 Oct 2023 19:32:44 -0700 Subject: [PATCH] [cuda] Avoid sorting when composing kernel arguments (#15325) This commit changes the graph command buffer to keep track of a fixed list of descriptor sets. This simplifies the logic of pushing descriptors/constants; we only perform kernel parameter serialization at the dispatch time. This means we don't pay the overhead if multiple push descriptor/constant commands are issued before dispatching. Also we don't need to sort descriptors anymore. Progress towards https://github.com/openxla/iree/issues/13245 --- experimental/cuda2/graph_command_buffer.c | 222 +++++++++++----------- experimental/cuda2/native_executable.c | 31 ++- experimental/cuda2/native_executable.h | 10 +- experimental/cuda2/pipeline_layout.c | 64 +++++-- experimental/cuda2/pipeline_layout.h | 32 +++- 5 files changed, 212 insertions(+), 147 deletions(-) diff --git a/experimental/cuda2/graph_command_buffer.c b/experimental/cuda2/graph_command_buffer.c index fb7a1b7853d36..2dc3594c39839 100644 --- a/experimental/cuda2/graph_command_buffer.c +++ b/experimental/cuda2/graph_command_buffer.c @@ -6,7 +6,6 @@ #include "experimental/cuda2/graph_command_buffer.h" -#include #include #include @@ -14,18 +13,11 @@ #include "experimental/cuda2/cuda_dynamic_symbols.h" #include "experimental/cuda2/cuda_status_util.h" #include "experimental/cuda2/native_executable.h" -#include "experimental/cuda2/nccl_channel.h" #include "experimental/cuda2/pipeline_layout.h" #include "iree/base/api.h" #include "iree/hal/utils/collective_batch.h" #include "iree/hal/utils/resource_set.h" -// The maximal number of descriptor bindings supported in the CUDA HAL driver. -#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64 -// The maximal number of kernel arguments supported in the CUDA HAL driver for -// descriptor bindings and push constants. -#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128 - // Command buffer implementation that directly records into CUDA graphs. // The command buffer records the commands on the calling thread without // additional threading indirection. @@ -45,8 +37,8 @@ typedef struct iree_hal_cuda2_graph_command_buffer_t { CUcontext cu_context; // The CUDA graph under construction. - CUgraph graph; - CUgraphExec exec; + CUgraph cu_graph; + CUgraphExec cu_graph_exec; // The last node added to the command buffer. // We need to track it as we are currently serializing all the nodes (each @@ -56,12 +48,13 @@ typedef struct iree_hal_cuda2_graph_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; - int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; + int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - // The current set of kernel arguments. - void* current_descriptor[]; + // The current bound descriptor sets. + struct { + CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; + } descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT]; } iree_hal_cuda2_graph_command_buffer_t; -// + Additional inline allocation for holding all kernel arguments. static const iree_hal_command_buffer_vtable_t iree_hal_cuda2_graph_command_buffer_vtable; @@ -94,11 +87,8 @@ iree_status_t iree_hal_cuda2_graph_command_buffer_create( IREE_TRACE_ZONE_BEGIN(z0); iree_hal_cuda2_graph_command_buffer_t* command_buffer = NULL; - size_t total_size = sizeof(*command_buffer) + - IREE_HAL_CUDA_MAX_KERNEL_ARG * sizeof(void*) + - IREE_HAL_CUDA_MAX_KERNEL_ARG * sizeof(CUdeviceptr); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, total_size, + z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), (void**)&command_buffer)); iree_hal_command_buffer_initialize( @@ -108,16 +98,10 @@ iree_status_t iree_hal_cuda2_graph_command_buffer_create( command_buffer->symbols = cuda_symbols; iree_arena_initialize(block_pool, &command_buffer->arena); command_buffer->cu_context = context; - command_buffer->graph = NULL; - command_buffer->exec = NULL; + command_buffer->cu_graph = NULL; + command_buffer->cu_graph_exec = NULL; command_buffer->last_node = NULL; - CUdeviceptr* device_ptrs = (CUdeviceptr*)(command_buffer->current_descriptor + - IREE_HAL_CUDA_MAX_KERNEL_ARG); - for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) { - command_buffer->current_descriptor[i] = &device_ptrs[i]; - } - iree_status_t status = iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); @@ -147,15 +131,15 @@ static void iree_hal_cuda2_graph_command_buffer_destroy( // Drop any pending collective batches before we tear things down. iree_hal_collective_batch_clear(&command_buffer->collective_batch); - if (command_buffer->graph != NULL) { + if (command_buffer->cu_graph != NULL) { IREE_CUDA_IGNORE_ERROR(command_buffer->symbols, - cuGraphDestroy(command_buffer->graph)); - command_buffer->graph = NULL; + cuGraphDestroy(command_buffer->cu_graph)); + command_buffer->cu_graph = NULL; } - if (command_buffer->exec != NULL) { + if (command_buffer->cu_graph_exec != NULL) { IREE_CUDA_IGNORE_ERROR(command_buffer->symbols, - cuGraphExecDestroy(command_buffer->exec)); - command_buffer->exec = NULL; + cuGraphExecDestroy(command_buffer->cu_graph_exec)); + command_buffer->cu_graph_exec = NULL; } command_buffer->last_node = NULL; @@ -177,7 +161,7 @@ CUgraphExec iree_hal_cuda2_graph_command_buffer_handle( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_cuda2_graph_command_buffer_t* command_buffer = iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer); - return command_buffer->exec; + return command_buffer->cu_graph_exec; } // Flushes any pending batched collective operations. @@ -226,15 +210,15 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_begin( iree_hal_cuda2_graph_command_buffer_t* command_buffer = iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer); - if (command_buffer->graph != NULL) { + if (command_buffer->cu_graph != NULL) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "command buffer cannot be re-recorded"); } // Create a new empty graph to record into. - IREE_CUDA_RETURN_IF_ERROR(command_buffer->symbols, - cuGraphCreate(&command_buffer->graph, /*flags=*/0), - "cuGraphCreate"); + IREE_CUDA_RETURN_IF_ERROR( + command_buffer->symbols, + cuGraphCreate(&command_buffer->cu_graph, /*flags=*/0), "cuGraphCreate"); return iree_ok_status(); } @@ -255,15 +239,15 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_end( CUgraphNode error_node = NULL; iree_status_t status = IREE_CURESULT_TO_STATUS( command_buffer->symbols, - cuGraphInstantiate(&command_buffer->exec, command_buffer->graph, - &error_node, + cuGraphInstantiate(&command_buffer->cu_graph_exec, + command_buffer->cu_graph, &error_node, /*logBuffer=*/NULL, /*bufferSize=*/0)); if (iree_status_is_ok(status)) { // No longer need the source graph used for construction. IREE_CUDA_IGNORE_ERROR(command_buffer->symbols, - cuGraphDestroy(command_buffer->graph)); - command_buffer->graph = NULL; + cuGraphDestroy(command_buffer->cu_graph)); + command_buffer->cu_graph = NULL; } iree_hal_resource_set_freeze(command_buffer->resource_set); @@ -416,7 +400,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_fill_buffer( size_t numNode = command_buffer->last_node ? 1 : 0; IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( z0, command_buffer->symbols, - cuGraphAddMemsetNode(&command_buffer->last_node, command_buffer->graph, + cuGraphAddMemsetNode(&command_buffer->last_node, command_buffer->cu_graph, dep, numNode, ¶ms, command_buffer->cu_context), "cuGraphAddMemsetNode"); @@ -471,7 +455,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_update_buffer( IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( z0, command_buffer->symbols, - cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph, + cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->cu_graph, dep, numNode, ¶ms, command_buffer->cu_context), "cuGraphAddMemcpyNode"); @@ -522,7 +506,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_copy_buffer( IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( z0, command_buffer->symbols, - cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph, + cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->cu_graph, dep, numNode, ¶ms, command_buffer->cu_context), "cuGraphAddMemcpyNode"); @@ -550,77 +534,44 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_push_constants( iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer); iree_host_size_t constant_base_index = offset / sizeof(int32_t); for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { - command_buffer->push_constant[i + constant_base_index] = + command_buffer->push_constants[i + constant_base_index] = ((uint32_t*)values)[i]; } return iree_ok_status(); } -typedef struct { - // The original index into the iree_hal_descriptor_set_binding_t array. - uint32_t index; - // The descriptor binding number. - uint32_t binding; -} iree_hal_cuda2_binding_mapping_t; - -// Compares two iree_hal_cuda2_binding_mapping_t according to the descriptor -// binding number. -static int compare_binding_index(const void* a, const void* b) { - const iree_hal_cuda2_binding_mapping_t buffer_a = - *(const iree_hal_cuda2_binding_mapping_t*)a; - const iree_hal_cuda2_binding_mapping_t buffer_b = - *(const iree_hal_cuda2_binding_mapping_t*)b; - return buffer_a.binding < buffer_b.binding ? -1 : 1; -} - static iree_status_t iree_hal_cuda2_graph_command_buffer_push_descriptor_set( iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, iree_host_size_t binding_count, const iree_hal_descriptor_set_binding_t* bindings) { - IREE_ASSERT_LT(binding_count, IREE_HAL_CUDA_MAX_BINDING_COUNT, - "binding count larger than the max expected"); + if (binding_count > IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { + return iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "exceeded available binding slots for push " + "descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d", + set, binding_count, IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT); + } + iree_hal_cuda2_graph_command_buffer_t* command_buffer = iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - iree_host_size_t base_binding = - iree_hal_cuda2_pipeline_layout_base_binding_index(pipeline_layout, set); - - // Convention with the compiler side. We map descriptor bindings to kernel - // argument. We compact the descriptor binding number ranges to get a dense - // set of kernel arguments and keep them ordered based on the descriptor - // binding index. - iree_hal_cuda2_binding_mapping_t - sorted_bindings[IREE_HAL_CUDA_MAX_BINDING_COUNT]; + CUdeviceptr* current_bindings = command_buffer->descriptor_sets[set].bindings; for (iree_host_size_t i = 0; i < binding_count; i++) { - sorted_bindings[i].index = i; - sorted_bindings[i].binding = bindings[i].binding; - } - // Sort the binding based on the binding index and map the (base offset + - // array index) to the kernel argument index. - // TODO: remove this sort - it's thankfully small (1-8 on average) but we - // should be able to avoid it like we do on the CPU side with a bitmap. - qsort(sorted_bindings, binding_count, - sizeof(iree_hal_cuda2_binding_mapping_t), compare_binding_index); - - for (iree_host_size_t i = 0; i < binding_count; i++) { - const iree_hal_descriptor_set_binding_t* binding = - &bindings[sorted_bindings[i].index]; + const iree_hal_descriptor_set_binding_t* binding = &bindings[i]; CUdeviceptr device_ptr = 0; if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + CUdeviceptr device_buffer = iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(binding->buffer)); iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); device_ptr = device_buffer + offset + binding->offset; }; - *((CUdeviceptr*)command_buffer->current_descriptor[base_binding + i]) = - device_ptr; - if (binding->buffer) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, - &binding->buffer)); - } + current_bindings[binding->binding] = device_ptr; } IREE_TRACE_ZONE_END(z0); @@ -641,35 +592,86 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch( // Lookup kernel parameters used for side-channeling additional launch // information from the compiler. - iree_hal_cuda2_kernel_params_t kernel_params; + iree_hal_cuda2_kernel_info_t kernel_info; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda2_native_executable_entry_point_kernel_params( - executable, entry_point, &kernel_params)); + z0, iree_hal_cuda2_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable)); - // Patch the push constants in the kernel arguments. - iree_host_size_t num_constants = - iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_params.layout); + // The total number of descriptors across all descriptor sets. + iree_host_size_t descriptor_count = + iree_hal_cuda2_pipeline_layout_total_binding_count(kernel_info.layout); + // The total number of push constants. + iree_host_size_t push_constant_count = + iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_info.layout); + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = descriptor_count + push_constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // Per CUDA API requirements, we need two levels of indirection for passing + // kernel arguments in. + // "If the kernel has N parameters, then kernelParams needs to be an array + // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1], + // points to the region of memory from which the actual parameter will be + // copied." + // + // (From the cuGraphAddKernelNode API doc in + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b) + // + // It means each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + uint8_t* storage_base = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, + (void**)&storage_base)); + void** params_ptr = (void**)storage_base; + + // Set up kernel arguments to point to the payload slots. + CUdeviceptr* payload_ptr = + (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + + // Copy descriptors from all sets to the end of the current segment for later + // access. + iree_host_size_t set_count = + iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout); + for (iree_host_size_t i = 0; i < set_count; ++i) { + iree_host_size_t binding_count = + iree_hal_cuda2_descriptor_set_layout_binding_count( + iree_hal_cuda2_pipeline_layout_descriptor_set_layout( + kernel_info.layout, i)); + iree_host_size_t index = iree_hal_cuda2_pipeline_layout_base_binding_index( + kernel_info.layout, i); + memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings, + binding_count * sizeof(CUdeviceptr)); + } + + // Append the push constants to the kernel arguments. iree_host_size_t base_index = - iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_params.layout); - for (iree_host_size_t i = 0; i < num_constants; i++) { - *((uint32_t*)command_buffer->current_descriptor[base_index + i]) = - command_buffer->push_constant[i]; + iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout); + for (iree_host_size_t i = 0; i < push_constant_count; i++) { + *((uint32_t*)params_ptr[base_index + i]) = + command_buffer->push_constants[i]; } CUDA_KERNEL_NODE_PARAMS params = { - .func = kernel_params.function, - .blockDimX = kernel_params.block_size[0], - .blockDimY = kernel_params.block_size[1], - .blockDimZ = kernel_params.block_size[2], + .func = kernel_info.function, + .blockDimX = kernel_info.block_size[0], + .blockDimY = kernel_info.block_size[1], + .blockDimZ = kernel_info.block_size[2], .gridDimX = workgroup_x, .gridDimY = workgroup_y, .gridDimZ = workgroup_z, - .kernelParams = command_buffer->current_descriptor, - .sharedMemBytes = kernel_params.shared_memory_size, + .kernelParams = params_ptr, + .sharedMemBytes = kernel_info.shared_memory_size, }; // Serialize all the nodes for now. @@ -678,7 +680,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch( IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( z0, command_buffer->symbols, - cuGraphAddKernelNode(&command_buffer->last_node, command_buffer->graph, + cuGraphAddKernelNode(&command_buffer->last_node, command_buffer->cu_graph, dep, numNodes, ¶ms), "cuGraphAddKernelNode"); diff --git a/experimental/cuda2/native_executable.c b/experimental/cuda2/native_executable.c index 48794479d3506..db74f75a0a333 100644 --- a/experimental/cuda2/native_executable.c +++ b/experimental/cuda2/native_executable.c @@ -33,7 +33,7 @@ typedef struct iree_hal_cuda2_native_executable_t { iree_host_size_t entry_point_count; // The list of entry point data pointers, pointing to trailing inline // allocation after the end of this struct. - iree_hal_cuda2_kernel_params_t entry_points[]; + iree_hal_cuda2_kernel_info_t entry_points[]; } iree_hal_cuda2_native_executable_t; // + Additional inline allocation for holding entry point information. @@ -225,20 +225,20 @@ iree_status_t iree_hal_cuda2_native_executable_create( if (!iree_status_is_ok(status)) break; // Package required parameters for kernel launches for each entry point. - iree_hal_cuda2_kernel_params_t* params = &executable->entry_points[i]; - params->layout = executable_params->pipeline_layouts[i]; - iree_hal_pipeline_layout_retain(params->layout); - params->function = function; - params->block_size[0] = block_sizes_vec[i].x; - params->block_size[1] = block_sizes_vec[i].y; - params->block_size[2] = block_sizes_vec[i].z; - params->shared_memory_size = shared_memory_sizes[i]; + iree_hal_cuda2_kernel_info_t* info = &executable->entry_points[i]; + info->layout = executable_params->pipeline_layouts[i]; + iree_hal_pipeline_layout_retain(info->layout); + info->function = function; + info->block_size[0] = block_sizes_vec[i].x; + info->block_size[1] = block_sizes_vec[i].y; + info->block_size[2] = block_sizes_vec[i].z; + info->shared_memory_size = shared_memory_sizes[i]; // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); memcpy(string_table_buffer, entry_name, entry_name_length); - params->function_name = + info->function_name = iree_make_string_view(string_table_buffer, entry_name_length); string_table_buffer += entry_name_length; }); @@ -253,9 +253,9 @@ iree_status_t iree_hal_cuda2_native_executable_create( flatbuffers_string_t filename = iree_hal_cuda_FileLineLocDef_filename_get(source_loc); uint32_t line = iree_hal_cuda_FileLineLocDef_line_get(source_loc); - params->source_filename = + info->source_filename = iree_make_string_view(filename, flatbuffers_string_len(filename)); - params->source_line = line; + info->source_line = line; } }); } @@ -290,9 +290,9 @@ static void iree_hal_cuda2_native_executable_destroy( IREE_TRACE_ZONE_END(z0); } -iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_params( +iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_info( iree_hal_executable_t* base_executable, int32_t entry_point, - iree_hal_cuda2_kernel_params_t* out_params) { + iree_hal_cuda2_kernel_info_t* out_info) { iree_hal_cuda2_native_executable_t* executable = iree_hal_cuda2_native_executable_cast(base_executable); if (entry_point >= executable->entry_point_count) { @@ -301,8 +301,7 @@ iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_params( "only contains %" PRIhsz " entry points", entry_point, executable->entry_point_count); } - memcpy(out_params, &executable->entry_points[entry_point], - sizeof(*out_params)); + memcpy(out_info, &executable->entry_points[entry_point], sizeof(*out_info)); return iree_ok_status(); } diff --git a/experimental/cuda2/native_executable.h b/experimental/cuda2/native_executable.h index c0ac1e0e56146..2b633fe529acb 100644 --- a/experimental/cuda2/native_executable.h +++ b/experimental/cuda2/native_executable.h @@ -19,7 +19,7 @@ extern "C" { #endif // __cplusplus -typedef struct iree_hal_cuda2_kernel_params_t { +typedef struct iree_hal_cuda2_kernel_info_t { iree_hal_pipeline_layout_t* layout; CUfunction function; uint32_t block_size[3]; @@ -28,7 +28,7 @@ typedef struct iree_hal_cuda2_kernel_params_t { IREE_TRACE(iree_string_view_t function_name;) IREE_TRACE(iree_string_view_t source_filename;) IREE_TRACE(uint32_t source_line;) -} iree_hal_cuda2_kernel_params_t; +} iree_hal_cuda2_kernel_info_t; // Creates an IREE executable from a CUDA PTX module. The module may contain // several kernels that can be extracted along with the associated block size. @@ -37,11 +37,11 @@ iree_status_t iree_hal_cuda2_native_executable_create( const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable); -// Returns the kernel launch parameters for the given |entry_point| in the +// Returns the kernel launch information for the given |entry_point| in the // |executable|. -iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_params( +iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_info( iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_cuda2_kernel_params_t* out_params); + iree_hal_cuda2_kernel_info_t* out_info); #ifdef __cplusplus } // extern "C" diff --git a/experimental/cuda2/pipeline_layout.c b/experimental/cuda2/pipeline_layout.c index a9d13a4cf880f..bd64d037d903e 100644 --- a/experimental/cuda2/pipeline_layout.c +++ b/experimental/cuda2/pipeline_layout.c @@ -38,6 +38,14 @@ iree_hal_cuda2_descriptor_set_layout_cast( return (iree_hal_cuda2_descriptor_set_layout_t*)base_value; } +static const iree_hal_cuda2_descriptor_set_layout_t* +iree_hal_cuda2_descriptor_set_layout_const_cast( + const iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_cuda2_descriptor_set_layout_vtable); + return (const iree_hal_cuda2_descriptor_set_layout_t*)base_value; +} + iree_status_t iree_hal_cuda2_descriptor_set_layout_create( iree_hal_descriptor_set_layout_flags_t flags, iree_host_size_t binding_count, @@ -67,9 +75,10 @@ iree_status_t iree_hal_cuda2_descriptor_set_layout_create( } iree_host_size_t iree_hal_cuda2_descriptor_set_layout_binding_count( - iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { - iree_hal_cuda2_descriptor_set_layout_t* descriptor_set_layout = - iree_hal_cuda2_descriptor_set_layout_cast(base_descriptor_set_layout); + const iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + const iree_hal_cuda2_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_cuda2_descriptor_set_layout_const_cast( + base_descriptor_set_layout); return descriptor_set_layout->binding_count; } @@ -128,6 +137,13 @@ static iree_hal_cuda2_pipeline_layout_t* iree_hal_cuda2_pipeline_layout_cast( return (iree_hal_cuda2_pipeline_layout_t*)base_value; } +static const iree_hal_cuda2_pipeline_layout_t* +iree_hal_cuda2_pipeline_layout_const_cast( + const iree_hal_pipeline_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_pipeline_layout_vtable); + return (const iree_hal_cuda2_pipeline_layout_t*)base_value; +} + iree_status_t iree_hal_cuda2_pipeline_layout_create( iree_host_size_t set_layout_count, iree_hal_descriptor_set_layout_t* const* set_layouts, @@ -195,24 +211,48 @@ static void iree_hal_cuda2_pipeline_layout_destroy( IREE_TRACE_ZONE_END(z0); } +iree_host_size_t iree_hal_cuda2_pipeline_layout_descriptor_set_count( + const iree_hal_pipeline_layout_t* base_pipeline_layout) { + const iree_hal_cuda2_pipeline_layout_t* pipeline_layout = + iree_hal_cuda2_pipeline_layout_const_cast(base_pipeline_layout); + return pipeline_layout->set_layout_count; +} + +const iree_hal_descriptor_set_layout_t* +iree_hal_cuda2_pipeline_layout_descriptor_set_layout( + const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { + const iree_hal_cuda2_pipeline_layout_t* pipeline_layout = + iree_hal_cuda2_pipeline_layout_const_cast(base_pipeline_layout); + if (set < pipeline_layout->set_layout_count) { + return pipeline_layout->set_layouts[set].set_layout; + } + return NULL; +} + iree_host_size_t iree_hal_cuda2_pipeline_layout_base_binding_index( - iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { - iree_hal_cuda2_pipeline_layout_t* pipeline_layout = - iree_hal_cuda2_pipeline_layout_cast(base_pipeline_layout); + const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { + const iree_hal_cuda2_pipeline_layout_t* pipeline_layout = + iree_hal_cuda2_pipeline_layout_const_cast(base_pipeline_layout); return pipeline_layout->set_layouts[set].base_index; } +iree_host_size_t iree_hal_cuda2_pipeline_layout_total_binding_count( + const iree_hal_pipeline_layout_t* base_pipeline_layout) { + return iree_hal_cuda2_pipeline_layout_push_constant_index( + base_pipeline_layout); +} + iree_host_size_t iree_hal_cuda2_pipeline_layout_push_constant_index( - iree_hal_pipeline_layout_t* base_pipeline_layout) { - iree_hal_cuda2_pipeline_layout_t* pipeline_layout = - iree_hal_cuda2_pipeline_layout_cast(base_pipeline_layout); + const iree_hal_pipeline_layout_t* base_pipeline_layout) { + const iree_hal_cuda2_pipeline_layout_t* pipeline_layout = + iree_hal_cuda2_pipeline_layout_const_cast(base_pipeline_layout); return pipeline_layout->push_constant_base_index; } iree_host_size_t iree_hal_cuda2_pipeline_layout_push_constant_count( - iree_hal_pipeline_layout_t* base_pipeline_layout) { - iree_hal_cuda2_pipeline_layout_t* pipeline_layout = - iree_hal_cuda2_pipeline_layout_cast(base_pipeline_layout); + const iree_hal_pipeline_layout_t* base_pipeline_layout) { + const iree_hal_cuda2_pipeline_layout_t* pipeline_layout = + iree_hal_cuda2_pipeline_layout_const_cast(base_pipeline_layout); return pipeline_layout->push_constant_count; } diff --git a/experimental/cuda2/pipeline_layout.h b/experimental/cuda2/pipeline_layout.h index 1bf18972a37be..f2d48c80dc9b8 100644 --- a/experimental/cuda2/pipeline_layout.h +++ b/experimental/cuda2/pipeline_layout.h @@ -14,6 +14,17 @@ extern "C" { #endif // __cplusplus +// The max number of bindings per descriptor set allowed in the CUDA HAL +// implementation. +#define IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT 16 + +// The max number of descriptor sets allowed in the CUDA HAL implementation. +// +// This depends on the general descriptor set planning in IREE and should adjust +// with it. +#define IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT 4 + +// The max number of push constants supported by the CUDA HAL implementation. #define IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT 64 // Note that IREE HAL uses a descriptor binding model for expressing resources @@ -50,7 +61,7 @@ iree_status_t iree_hal_cuda2_descriptor_set_layout_create( // Returns the binding count for the given descriptor set layout. iree_host_size_t iree_hal_cuda2_descriptor_set_layout_binding_count( - iree_hal_descriptor_set_layout_t* descriptor_set_layout); + const iree_hal_descriptor_set_layout_t* descriptor_set_layout); //===----------------------------------------------------------------------===// // iree_hal_cuda2_pipeline_layout_t @@ -67,17 +78,30 @@ iree_status_t iree_hal_cuda2_pipeline_layout_create( iree_host_size_t push_constant_count, iree_allocator_t host_allocator, iree_hal_pipeline_layout_t** out_pipeline_layout); +// Returns the total number of sets in the given |pipeline_layout|. +iree_host_size_t iree_hal_cuda2_pipeline_layout_descriptor_set_count( + const iree_hal_pipeline_layout_t* pipeline_layout); + +// Returns the descriptor set layout of the given |set| in |pipeline_layout|. +const iree_hal_descriptor_set_layout_t* +iree_hal_cuda2_pipeline_layout_descriptor_set_layout( + const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + // Returns the base kernel argument index for the given set. iree_host_size_t iree_hal_cuda2_pipeline_layout_base_binding_index( - iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + +// Returns the total number of descriptor bindings across all sets. +iree_host_size_t iree_hal_cuda2_pipeline_layout_total_binding_count( + const iree_hal_pipeline_layout_t* pipeline_layout); // Returns the kernel argument index for push constant data. iree_host_size_t iree_hal_cuda2_pipeline_layout_push_constant_index( - iree_hal_pipeline_layout_t* pipeline_layout); + const iree_hal_pipeline_layout_t* pipeline_layout); // Returns the number of push constants in the pipeline layout. iree_host_size_t iree_hal_cuda2_pipeline_layout_push_constant_count( - iree_hal_pipeline_layout_t* pipeline_layout); + const iree_hal_pipeline_layout_t* pipeline_layout); #ifdef __cplusplus } // extern "C"