Skip to content

Commit

Permalink
[cuda] Remove if ok status check nesting when possible (#14811)
Browse files Browse the repository at this point in the history
We can immediately return if allocation fails.

Also fixed a symbol issue for compilation.
  • Loading branch information
antiagainst authored Aug 25, 2023
1 parent b76b6df commit 10f9e61
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 70 deletions.
38 changes: 19 additions & 19 deletions experimental/cuda2/cuda_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ iree_status_t iree_hal_cuda2_allocator_create(
int supports_read_only_host_register = 0;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
CU_RESULT_TO_STATUS(
context->syms,
IREE_CURESULT_TO_STATUS(
cuda_symbols,
cuDeviceGetAttribute(
&supports_read_only_host_register,
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED, device),
Expand All @@ -99,25 +99,25 @@ iree_status_t iree_hal_cuda2_allocator_create(
: "no READ_ONLY_HOST_REGISTER_SUPPORTED");

iree_hal_cuda2_allocator_t* allocator = NULL;
iree_status_t status = iree_allocator_malloc(
host_allocator, sizeof(*allocator), (void**)&allocator);
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda2_allocator_vtable,
&allocator->resource);
allocator->device = device;
allocator->stream = stream;
allocator->pools = pools;
allocator->symbols = cuda_symbols;
allocator->host_allocator = host_allocator;
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;
allocator->supports_read_only_host_register =
supports_read_only_host_register != 0;
*out_allocator = (iree_hal_allocator_t*)allocator;
}
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*allocator),
(void**)&allocator));

iree_hal_resource_initialize(&iree_hal_cuda2_allocator_vtable,
&allocator->resource);
allocator->device = device;
allocator->stream = stream;
allocator->pools = pools;
allocator->symbols = cuda_symbols;
allocator->host_allocator = host_allocator;
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;
allocator->supports_read_only_host_register =
supports_read_only_host_register != 0;
*out_allocator = (iree_hal_allocator_t*)allocator;

IREE_TRACE_ZONE_END(z0);
return status;
return iree_ok_status();
}

static void iree_hal_cuda2_allocator_destroy(
Expand Down
33 changes: 16 additions & 17 deletions experimental/cuda2/nccl_channel.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,24 @@ iree_status_t iree_hal_cuda2_nccl_channel_create(
"ncclCommInitRankConfig");

iree_hal_cuda2_nccl_channel_t* channel = NULL;
iree_status_t status =
iree_allocator_malloc(host_allocator, sizeof(*channel), (void**)&channel);

if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda2_nccl_channel_vtable,
&channel->resource);
channel->cuda_symbols = cuda_symbols;
channel->nccl_symbols = nccl_symbols;
channel->host_allocator = host_allocator;
channel->parent_channel = NULL;
channel->rank = rank;
channel->count = count;
channel->comm = comm;
IREE_TRACE(channel->id_hash = id_hash);
*out_channel = (iree_hal_channel_t*)channel;
}
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*channel),
(void**)&channel));

iree_hal_resource_initialize(&iree_hal_cuda2_nccl_channel_vtable,
&channel->resource);
channel->cuda_symbols = cuda_symbols;
channel->nccl_symbols = nccl_symbols;
channel->host_allocator = host_allocator;
channel->parent_channel = NULL;
channel->rank = rank;
channel->count = count;
channel->comm = comm;
IREE_TRACE(channel->id_hash = id_hash);
*out_channel = (iree_hal_channel_t*)channel;

IREE_TRACE_ZONE_END(z0);
return status;
return iree_ok_status();
}

static void iree_hal_cuda2_nccl_channel_destroy(
Expand Down
65 changes: 31 additions & 34 deletions experimental/cuda2/pipeline_layout.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,19 @@ iree_status_t iree_hal_cuda2_descriptor_set_layout_create(
*out_descriptor_set_layout = NULL;

iree_hal_cuda2_descriptor_set_layout_t* descriptor_set_layout = NULL;
iree_status_t status =
iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout),
(void**)&descriptor_set_layout);

if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda2_descriptor_set_layout_vtable,
&descriptor_set_layout->resource);
descriptor_set_layout->host_allocator = host_allocator;
descriptor_set_layout->binding_count = binding_count;
*out_descriptor_set_layout =
(iree_hal_descriptor_set_layout_t*)descriptor_set_layout;
}
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout),
(void**)&descriptor_set_layout));

iree_hal_resource_initialize(&iree_hal_cuda2_descriptor_set_layout_vtable,
&descriptor_set_layout->resource);
descriptor_set_layout->host_allocator = host_allocator;
descriptor_set_layout->binding_count = binding_count;
*out_descriptor_set_layout =
(iree_hal_descriptor_set_layout_t*)descriptor_set_layout;

IREE_TRACE_ZONE_END(z0);
return status;
return iree_ok_status();
}

iree_host_size_t iree_hal_cuda2_descriptor_set_layout_binding_count(
Expand Down Expand Up @@ -155,30 +153,29 @@ iree_status_t iree_hal_cuda2_pipeline_layout_create(
iree_host_size_t total_size =
sizeof(*pipeline_layout) +
set_layout_count * sizeof(*pipeline_layout->set_layouts);
iree_status_t status = iree_allocator_malloc(host_allocator, total_size,
(void**)&pipeline_layout);

if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda2_pipeline_layout_vtable,
&pipeline_layout->resource);
pipeline_layout->host_allocator = host_allocator;
pipeline_layout->set_layout_count = set_layout_count;
iree_host_size_t base_index = 0;
for (iree_host_size_t i = 0; i < set_layout_count; ++i) {
pipeline_layout->set_layouts[i].set_layout = set_layouts[i];
// Copy and retain all descriptor sets so we don't lose them.
iree_hal_descriptor_set_layout_retain(set_layouts[i]);
pipeline_layout->set_layouts[i].base_index = base_index;
base_index +=
iree_hal_cuda2_descriptor_set_layout_binding_count(set_layouts[i]);
}
pipeline_layout->push_constant_base_index = base_index;
pipeline_layout->push_constant_count = push_constant_count;
*out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, total_size,
(void**)&pipeline_layout));

iree_hal_resource_initialize(&iree_hal_cuda2_pipeline_layout_vtable,
&pipeline_layout->resource);
pipeline_layout->host_allocator = host_allocator;
pipeline_layout->set_layout_count = set_layout_count;
iree_host_size_t base_index = 0;
for (iree_host_size_t i = 0; i < set_layout_count; ++i) {
pipeline_layout->set_layouts[i].set_layout = set_layouts[i];
// Copy and retain all descriptor sets so we don't lose them.
iree_hal_descriptor_set_layout_retain(set_layouts[i]);
pipeline_layout->set_layouts[i].base_index = base_index;
base_index +=
iree_hal_cuda2_descriptor_set_layout_binding_count(set_layouts[i]);
}
pipeline_layout->push_constant_base_index = base_index;
pipeline_layout->push_constant_count = push_constant_count;
*out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout;

IREE_TRACE_ZONE_END(z0);
return status;
return iree_ok_status();
}

static void iree_hal_cuda2_pipeline_layout_destroy(
Expand Down

0 comments on commit 10f9e61

Please sign in to comment.