Skip to content

Commit

Permalink
[cuda] Avoid sorting descriptors in stream command buffer (iree-org#1…
Browse files Browse the repository at this point in the history
…5437)

This commit changes the stream command buffer to keep track of a fixed
list of descriptor sets. This simplifies the logic of pushing
descriptors/constants; we only perform kernel parameter serialization at
the dispatch time. This means we don't pay the overhead if multiple push
descriptor/constant commands are issued before dispatching. Also we
don't need to sort descriptors anymore.

Progress towards iree-org#13245
  • Loading branch information
antiagainst authored and ramiro050 committed Dec 19, 2023
1 parent f316462 commit 0036c68
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 75 deletions.
7 changes: 6 additions & 1 deletion experimental/cuda2/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_push_descriptor_set(
iree_hal_buffer_allocated_buffer(binding->buffer));
iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
device_ptr = device_buffer + offset + binding->offset;
};
}
current_bindings[binding->binding] = device_ptr;
}

Expand Down Expand Up @@ -665,6 +665,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(
iree_host_size_t set_count =
iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout);
for (iree_host_size_t i = 0; i < set_count; ++i) {
// TODO: cache this information in the kernel info to avoid recomputation.
iree_host_size_t binding_count =
iree_hal_cuda2_descriptor_set_layout_binding_count(
iree_hal_cuda2_pipeline_layout_descriptor_set_layout(
Expand All @@ -678,6 +679,10 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(
// Append the push constants to the kernel arguments.
iree_host_size_t base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout);
// As commented in the above, what each kernel parameter points to is a
// CUdeviceptr, which as the size of a pointer on the target machine. we are
// just storing a 32-bit value for the push constant here instead. So we must
// process one element each type, for 64-bit machines.
for (iree_host_size_t i = 0; i < push_constant_count; i++) {
*((uint32_t*)params_ptr[base_index + i]) =
command_buffer->push_constants[i];
Expand Down
179 changes: 105 additions & 74 deletions experimental/cuda2/stream_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
#include "iree/hal/utils/collective_batch.h"
#include "iree/hal/utils/resource_set.h"

#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
// Kernel arguments contains binding and push constants.
#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128

typedef struct iree_hal_cuda2_stream_command_buffer_t {
iree_hal_command_buffer_t base;
iree_allocator_t host_allocator;
Expand All @@ -42,11 +38,13 @@ typedef struct iree_hal_cuda2_stream_command_buffer_t {
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;

// The current set push constants.
int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];

// The current set of kernel arguments.
void* current_descriptors[IREE_HAL_CUDA_MAX_KERNEL_ARG];
CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_KERNEL_ARG];
// The current bound descriptor sets.
struct {
CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT];
} iree_hal_cuda2_stream_command_buffer_t;

static const iree_hal_command_buffer_vtable_t
Expand Down Expand Up @@ -100,10 +98,6 @@ iree_status_t iree_hal_cuda2_stream_command_buffer_create(
command_buffer->cu_stream = stream;
iree_arena_initialize(block_pool, &command_buffer->arena);

for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) {
command_buffer->current_descriptors[i] = &command_buffer->device_ptrs[i];
}

iree_status_t status =
iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set);

Expand Down Expand Up @@ -467,61 +461,38 @@ static iree_status_t iree_hal_cuda2_stream_command_buffer_push_constants(
return iree_ok_status();
}

// Tie together the binding index and its index in |bindings| array.
typedef struct {
uint32_t index;
uint32_t binding;
} iree_hal_cuda2_binding_mapping_t;

// Helper to sort the binding based on their binding index.
static int compare_binding_index(const void* a, const void* b) {
const iree_hal_cuda2_binding_mapping_t buffer_a =
*(const iree_hal_cuda2_binding_mapping_t*)a;
const iree_hal_cuda2_binding_mapping_t buffer_b =
*(const iree_hal_cuda2_binding_mapping_t*)b;
return buffer_a.binding < buffer_b.binding ? -1 : 1;
}

static iree_status_t iree_hal_cuda2_stream_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
if (binding_count > IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
return iree_make_status(
IREE_STATUS_RESOURCE_EXHAUSTED,
"exceeded available binding slots for push "
"descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d",
set, binding_count, IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT);
}

iree_hal_cuda2_stream_command_buffer_t* command_buffer =
iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);

iree_host_size_t base_binding =
iree_hal_cuda2_pipeline_layout_base_binding_index(pipeline_layout, set);

// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
// Sort the binding based on the binding index and map the array index to the
// argument index.
iree_hal_cuda2_binding_mapping_t
binding_used[IREE_HAL_CUDA_MAX_BINDING_COUNT];
CUdeviceptr* current_bindings = command_buffer->descriptor_sets[set].bindings;
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_cuda2_binding_mapping_t buffer = {i, bindings[i].binding};
binding_used[i] = buffer;
}
// TODO: remove this sort - it's thankfully small (1-8 on average) but we
// should be able to avoid it like we do on the CPU side with a bitmap.
qsort(binding_used, binding_count, sizeof(iree_hal_cuda2_binding_mapping_t),
compare_binding_index);
assert(binding_count < IREE_HAL_CUDA_MAX_BINDING_COUNT &&
"binding count larger than the max expected.");

for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
CUdeviceptr device_ptr =
binding.buffer
? (iree_hal_cuda2_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset)
: 0;
*((CUdeviceptr*)command_buffer->current_descriptors[i + base_binding]) =
device_ptr;
const iree_hal_descriptor_set_binding_t* binding = &bindings[i];
CUdeviceptr device_ptr = 0;
if (binding->buffer) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&binding->buffer));

CUdeviceptr device_buffer = iree_hal_cuda2_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding->buffer));
iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
device_ptr = device_buffer + offset + binding->offset;
}
current_bindings[binding->binding] = device_ptr;
}

IREE_TRACE_ZONE_END(z0);
Expand All @@ -542,34 +513,94 @@ static iree_status_t iree_hal_cuda2_stream_command_buffer_dispatch(

// Lookup kernel parameters used for side-channeling additional launch
// information from the compiler.
iree_hal_cuda2_kernel_info_t kernel_params;
iree_hal_cuda2_kernel_info_t kernel_info;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_native_executable_entry_point_kernel_info(
executable, entry_point, &kernel_params));
executable, entry_point, &kernel_info));

IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, command_buffer->cu_stream,
kernel_params.source_filename.data, kernel_params.source_filename.size,
kernel_params.source_line, /*func_name=*/NULL, 0,
kernel_params.function_name.data, kernel_params.function_name.size);

// Patch the push constants in the kernel arguments.
iree_host_size_t num_constants =
iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_params.layout);
iree_host_size_t constant_base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_params.layout);
for (iree_host_size_t i = 0; i < num_constants; i++) {
*((uint32_t*)command_buffer->current_descriptors[i + constant_base_index]) =
kernel_info.source_filename.data, kernel_info.source_filename.size,
kernel_info.source_line, /*func_name=*/NULL, 0,
kernel_info.function_name.data, kernel_info.function_name.size);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&executable));

// The total number of descriptors across all descriptor sets.
iree_host_size_t descriptor_count =
iree_hal_cuda2_pipeline_layout_total_binding_count(kernel_info.layout);
// The total number of push constants.
iree_host_size_t push_constant_count =
iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_info.layout);
// We append push constants to the end of descriptors to form a linear chain
// of kernel arguments.
iree_host_size_t kernel_params_count = descriptor_count + push_constant_count;
iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);

// Per CUDA API requirements, we need two levels of indirection for passing
// kernel arguments in.
// "If the kernel has N parameters, then kernelParams needs to be an array
// of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1],
// points to the region of memory from which the actual parameter will be
// copied."
//
// (From the cuGraphAddKernelNode API doc in
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b)
//
// It means each kernel_params[i] is itself a pointer to the corresponding
// element at the *second* inline allocation at the end of the current
// segment.
iree_host_size_t total_size = kernel_params_length * 2;
uint8_t* storage_base = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, total_size,
(void**)&storage_base));
void** params_ptr = (void**)storage_base;

// Set up kernel arguments to point to the payload slots.
CUdeviceptr* payload_ptr =
(CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length);
for (size_t i = 0; i < kernel_params_count; i++) {
params_ptr[i] = &payload_ptr[i];
}

// Copy descriptors from all sets to the end of the current segment for later
// access.
iree_host_size_t set_count =
iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout);
for (iree_host_size_t i = 0; i < set_count; ++i) {
// TODO: cache this information in the kernel info to avoid recomputation.
iree_host_size_t binding_count =
iree_hal_cuda2_descriptor_set_layout_binding_count(
iree_hal_cuda2_pipeline_layout_descriptor_set_layout(
kernel_info.layout, i));
iree_host_size_t index = iree_hal_cuda2_pipeline_layout_base_binding_index(
kernel_info.layout, i);
memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings,
binding_count * sizeof(CUdeviceptr));
}

// Append the push constants to the kernel arguments.
iree_host_size_t base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout);
// As commented in the above, what each kernel parameter points to is a
// CUdeviceptr, which as the size of a pointer on the target machine. we are
// just storing a 32-bit value for the push constant here instead. So we must
// process one element each type, for 64-bit machines.
for (iree_host_size_t i = 0; i < push_constant_count; i++) {
*((uint32_t*)params_ptr[base_index + i]) =
command_buffer->push_constants[i];
}

IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->cuda_symbols,
cuLaunchKernel(
kernel_params.function, workgroup_x, workgroup_y, workgroup_z,
kernel_params.block_size[0], kernel_params.block_size[1],
kernel_params.block_size[2], kernel_params.shared_memory_size,
command_buffer->cu_stream, command_buffer->current_descriptors, NULL),
cuLaunchKernel(kernel_info.function, workgroup_x, workgroup_y,
workgroup_z, kernel_info.block_size[0],
kernel_info.block_size[1], kernel_info.block_size[2],
kernel_info.shared_memory_size, command_buffer->cu_stream,
params_ptr, NULL),
"cuLaunchKernel");

IREE_CUDA_TRACE_ZONE_END(command_buffer->tracing_context,
Expand Down

0 comments on commit 0036c68

Please sign in to comment.