Skip to content

Commit

Permalink
[vulkan] Request 8-/16-bit integer/floating-point features (#14848)
Browse files Browse the repository at this point in the history
When the Vulkan implementation supports 8-/16-bit integer or
floating-point features, we can just request them. This helps to address
validation errors regarding them.
  • Loading branch information
antiagainst authored Aug 28, 2023
1 parent 2f9a42c commit aabb73a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
6 changes: 6 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ iree_hal_vulkan_populate_enabled_device_extensions(
} else if (strcmp(extension_name,
VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) == 0) {
extensions.buffer_device_address = true;
} else if (strcmp(extension_name, VK_KHR_8BIT_STORAGE_EXTENSION_NAME) ==
0) {
extensions.shader_8bit_storage = true;
} else if (strcmp(extension_name,
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) == 0) {
extensions.shader_float16_int8 = true;
}
}
return extensions;
Expand Down
4 changes: 4 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ typedef struct iree_hal_vulkan_device_extensions_t {
bool external_memory_host : 1;
// VK_KHR_buffer_device_address is enabled.
bool buffer_device_address : 1;
// VK_KHR_8bit_storage is enabled.
bool shader_8bit_storage : 1;
// VK_KHR_shader_float16_int8 is enabled.
bool shader_float16_int8 : 1;
} iree_hal_vulkan_device_extensions_t;

// Returns a bitfield with all of the provided extension names.
Expand Down
60 changes: 60 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,19 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_query_extensibility_set(
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);

// VK_KHR_8bit_storage:
// This extension allows use of 8-bit types in uniform and storage buffers,
// and push constant blocks. It's promoted to core since Vulkan 1.2.
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_KHR_8BIT_STORAGE_EXTENSION_NAME);

// VK_KHR_shader_float16_int8:
// This extension allows use of 16-bit floating-point types and 8-bit integer
// types in shaders for arithmetic operations. It's promoted to core since
// Vulkan 1.2.
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);

//===--------------------------------------------------------------------===//
// Optional debugging features
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -919,13 +932,45 @@ iree_status_t iree_hal_vulkan_device_create(
VkPhysicalDeviceFeatures2 available_features2;
memset(&available_features2, 0, sizeof(available_features2));
available_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;

// + Buffer device address features.
VkPhysicalDeviceBufferDeviceAddressFeatures
available_buffer_device_address_features;
memset(&available_buffer_device_address_features, 0,
sizeof(available_buffer_device_address_features));
available_buffer_device_address_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES;
available_buffer_device_address_features.pNext = available_features2.pNext;
available_features2.pNext = &available_buffer_device_address_features;

// + Shader 16 bit storage features.
VkPhysicalDevice16BitStorageFeatures available_16bit_storage_features;
memset(&available_16bit_storage_features, 0,
sizeof(available_16bit_storage_features));
available_16bit_storage_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES;
available_16bit_storage_features.pNext = available_features2.pNext;
available_features2.pNext = &available_16bit_storage_features;

// + Shader 8 bit storage features.
VkPhysicalDevice8BitStorageFeatures available_8bit_storage_features;
memset(&available_8bit_storage_features, 0,
sizeof(available_8bit_storage_features));
available_8bit_storage_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES;
available_8bit_storage_features.pNext = available_features2.pNext;
available_features2.pNext = &available_8bit_storage_features;

// + Shader float16 and int8 features.
VkPhysicalDeviceShaderFloat16Int8Features
available_shader_float16_int8_features;
memset(&available_shader_float16_int8_features, 0,
sizeof(available_shader_float16_int8_features));
available_shader_float16_int8_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES;
available_shader_float16_int8_features.pNext = available_features2.pNext;
available_features2.pNext = &available_shader_float16_int8_features;

instance_syms->vkGetPhysicalDeviceFeatures2(physical_device,
&available_features2);
const VkPhysicalDeviceFeatures* available_features =
Expand All @@ -950,6 +995,9 @@ iree_status_t iree_hal_vulkan_device_create(
if (available_features->shaderInt64) {
enabled_features2.features.shaderInt64 = VK_TRUE;
}
if (available_features->shaderInt16) {
enabled_features2.features.shaderInt16 = VK_TRUE;
}

iree_hal_vulkan_features_t enabled_features = 0;

Expand Down Expand Up @@ -1030,6 +1078,18 @@ iree_status_t iree_hal_vulkan_device_create(
subgroup_control_features.subgroupSizeControl = VK_TRUE;
}

// Enable all available 16- or 8-bit integer/floating-point features.
available_16bit_storage_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_16bit_storage_features;
if (enabled_device_extensions.shader_8bit_storage) {
available_8bit_storage_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_8bit_storage_features;
}
if (enabled_device_extensions.shader_float16_int8) {
available_shader_float16_int8_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_shader_float16_int8_features;
}

auto logical_device = new VkDeviceHandle(
instance_syms, physical_device, enabled_features,
enabled_device_extensions,
Expand Down

0 comments on commit aabb73a

Please sign in to comment.