diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc index bf464c0d75f0..e23e91bc67b0 100644 --- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc +++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc @@ -216,6 +216,12 @@ iree_hal_vulkan_populate_enabled_device_extensions( } else if (strcmp(extension_name, VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) == 0) { extensions.buffer_device_address = true; + } else if (strcmp(extension_name, VK_KHR_8BIT_STORAGE_EXTENSION_NAME) == + 0) { + extensions.shader_8bit_storage = true; + } else if (strcmp(extension_name, + VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) == 0) { + extensions.shader_float16_int8 = true; } } return extensions; diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h index b404f3229c8f..789dd71a1df1 100644 --- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h +++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h @@ -86,6 +86,10 @@ typedef struct iree_hal_vulkan_device_extensions_t { bool external_memory_host : 1; // VK_KHR_buffer_device_address is enabled. bool buffer_device_address : 1; + // VK_KHR_8bit_storage is enabled. + bool shader_8bit_storage : 1; + // VK_KHR_shader_float16_int8 is enabled. + bool shader_float16_int8 : 1; } iree_hal_vulkan_device_extensions_t; // Returns a bitfield with all of the provided extension names. diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index d93b2599efb0..2a596be6dc6b 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -242,6 +242,19 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_query_extensibility_set( ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME); + // VK_KHR_8bit_storage: + // This extension allows use of 8-bit types in uniform and storage buffers, + // and push constant blocks. It's promoted to core since Vulkan 1.2. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_8BIT_STORAGE_EXTENSION_NAME); + + // VK_KHR_shader_float16_int8: + // This extension allows use of 16-bit floating-point types and 8-bit integer + // types in shaders for arithmetic operations. It's promoted to core since + // Vulkan 1.2. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME); + //===--------------------------------------------------------------------===// // Optional debugging features //===--------------------------------------------------------------------===// @@ -919,13 +932,45 @@ iree_status_t iree_hal_vulkan_device_create( VkPhysicalDeviceFeatures2 available_features2; memset(&available_features2, 0, sizeof(available_features2)); available_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + // + Buffer device address features. VkPhysicalDeviceBufferDeviceAddressFeatures available_buffer_device_address_features; memset(&available_buffer_device_address_features, 0, sizeof(available_buffer_device_address_features)); available_buffer_device_address_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES; + available_buffer_device_address_features.pNext = available_features2.pNext; available_features2.pNext = &available_buffer_device_address_features; + + // + Shader 16 bit storage features. + VkPhysicalDevice16BitStorageFeatures available_16bit_storage_features; + memset(&available_16bit_storage_features, 0, + sizeof(available_16bit_storage_features)); + available_16bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES; + available_16bit_storage_features.pNext = available_features2.pNext; + available_features2.pNext = &available_16bit_storage_features; + + // + Shader 8 bit storage features. + VkPhysicalDevice8BitStorageFeatures available_8bit_storage_features; + memset(&available_8bit_storage_features, 0, + sizeof(available_8bit_storage_features)); + available_8bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES; + available_8bit_storage_features.pNext = available_features2.pNext; + available_features2.pNext = &available_8bit_storage_features; + + // + Shader float16 and int8 features. + VkPhysicalDeviceShaderFloat16Int8Features + available_shader_float16_int8_features; + memset(&available_shader_float16_int8_features, 0, + sizeof(available_shader_float16_int8_features)); + available_shader_float16_int8_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES; + available_shader_float16_int8_features.pNext = available_features2.pNext; + available_features2.pNext = &available_shader_float16_int8_features; + instance_syms->vkGetPhysicalDeviceFeatures2(physical_device, &available_features2); const VkPhysicalDeviceFeatures* available_features = @@ -950,6 +995,9 @@ iree_status_t iree_hal_vulkan_device_create( if (available_features->shaderInt64) { enabled_features2.features.shaderInt64 = VK_TRUE; } + if (available_features->shaderInt16) { + enabled_features2.features.shaderInt16 = VK_TRUE; + } iree_hal_vulkan_features_t enabled_features = 0; @@ -1030,6 +1078,18 @@ iree_status_t iree_hal_vulkan_device_create( subgroup_control_features.subgroupSizeControl = VK_TRUE; } + // Enable all available 16- or 8-bit integer/floating-point features. + available_16bit_storage_features.pNext = enabled_features2.pNext; + enabled_features2.pNext = &available_16bit_storage_features; + if (enabled_device_extensions.shader_8bit_storage) { + available_8bit_storage_features.pNext = enabled_features2.pNext; + enabled_features2.pNext = &available_8bit_storage_features; + } + if (enabled_device_extensions.shader_float16_int8) { + available_shader_float16_int8_features.pNext = enabled_features2.pNext; + enabled_features2.pNext = &available_shader_float16_int8_features; + } + auto logical_device = new VkDeviceHandle( instance_syms, physical_device, enabled_features, enabled_device_extensions,