Skip to content

Commit

Permalink
[cuda] Avoid sorting when composing kernel arguments (iree-org#15325)
Browse files Browse the repository at this point in the history
This commit changes the graph command buffer to keep track of a fixed
list of descriptor sets. This simplifies the logic of pushing
descriptors/constants; we only perform kernel parameter serialization at
the dispatch time. This means we don't pay the overhead if multiple push
descriptor/constant commands are issued before dispatching. Also we
don't need to sort descriptors anymore.


Progress towards iree-org#13245
  • Loading branch information
antiagainst authored and ramiro050 committed Dec 19, 2023
1 parent 5783dda commit 1c10beb
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 147 deletions.
222 changes: 112 additions & 110 deletions experimental/cuda2/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,18 @@

#include "experimental/cuda2/graph_command_buffer.h"

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "experimental/cuda2/cuda_buffer.h"
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/native_executable.h"
#include "experimental/cuda2/nccl_channel.h"
#include "experimental/cuda2/pipeline_layout.h"
#include "iree/base/api.h"
#include "iree/hal/utils/collective_batch.h"
#include "iree/hal/utils/resource_set.h"

// The maximal number of descriptor bindings supported in the CUDA HAL driver.
#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
// The maximal number of kernel arguments supported in the CUDA HAL driver for
// descriptor bindings and push constants.
#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128

// Command buffer implementation that directly records into CUDA graphs.
// The command buffer records the commands on the calling thread without
// additional threading indirection.
Expand All @@ -45,8 +37,8 @@ typedef struct iree_hal_cuda2_graph_command_buffer_t {

CUcontext cu_context;
// The CUDA graph under construction.
CUgraph graph;
CUgraphExec exec;
CUgraph cu_graph;
CUgraphExec cu_graph_exec;

// The last node added to the command buffer.
// We need to track it as we are currently serializing all the nodes (each
Expand All @@ -56,12 +48,13 @@ typedef struct iree_hal_cuda2_graph_command_buffer_t {
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;

int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];

// The current set of kernel arguments.
void* current_descriptor[];
// The current bound descriptor sets.
struct {
CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT];
} iree_hal_cuda2_graph_command_buffer_t;
// + Additional inline allocation for holding all kernel arguments.

static const iree_hal_command_buffer_vtable_t
iree_hal_cuda2_graph_command_buffer_vtable;
Expand Down Expand Up @@ -94,11 +87,8 @@ iree_status_t iree_hal_cuda2_graph_command_buffer_create(
IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_cuda2_graph_command_buffer_t* command_buffer = NULL;
size_t total_size = sizeof(*command_buffer) +
IREE_HAL_CUDA_MAX_KERNEL_ARG * sizeof(void*) +
IREE_HAL_CUDA_MAX_KERNEL_ARG * sizeof(CUdeviceptr);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, total_size,
z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer),
(void**)&command_buffer));

iree_hal_command_buffer_initialize(
Expand All @@ -108,16 +98,10 @@ iree_status_t iree_hal_cuda2_graph_command_buffer_create(
command_buffer->symbols = cuda_symbols;
iree_arena_initialize(block_pool, &command_buffer->arena);
command_buffer->cu_context = context;
command_buffer->graph = NULL;
command_buffer->exec = NULL;
command_buffer->cu_graph = NULL;
command_buffer->cu_graph_exec = NULL;
command_buffer->last_node = NULL;

CUdeviceptr* device_ptrs = (CUdeviceptr*)(command_buffer->current_descriptor +
IREE_HAL_CUDA_MAX_KERNEL_ARG);
for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) {
command_buffer->current_descriptor[i] = &device_ptrs[i];
}

iree_status_t status =
iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set);

Expand Down Expand Up @@ -147,15 +131,15 @@ static void iree_hal_cuda2_graph_command_buffer_destroy(
// Drop any pending collective batches before we tear things down.
iree_hal_collective_batch_clear(&command_buffer->collective_batch);

if (command_buffer->graph != NULL) {
if (command_buffer->cu_graph != NULL) {
IREE_CUDA_IGNORE_ERROR(command_buffer->symbols,
cuGraphDestroy(command_buffer->graph));
command_buffer->graph = NULL;
cuGraphDestroy(command_buffer->cu_graph));
command_buffer->cu_graph = NULL;
}
if (command_buffer->exec != NULL) {
if (command_buffer->cu_graph_exec != NULL) {
IREE_CUDA_IGNORE_ERROR(command_buffer->symbols,
cuGraphExecDestroy(command_buffer->exec));
command_buffer->exec = NULL;
cuGraphExecDestroy(command_buffer->cu_graph_exec));
command_buffer->cu_graph_exec = NULL;
}
command_buffer->last_node = NULL;

Expand All @@ -177,7 +161,7 @@ CUgraphExec iree_hal_cuda2_graph_command_buffer_handle(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_cuda2_graph_command_buffer_t* command_buffer =
iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer);
return command_buffer->exec;
return command_buffer->cu_graph_exec;
}

// Flushes any pending batched collective operations.
Expand Down Expand Up @@ -226,15 +210,15 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_begin(
iree_hal_cuda2_graph_command_buffer_t* command_buffer =
iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer);

if (command_buffer->graph != NULL) {
if (command_buffer->cu_graph != NULL) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"command buffer cannot be re-recorded");
}

// Create a new empty graph to record into.
IREE_CUDA_RETURN_IF_ERROR(command_buffer->symbols,
cuGraphCreate(&command_buffer->graph, /*flags=*/0),
"cuGraphCreate");
IREE_CUDA_RETURN_IF_ERROR(
command_buffer->symbols,
cuGraphCreate(&command_buffer->cu_graph, /*flags=*/0), "cuGraphCreate");

return iree_ok_status();
}
Expand All @@ -255,15 +239,15 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_end(
CUgraphNode error_node = NULL;
iree_status_t status = IREE_CURESULT_TO_STATUS(
command_buffer->symbols,
cuGraphInstantiate(&command_buffer->exec, command_buffer->graph,
&error_node,
cuGraphInstantiate(&command_buffer->cu_graph_exec,
command_buffer->cu_graph, &error_node,
/*logBuffer=*/NULL,
/*bufferSize=*/0));
if (iree_status_is_ok(status)) {
// No longer need the source graph used for construction.
IREE_CUDA_IGNORE_ERROR(command_buffer->symbols,
cuGraphDestroy(command_buffer->graph));
command_buffer->graph = NULL;
cuGraphDestroy(command_buffer->cu_graph));
command_buffer->cu_graph = NULL;
}

iree_hal_resource_set_freeze(command_buffer->resource_set);
Expand Down Expand Up @@ -416,7 +400,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_fill_buffer(
size_t numNode = command_buffer->last_node ? 1 : 0;
IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->symbols,
cuGraphAddMemsetNode(&command_buffer->last_node, command_buffer->graph,
cuGraphAddMemsetNode(&command_buffer->last_node, command_buffer->cu_graph,
dep, numNode, &params, command_buffer->cu_context),
"cuGraphAddMemsetNode");

Expand Down Expand Up @@ -471,7 +455,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_update_buffer(

IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->symbols,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->cu_graph,
dep, numNode, &params, command_buffer->cu_context),
"cuGraphAddMemcpyNode");

Expand Down Expand Up @@ -522,7 +506,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_copy_buffer(

IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->symbols,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->cu_graph,
dep, numNode, &params, command_buffer->cu_context),
"cuGraphAddMemcpyNode");

Expand Down Expand Up @@ -550,77 +534,44 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_push_constants(
iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer);
iree_host_size_t constant_base_index = offset / sizeof(int32_t);
for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) {
command_buffer->push_constant[i + constant_base_index] =
command_buffer->push_constants[i + constant_base_index] =
((uint32_t*)values)[i];
}
return iree_ok_status();
}

typedef struct {
// The original index into the iree_hal_descriptor_set_binding_t array.
uint32_t index;
// The descriptor binding number.
uint32_t binding;
} iree_hal_cuda2_binding_mapping_t;

// Compares two iree_hal_cuda2_binding_mapping_t according to the descriptor
// binding number.
static int compare_binding_index(const void* a, const void* b) {
const iree_hal_cuda2_binding_mapping_t buffer_a =
*(const iree_hal_cuda2_binding_mapping_t*)a;
const iree_hal_cuda2_binding_mapping_t buffer_b =
*(const iree_hal_cuda2_binding_mapping_t*)b;
return buffer_a.binding < buffer_b.binding ? -1 : 1;
}

static iree_status_t iree_hal_cuda2_graph_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
IREE_ASSERT_LT(binding_count, IREE_HAL_CUDA_MAX_BINDING_COUNT,
"binding count larger than the max expected");
if (binding_count > IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
return iree_make_status(
IREE_STATUS_RESOURCE_EXHAUSTED,
"exceeded available binding slots for push "
"descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d",
set, binding_count, IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT);
}

iree_hal_cuda2_graph_command_buffer_t* command_buffer =
iree_hal_cuda2_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);

iree_host_size_t base_binding =
iree_hal_cuda2_pipeline_layout_base_binding_index(pipeline_layout, set);

// Convention with the compiler side. We map descriptor bindings to kernel
// argument. We compact the descriptor binding number ranges to get a dense
// set of kernel arguments and keep them ordered based on the descriptor
// binding index.
iree_hal_cuda2_binding_mapping_t
sorted_bindings[IREE_HAL_CUDA_MAX_BINDING_COUNT];
CUdeviceptr* current_bindings = command_buffer->descriptor_sets[set].bindings;
for (iree_host_size_t i = 0; i < binding_count; i++) {
sorted_bindings[i].index = i;
sorted_bindings[i].binding = bindings[i].binding;
}
// Sort the binding based on the binding index and map the (base offset +
// array index) to the kernel argument index.
// TODO: remove this sort - it's thankfully small (1-8 on average) but we
// should be able to avoid it like we do on the CPU side with a bitmap.
qsort(sorted_bindings, binding_count,
sizeof(iree_hal_cuda2_binding_mapping_t), compare_binding_index);

for (iree_host_size_t i = 0; i < binding_count; i++) {
const iree_hal_descriptor_set_binding_t* binding =
&bindings[sorted_bindings[i].index];
const iree_hal_descriptor_set_binding_t* binding = &bindings[i];
CUdeviceptr device_ptr = 0;
if (binding->buffer) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&binding->buffer));

CUdeviceptr device_buffer = iree_hal_cuda2_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding->buffer));
iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
device_ptr = device_buffer + offset + binding->offset;
};
*((CUdeviceptr*)command_buffer->current_descriptor[base_binding + i]) =
device_ptr;
if (binding->buffer) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&binding->buffer));
}
current_bindings[binding->binding] = device_ptr;
}

IREE_TRACE_ZONE_END(z0);
Expand All @@ -641,35 +592,86 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(

// Lookup kernel parameters used for side-channeling additional launch
// information from the compiler.
iree_hal_cuda2_kernel_params_t kernel_params;
iree_hal_cuda2_kernel_info_t kernel_info;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_native_executable_entry_point_kernel_params(
executable, entry_point, &kernel_params));
z0, iree_hal_cuda2_native_executable_entry_point_kernel_info(
executable, entry_point, &kernel_info));

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&executable));

// Patch the push constants in the kernel arguments.
iree_host_size_t num_constants =
iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_params.layout);
// The total number of descriptors across all descriptor sets.
iree_host_size_t descriptor_count =
iree_hal_cuda2_pipeline_layout_total_binding_count(kernel_info.layout);
// The total number of push constants.
iree_host_size_t push_constant_count =
iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_info.layout);
// We append push constants to the end of descriptors to form a linear chain
// of kernel arguments.
iree_host_size_t kernel_params_count = descriptor_count + push_constant_count;
iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);

// Per CUDA API requirements, we need two levels of indirection for passing
// kernel arguments in.
// "If the kernel has N parameters, then kernelParams needs to be an array
// of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1],
// points to the region of memory from which the actual parameter will be
// copied."
//
// (From the cuGraphAddKernelNode API doc in
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b)
//
// It means each kernel_params[i] is itself a pointer to the corresponding
// element at the *second* inline allocation at the end of the current
// segment.
iree_host_size_t total_size = kernel_params_length * 2;
uint8_t* storage_base = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, total_size,
(void**)&storage_base));
void** params_ptr = (void**)storage_base;

// Set up kernel arguments to point to the payload slots.
CUdeviceptr* payload_ptr =
(CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length);
for (size_t i = 0; i < kernel_params_count; i++) {
params_ptr[i] = &payload_ptr[i];
}

// Copy descriptors from all sets to the end of the current segment for later
// access.
iree_host_size_t set_count =
iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout);
for (iree_host_size_t i = 0; i < set_count; ++i) {
iree_host_size_t binding_count =
iree_hal_cuda2_descriptor_set_layout_binding_count(
iree_hal_cuda2_pipeline_layout_descriptor_set_layout(
kernel_info.layout, i));
iree_host_size_t index = iree_hal_cuda2_pipeline_layout_base_binding_index(
kernel_info.layout, i);
memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings,
binding_count * sizeof(CUdeviceptr));
}

// Append the push constants to the kernel arguments.
iree_host_size_t base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_params.layout);
for (iree_host_size_t i = 0; i < num_constants; i++) {
*((uint32_t*)command_buffer->current_descriptor[base_index + i]) =
command_buffer->push_constant[i];
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout);
for (iree_host_size_t i = 0; i < push_constant_count; i++) {
*((uint32_t*)params_ptr[base_index + i]) =
command_buffer->push_constants[i];
}

CUDA_KERNEL_NODE_PARAMS params = {
.func = kernel_params.function,
.blockDimX = kernel_params.block_size[0],
.blockDimY = kernel_params.block_size[1],
.blockDimZ = kernel_params.block_size[2],
.func = kernel_info.function,
.blockDimX = kernel_info.block_size[0],
.blockDimY = kernel_info.block_size[1],
.blockDimZ = kernel_info.block_size[2],
.gridDimX = workgroup_x,
.gridDimY = workgroup_y,
.gridDimZ = workgroup_z,
.kernelParams = command_buffer->current_descriptor,
.sharedMemBytes = kernel_params.shared_memory_size,
.kernelParams = params_ptr,
.sharedMemBytes = kernel_info.shared_memory_size,
};

// Serialize all the nodes for now.
Expand All @@ -678,7 +680,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(

IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->symbols,
cuGraphAddKernelNode(&command_buffer->last_node, command_buffer->graph,
cuGraphAddKernelNode(&command_buffer->last_node, command_buffer->cu_graph,
dep, numNodes, &params),
"cuGraphAddKernelNode");

Expand Down
Loading

0 comments on commit 1c10beb

Please sign in to comment.