Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vulkan] Request 8-/16-bit integer/floating-point features #14848

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ iree_hal_vulkan_populate_enabled_device_extensions(
} else if (strcmp(extension_name,
VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) == 0) {
extensions.buffer_device_address = true;
} else if (strcmp(extension_name, VK_KHR_8BIT_STORAGE_EXTENSION_NAME) ==
0) {
extensions.shader_8bit_storage = true;
} else if (strcmp(extension_name,
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) == 0) {
extensions.shader_float16_int8 = true;
}
}
return extensions;
Expand Down
4 changes: 4 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ typedef struct iree_hal_vulkan_device_extensions_t {
bool external_memory_host : 1;
// VK_KHR_buffer_device_address is enabled.
bool buffer_device_address : 1;
// VK_KHR_8bit_storage is enabled.
bool shader_8bit_storage : 1;
// VK_KHR_shader_float16_int8 is enabled.
bool shader_float16_int8 : 1;
} iree_hal_vulkan_device_extensions_t;

// Returns a bitfield with all of the provided extension names.
Expand Down
60 changes: 60 additions & 0 deletions runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,19 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_query_extensibility_set(
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);

// VK_KHR_8bit_storage:
// This extension allows use of 8-bit types in uniform and storage buffers,
// and push constant blocks. It's promoted to core since Vulkan 1.2.
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_KHR_8BIT_STORAGE_EXTENSION_NAME);

// VK_KHR_shader_float16_int8:
// This extension allows use of 16-bit floating-point types and 8-bit integer
// types in shaders for arithmetic operations. It's promoted to core since
// Vulkan 1.2.
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);

//===--------------------------------------------------------------------===//
// Optional debugging features
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -919,13 +932,45 @@ iree_status_t iree_hal_vulkan_device_create(
VkPhysicalDeviceFeatures2 available_features2;
memset(&available_features2, 0, sizeof(available_features2));
available_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;

// + Buffer device address features.
VkPhysicalDeviceBufferDeviceAddressFeatures
available_buffer_device_address_features;
memset(&available_buffer_device_address_features, 0,
sizeof(available_buffer_device_address_features));
available_buffer_device_address_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES;
available_buffer_device_address_features.pNext = available_features2.pNext;
available_features2.pNext = &available_buffer_device_address_features;

// + Shader 16 bit storage features.
VkPhysicalDevice16BitStorageFeatures available_16bit_storage_features;
memset(&available_16bit_storage_features, 0,
sizeof(available_16bit_storage_features));
available_16bit_storage_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES;
available_16bit_storage_features.pNext = available_features2.pNext;
available_features2.pNext = &available_16bit_storage_features;

// + Shader 8 bit storage features.
VkPhysicalDevice8BitStorageFeatures available_8bit_storage_features;
memset(&available_8bit_storage_features, 0,
sizeof(available_8bit_storage_features));
available_8bit_storage_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES;
available_8bit_storage_features.pNext = available_features2.pNext;
available_features2.pNext = &available_8bit_storage_features;

// + Shader float16 and int8 features.
VkPhysicalDeviceShaderFloat16Int8Features
available_shader_float16_int8_features;
memset(&available_shader_float16_int8_features, 0,
sizeof(available_shader_float16_int8_features));
available_shader_float16_int8_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES;
available_shader_float16_int8_features.pNext = available_features2.pNext;
available_features2.pNext = &available_shader_float16_int8_features;

instance_syms->vkGetPhysicalDeviceFeatures2(physical_device,
&available_features2);
const VkPhysicalDeviceFeatures* available_features =
Expand All @@ -950,6 +995,9 @@ iree_status_t iree_hal_vulkan_device_create(
if (available_features->shaderInt64) {
enabled_features2.features.shaderInt64 = VK_TRUE;
}
if (available_features->shaderInt16) {
enabled_features2.features.shaderInt16 = VK_TRUE;
}

iree_hal_vulkan_features_t enabled_features = 0;

Expand Down Expand Up @@ -1030,6 +1078,18 @@ iree_status_t iree_hal_vulkan_device_create(
subgroup_control_features.subgroupSizeControl = VK_TRUE;
}

// Enable all available 16- or 8-bit integer/floating-point features.
available_16bit_storage_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_16bit_storage_features;
if (enabled_device_extensions.shader_8bit_storage) {
available_8bit_storage_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_8bit_storage_features;
}
if (enabled_device_extensions.shader_float16_int8) {
available_shader_float16_int8_features.pNext = enabled_features2.pNext;
enabled_features2.pNext = &available_shader_float16_int8_features;
}

auto logical_device = new VkDeviceHandle(
instance_syms, physical_device, enabled_features,
enabled_device_extensions,
Expand Down
Loading