diff --git a/experimental/hip/CMakeLists.txt b/experimental/hip/CMakeLists.txt new file mode 100644 index 000000000000..32c9dd6a2331 --- /dev/null +++ b/experimental/hip/CMakeLists.txt @@ -0,0 +1,78 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Set the root for package namespacing to the current directory. +set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}") +set(IREE_PACKAGE_ROOT_PREFIX "iree/experimental/hip") + +iree_add_all_subdirs() + +if(NOT DEFINED ROCM_HEADERS_API_ROOT) + set(ROCM_HEADERS_API_ROOT "${IREE_SOURCE_DIR}/third_party/hip-build-deps/include") +endif() + +if(NOT EXISTS "${ROCM_HEADERS_API_ROOT}/hip/hip_version.h") + message(SEND_ERROR "Could not find HIP headers at: ${ROCM_HEADERS_API_ROOT}") +endif() + +iree_cc_library( + NAME + hip + HDRS + "api.h" + SRCS + "api.h" + "hip_driver.c" + INCLUDES + "${ROCM_HEADERS_API_ROOT}" + DEPS + ::dynamic_symbols + iree::base + iree::base::core_headers + iree::hal + COPTS + "-D__HIP_PLATFORM_HCC__=1" + PUBLIC +) + +iree_cc_library( + NAME + dynamic_symbols + HDRS + "dynamic_symbols.h" + "status_util.h" + TEXTUAL_HDRS + "dynamic_symbol_tables.h" + SRCS + "hip_headers.h" + "dynamic_symbols.c" + "status_util.c" + INCLUDES + "${ROCM_HEADERS_API_ROOT}" + COPTS + "-D__HIP_PLATFORM_HCC__=1" + DEPS + iree::base + iree::base::core_headers + iree::base::internal::dynamic_library + PUBLIC +) + +iree_cc_test( + NAME + dynamic_symbols_test + SRCS + "dynamic_symbols_test.cc" + DEPS + ::dynamic_symbols + iree::base + iree::testing::gtest + iree::testing::gtest_main + LABELS + "driver=hip" + COPTS + "-D__HIP_PLATFORM_HCC__=1" +) diff --git a/experimental/hip/api.h b/experimental/hip/api.h new file mode 100644 index 000000000000..5dada5f4daa8 --- /dev/null +++ b/experimental/hip/api.h @@ -0,0 +1,47 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_EXPERIMENTAL_HIP_API_H_ +#define IREE_EXPERIMENTAL_HIP_API_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_hip_driver_t +//===----------------------------------------------------------------------===// + +// HIP HAL driver creation options. +typedef struct iree_hal_hip_driver_options_t { + // The index of the default HIP device to use within the list of available + // devices. + int default_device_index; +} iree_hal_hip_driver_options_t; + +// Initializes the given |out_options| with default driver creation options. +IREE_API_EXPORT void iree_hal_hip_driver_options_initialize( + iree_hal_hip_driver_options_t* out_options); + +// Creates a HIP HAL driver with the given |options|, from which HIP devices +// can be enumerated and created with specific parameters. +// +// |out_driver| must be released by the caller (see iree_hal_driver_release). +IREE_API_EXPORT iree_status_t iree_hal_hip_driver_create( + iree_string_view_t identifier, + const iree_hal_hip_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HIP_API_H_ diff --git a/experimental/hip/dynamic_symbol_tables.h b/experimental/hip/dynamic_symbol_tables.h new file mode 100644 index 000000000000..552712115df0 --- /dev/null +++ b/experimental/hip/dynamic_symbol_tables.h @@ -0,0 +1,66 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// HIP symbols +//===----------------------------------------------------------------------===// +IREE_HIP_PFN_DECL(hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t) +IREE_HIP_PFN_DECL(hipCtxDestroy, hipCtx_t) +IREE_HIP_PFN_DECL(hipDeviceGet, hipDevice_t *, int) // No direct, need to modify +IREE_HIP_PFN_DECL(hipGetDeviceCount, int *) +IREE_HIP_PFN_DECL(hipGetDeviceProperties, hipDeviceProp_t *, int) +IREE_HIP_PFN_DECL(hipDeviceGetName, char *, int, + hipDevice_t) // No direct, need to modify +IREE_HIP_PFN_STR_DECL( + hipGetErrorName, + hipError_t) // Unlike other functions hipGetErrorName(hipError_t) return + // const char* instead of hipError_t so it uses a different + // macro +IREE_HIP_PFN_STR_DECL( + hipGetErrorString, + hipError_t) // Unlike other functions hipGetErrorName(hipError_t) return + // const char* instead of hipError_t so it uses a different + // macro +IREE_HIP_PFN_DECL(hipInit, unsigned int) +IREE_HIP_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, hipStream_t, void **, void **) +IREE_HIP_PFN_DECL(hipMemset, void *, int, size_t) +IREE_HIP_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t) +IREE_HIP_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) +IREE_HIP_PFN_DECL(hipMemsetD16Async, void *, short, size_t, hipStream_t) +IREE_HIP_PFN_DECL(hipMemsetD8Async, void *, char, size_t, hipStream_t) +IREE_HIP_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind) +IREE_HIP_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, + hipStream_t) +IREE_HIP_PFN_DECL(hipMalloc, void **, size_t) +IREE_HIP_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int) +IREE_HIP_PFN_DECL(hipFree, void *) +IREE_HIP_PFN_DECL(hipHostFree, void *) +IREE_HIP_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int) +IREE_HIP_PFN_DECL(hipHostGetDevicePointer, void **, void *, unsigned int) +IREE_HIP_PFN_DECL(hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *) +IREE_HIP_PFN_DECL(hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, + hipJitOption *, void **) +IREE_HIP_PFN_DECL(hipModuleLoadData, hipModule_t *, const void *) +IREE_HIP_PFN_DECL(hipModuleUnload, hipModule_t) +IREE_HIP_PFN_DECL(hipStreamCreateWithFlags, hipStream_t *, unsigned int) +IREE_HIP_PFN_DECL(hipStreamDestroy, hipStream_t) +IREE_HIP_PFN_DECL(hipStreamSynchronize, hipStream_t) +IREE_HIP_PFN_DECL(hipStreamWaitEvent, hipStream_t, hipEvent_t, unsigned int) +IREE_HIP_PFN_DECL(hipEventCreate, hipEvent_t *) +IREE_HIP_PFN_DECL(hipEventDestroy, hipEvent_t) +IREE_HIP_PFN_DECL(hipEventElapsedTime, float *, hipEvent_t, hipEvent_t) +IREE_HIP_PFN_DECL(hipEventQuery, hipEvent_t) +IREE_HIP_PFN_DECL(hipEventRecord, hipEvent_t, hipStream_t) +IREE_HIP_PFN_DECL(hipEventSynchronize, hipEvent_t) +IREE_HIP_PFN_DECL(hipDeviceGetAttribute, int *, hipDeviceAttribute_t, int) +IREE_HIP_PFN_DECL(hipFuncSetAttribute, const void *, hipFuncAttribute, int) +IREE_HIP_PFN_DECL(hipDeviceGetUuid, hipUUID *, hipDevice_t) +IREE_HIP_PFN_DECL(hipDevicePrimaryCtxRetain, hipCtx_t *, hipDevice_t) +IREE_HIP_PFN_DECL(hipCtxGetDevice, hipDevice_t *) +IREE_HIP_PFN_DECL(hipCtxSetCurrent, hipCtx_t) +IREE_HIP_PFN_DECL(hipDevicePrimaryCtxRelease, hipDevice_t) diff --git a/experimental/hip/dynamic_symbols.c b/experimental/hip/dynamic_symbols.c new file mode 100644 index 000000000000..bfb9e7bee32b --- /dev/null +++ b/experimental/hip/dynamic_symbols.c @@ -0,0 +1,83 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/hip/dynamic_symbols.h" + +#include + +#include "experimental/hip/status_util.h" +#include "iree/base/assert.h" +#include "iree/base/internal/dynamic_library.h" +#include "iree/base/target_platform.h" +#include "iree/base/tracing.h" + +//===----------------------------------------------------------------------===// +// HIP dynamic symbols +//===----------------------------------------------------------------------===// + +static const char* iree_hal_hip_dylib_names[] = { +#if defined(IREE_PLATFORM_WINDOWS) + "amdhip64.dll", +#else + "libamdhip64.so", +#endif // IREE_PLATFORM_WINDOWS +}; + +// Resolves all HIP dynamic symbols in `dynamic_symbol_tables.h` +static iree_status_t iree_hal_hip_dynamic_symbols_resolve_all( + iree_hal_hip_dynamic_symbols_t* syms) { +#define IREE_HIP_PFN_DECL(hip_symbol_name, ...) \ + { \ + static const char* name = #hip_symbol_name; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->dylib, name, (void**)&syms->hip_symbol_name)); \ + } +#define IREE_HIP_PFN_STR_DECL(hip_symbol_name, ...) IREE_HIP_PFN_DECL(hip_symbol_name, ...) +#include "experimental/hip/dynamic_symbol_tables.h" // IWYU pragma: keep +#undef IREE_HIP_PFN_DECL +#undef IREE_HIP_PFN_STR_DECL + return iree_ok_status(); +} + +// #undef IREE_CONCAT + +iree_status_t iree_hal_hip_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + iree_hal_hip_dynamic_symbols_t* out_syms) { + IREE_ASSERT_ARGUMENT(out_syms); + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_syms, 0, sizeof(*out_syms)); + iree_status_t status = iree_dynamic_library_load_from_files( + IREE_ARRAYSIZE(iree_hal_hip_dylib_names), iree_hal_hip_dylib_names, + IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->dylib); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + status = iree_make_status( + IREE_STATUS_UNAVAILABLE, + "HIP runtime library 'amdhip64.dll'/'libamdhip64.so' not available; please" + "ensure installed and in dynamic library search path"); + } + if (iree_status_is_ok(status)) { + status = iree_hal_hip_dynamic_symbols_resolve_all(out_syms); + } + if (!iree_status_is_ok(status)) { + iree_hal_hip_dynamic_symbols_deinitialize(out_syms); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hip_dynamic_symbols_deinitialize( + iree_hal_hip_dynamic_symbols_t* syms) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_dynamic_library_release(syms->dylib); + memset(syms, 0, sizeof(*syms)); + + IREE_TRACE_ZONE_END(z0); +} diff --git a/experimental/hip/dynamic_symbols.h b/experimental/hip/dynamic_symbols.h new file mode 100644 index 000000000000..8183e9463600 --- /dev/null +++ b/experimental/hip/dynamic_symbols.h @@ -0,0 +1,59 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HIP_DYNAMIC_SYMBOLS_H_ +#define IREE_EXPERIMENTAL_HIP_DYNAMIC_SYMBOLS_H_ + +#include "experimental/hip/hip_headers.h" +#include "iree/base/api.h" +#include "iree/base/internal/dynamic_library.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// iree_dynamic_library_t allows dynamically loading a subset of HIP driver API. +// We load all the symbols in `dynamic_symbol_tables.h` and +// fail if any of the symbol is not available. The functions signatures are +// matching the declarations in `hipruntime.h`. + +//===----------------------------------------------------------------------===// +// HIP dynamic symbols +//===----------------------------------------------------------------------===// + +// HIP driver API dynamic symbols. +typedef struct iree_hal_hip_dynamic_symbols_t { + // The dynamic library handle. + iree_dynamic_library_t* dylib; + + // Concrete HIP symbols defined by including the `dynamic_symbol_tables.h`. +#define IREE_HIP_PFN_DECL(hipSymbolName, ...) \ + hipError_t (*hipSymbolName)(__VA_ARGS__); +#define IREE_HIP_PFN_STR_DECL(hipSymbolName, ...) \ + const char* (*hipSymbolName)(__VA_ARGS__); +#include "experimental/hip/dynamic_symbol_tables.h" // IWYU pragma: export +#undef IREE_HIP_PFN_DECL +#undef IREE_HIP_PFN_STR_DECL +} iree_hal_hip_dynamic_symbols_t; + +// Initializes |out_syms| in-place with dynamically loaded HIP symbols. +// iree_hal_hip_dynamic_symbols_deinitialize must be used to release the +// library resources. +iree_status_t iree_hal_hip_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + iree_hal_hip_dynamic_symbols_t* out_syms); + +// Deinitializes |syms| by unloading the backing library. All function pointers +// will be invalidated. They _may_ still work if there are other reasons the +// library remains loaded so be careful. +void iree_hal_hip_dynamic_symbols_deinitialize( + iree_hal_hip_dynamic_symbols_t* syms); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HIP_DYNAMIC_SYMBOLS_H_ diff --git a/experimental/hip/dynamic_symbols_test.cc b/experimental/hip/dynamic_symbols_test.cc new file mode 100644 index 000000000000..cad2310ffefa --- /dev/null +++ b/experimental/hip/dynamic_symbols_test.cc @@ -0,0 +1,50 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/hip/dynamic_symbols.h" + +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" + +namespace iree { +namespace hal { +namespace hip { +namespace { + +#define HIP_CHECK_ERRORS(expr) \ + { \ + hipError_t status = expr; \ + ASSERT_EQ(hipSuccess, status); \ + } + +TEST(DynamicSymbolsTest, CreateFromSystemLoader) { + iree_hal_hip_dynamic_symbols_t symbols; + iree_status_t status = iree_hal_hip_dynamic_symbols_initialize( + iree_allocator_system(), &symbols); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + std::cerr << "Symbols cannot be loaded, skipping test."; + GTEST_SKIP(); + } + + int device_count = 0; + HIP_CHECK_ERRORS(symbols.hipInit(0)); + HIP_CHECK_ERRORS(symbols.hipGetDeviceCount(&device_count)); + if (device_count > 0) { + hipDevice_t device; + HIP_CHECK_ERRORS(symbols.hipDeviceGet(&device, /*ordinal=*/0)); + } + + iree_hal_hip_dynamic_symbols_deinitialize(&symbols); +} + +} // namespace +} // namespace hip +} // namespace hal +} // namespace iree diff --git a/experimental/hip/hip_driver.c b/experimental/hip/hip_driver.c new file mode 100644 index 000000000000..9e79d7d38262 --- /dev/null +++ b/experimental/hip/hip_driver.c @@ -0,0 +1,438 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "experimental/hip/api.h" +#include "experimental/hip/dynamic_symbols.h" +#include "experimental/hip/status_util.h" +#include "iree/base/api.h" +#include "iree/base/assert.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" + +// Maximum device name length supported by the HIP HAL driver. +#define IREE_HAL_HIP_MAX_DEVICE_NAME_LENGTH 128 + +// Utility macros to convert between HIPDevice and iree_hal_device_id_t. +#define IREE_HIPDEVICE_TO_DEVICE_ID(device) (iree_hal_device_id_t)((device) + 1) +#define IREE_DEVICE_ID_TO_HIPDEVICE(device_id) (hipDevice_t)((device_id)-1) + +typedef struct iree_hal_hip_driver_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + iree_allocator_t host_allocator; + + // Identifier used for registering the driver in the IREE driver registry. + iree_string_view_t identifier; + // HIP driver API dynamic symbols to interact with the HIP system. + iree_hal_hip_dynamic_symbols_t HIP_symbols; + + // The index of the default HIP device to use if multiple ones are available. + int default_device_index; +} iree_hal_hip_driver_t; + +static const iree_hal_driver_vtable_t iree_hal_hip_driver_vtable; + +static iree_hal_hip_driver_t* iree_hal_hip_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_driver_vtable); + return (iree_hal_hip_driver_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_hip_driver_options_initialize( + iree_hal_hip_driver_options_t* out_options) { + IREE_ASSERT_ARGUMENT(out_options); + memset(out_options, 0, sizeof(*out_options)); + out_options->default_device_index = 0; +} + +static iree_status_t iree_hal_hip_driver_create_internal( + iree_string_view_t identifier, + const iree_hal_hip_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + iree_hal_hip_driver_t* driver = NULL; + iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); + + iree_hal_resource_initialize(&iree_hal_hip_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + iree_sizeof_struct(*driver)); + driver->default_device_index = options->default_device_index; + + iree_status_t status = iree_hal_hip_dynamic_symbols_initialize( + host_allocator, &driver->HIP_symbols); + + if (iree_status_is_ok(status)) { + *out_driver = (iree_hal_driver_t*)driver; + } else { + iree_hal_driver_release((iree_hal_driver_t*)driver); + } + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_hip_driver_create( + iree_string_view_t identifier, + const iree_hal_hip_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_hip_driver_create_internal( + identifier, options, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hip_driver_destroy(iree_hal_driver_t* base_driver) { + IREE_ASSERT_ARGUMENT(base_driver); + + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_dynamic_symbols_deinitialize(&driver->HIP_symbols); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +// Initializes the HIP system. +static iree_status_t iree_hal_hip_init(iree_hal_hip_driver_t* driver) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HIP_RESULT_TO_STATUS(&driver->HIP_symbols, hipInit(0), "hipInit"); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Populates device information from the given HIP physical device handle. +// |out_device_info| must point to valid memory and additional data will be +// appended to |buffer_ptr| and the new pointer is returned. +static iree_status_t iree_hal_hip_populate_device_info( + hipDevice_t device, iree_hal_hip_dynamic_symbols_t* syms, + uint8_t* buffer_ptr, uint8_t** out_buffer_ptr, + iree_hal_device_info_t* out_device_info) { + *out_buffer_ptr = buffer_ptr; + + char device_name[IREE_HAL_HIP_MAX_DEVICE_NAME_LENGTH]; + + IREE_HIP_RETURN_IF_ERROR( + syms, hipDeviceGetName(device_name, sizeof(device_name), device), + "hipDeviceGetName"); + memset(out_device_info, 0, sizeof(*out_device_info)); + out_device_info->device_id = IREE_HIPDEVICE_TO_DEVICE_ID(device); + + hipUUID device_uuid; + IREE_HIP_RETURN_IF_ERROR(syms, hipDeviceGetUuid(&device_uuid, device), + "hipDeviceGetUuid"); + char device_path_str[4 + 36 + 1] = {0}; + snprintf(device_path_str, sizeof(device_path_str), + "GPU-" + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x", + (uint8_t)device_uuid.bytes[0], (uint8_t)device_uuid.bytes[1], + (uint8_t)device_uuid.bytes[2], (uint8_t)device_uuid.bytes[3], + (uint8_t)device_uuid.bytes[4], (uint8_t)device_uuid.bytes[5], + (uint8_t)device_uuid.bytes[6], (uint8_t)device_uuid.bytes[7], + (uint8_t)device_uuid.bytes[8], (uint8_t)device_uuid.bytes[9], + (uint8_t)device_uuid.bytes[10], (uint8_t)device_uuid.bytes[11], + (uint8_t)device_uuid.bytes[12], (uint8_t)device_uuid.bytes[13], + (uint8_t)device_uuid.bytes[14], (uint8_t)device_uuid.bytes[15]); + buffer_ptr += iree_string_view_append_to_buffer( + iree_make_string_view(device_path_str, + IREE_ARRAYSIZE(device_path_str) - 1), + &out_device_info->path, (char*)buffer_ptr); + + iree_string_view_t device_name_str = + iree_make_string_view(device_name, strlen(device_name)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name_str, &out_device_info->name, (char*)buffer_ptr); + + *out_buffer_ptr = buffer_ptr; + return iree_ok_status(); +} + +// Returns true if the device meets all the required capabilities. +static bool iree_hal_hip_is_valid_device(iree_hal_hip_driver_t* driver, + hipDevice_t device) { + return true; +} + +static iree_status_t iree_hal_hip_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_host_size_t* out_device_info_count, + iree_hal_device_info_t** out_device_infos) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device_info_count); + IREE_ASSERT_ARGUMENT(out_device_infos); + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure HIP is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hip_init(driver)); + + // Query the number of available HIP devices. + int device_count = 0; + IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(z0, &driver->HIP_symbols, + hipGetDeviceCount(&device_count), + "hipGetDeviceCount"); + + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + device_count * (sizeof(iree_hal_device_info_t) + + IREE_HAL_HIP_MAX_DEVICE_NAME_LENGTH * sizeof(char)); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + + int valid_device_count = 0; + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); + for (iree_host_size_t i = 0; i < device_count; ++i) { + hipDevice_t device = 0; + status = IREE_HIP_RESULT_TO_STATUS(&driver->HIP_symbols, + hipDeviceGet(&device, i), "hipDeviceGet"); + if (!iree_status_is_ok(status)) break; + if (!iree_hal_hip_is_valid_device(driver, device)) continue; + status = iree_hal_hip_populate_device_info( + device, &driver->HIP_symbols, buffer_ptr, &buffer_ptr, + &device_infos[valid_device_count]); + if (!iree_status_is_ok(status)) break; + valid_device_count++; + } + } + if (iree_status_is_ok(status)) { + *out_device_info_count = valid_device_count; + *out_device_infos = device_infos; + } else { + iree_allocator_free(host_allocator, device_infos); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_driver_dump_device_info( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_string_builder_t* builder) { + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + hipDevice_t device = (hipDevice_t)device_id; + if (!device) return iree_ok_status(); + // TODO: dump detailed device info. + (void)driver; + (void)device; + return iree_ok_status(); +} + +static iree_status_t iree_hal_hip_driver_select_default_device( + iree_hal_driver_t* base_driver, iree_hal_hip_dynamic_symbols_t* syms, + int default_device_index, iree_allocator_t host_allocator, + hipDevice_t* out_device) { + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t device_count = 0; + IREE_RETURN_IF_ERROR(iree_hal_hip_driver_query_available_devices( + base_driver, host_allocator, &device_count, &device_infos)); + + iree_status_t status = iree_ok_status(); + if (device_count == 0) { + status = iree_make_status(IREE_STATUS_UNAVAILABLE, + "no compatible HIP devices were found"); + } else if (default_device_index >= device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %ld enumerated)", + default_device_index, device_count); + } else { + *out_device = IREE_DEVICE_ID_TO_HIPDEVICE( + device_infos[default_device_index].device_id); + } + iree_allocator_free(host_allocator, device_infos); + + return status; +} + +static iree_status_t iree_hal_hip_driver_create_device_by_id( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_host_size_t param_count, const iree_string_pair_t* params, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(out_device); + + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure HIP is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hip_init(driver)); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + hipDevice_t device = 0; + if (device_id == IREE_HAL_DEVICE_ID_DEFAULT) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_driver_select_default_device( + base_driver, &driver->HIP_symbols, + driver->default_device_index, host_allocator, &device)); + } else { + device = IREE_DEVICE_ID_TO_HIPDEVICE(device_id); + } + (void)device; + + IREE_TRACE_ZONE_END(z0); + return iree_status_from_code(IREE_STATUS_UNIMPLEMENTED); +} + +static iree_status_t iree_hal_hip_driver_create_device_by_uuid( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + const hipUUID* device_uuid, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + + // Ensure HIP is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_hip_init(driver)); + + // HIP doesn't have an API to do this so we need to scan all devices to + // find the one with the matching UUID. + int device_count = 0; + IREE_HIP_RETURN_IF_ERROR(&driver->HIP_symbols, + hipGetDeviceCount(&device_count), + "hipGetDeviceCount"); + hipDevice_t device = 0; + bool found_device = false; + for (int i = 0; i < device_count; i++) { + IREE_HIP_RETURN_IF_ERROR(&driver->HIP_symbols, hipDeviceGet(&device, i), + "hipDeviceGet"); + hipUUID query_uuid; + IREE_HIP_RETURN_IF_ERROR(&driver->HIP_symbols, + hipDeviceGetUuid(&query_uuid, device), + "hipDeviceGetUuid"); + if (memcmp(&device_uuid->bytes[0], &query_uuid.bytes[0], + sizeof(device_uuid)) == 0) { + found_device = true; + break; + } + } + if (!found_device) { + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "HIP device with UUID GPU-" + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x" + " not found", + (uint8_t)device_uuid->bytes[0], (uint8_t)device_uuid->bytes[1], + (uint8_t)device_uuid->bytes[2], (uint8_t)device_uuid->bytes[3], + (uint8_t)device_uuid->bytes[4], (uint8_t)device_uuid->bytes[5], + (uint8_t)device_uuid->bytes[6], (uint8_t)device_uuid->bytes[7], + (uint8_t)device_uuid->bytes[8], (uint8_t)device_uuid->bytes[9], + (uint8_t)device_uuid->bytes[10], (uint8_t)device_uuid->bytes[11], + (uint8_t)device_uuid->bytes[12], (uint8_t)device_uuid->bytes[13], + (uint8_t)device_uuid->bytes[14], (uint8_t)device_uuid->bytes[15]); + } + + iree_status_t status = iree_hal_hip_driver_create_device_by_id( + base_driver, IREE_HIPDEVICE_TO_DEVICE_ID(device), param_count, params, + host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_hip_driver_create_device_by_index( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + int device_index, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + + // Ensure HIP is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_hip_init(driver)); + + // Query the number of available HIP devices. + int device_count = 0; + IREE_HIP_RETURN_IF_ERROR(&driver->HIP_symbols, + hipGetDeviceCount(&device_count), + "hipGetDeviceCount"); + if (device_index >= device_count) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "device %d not found (of %d enumerated)", + device_index, device_count); + } + + hipDevice_t device = 0; + IREE_HIP_RETURN_IF_ERROR(&driver->HIP_symbols, + hipDeviceGet(&device, device_index), "hipDeviceGet"); + + iree_status_t status = iree_hal_hip_driver_create_device_by_id( + base_driver, IREE_HIPDEVICE_TO_DEVICE_ID(device), param_count, params, + host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_hip_driver_create_device_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + iree_string_view_t device_path, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(out_device); + + if (iree_string_view_is_empty(device_path)) { + return iree_hal_hip_driver_create_device_by_id( + base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params, + host_allocator, out_device); + } + + if (iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-"))) { + // UUID as returned by cuDeviceGetUuid. + hipUUID device_uuid; + if (!iree_string_view_parse_hex_bytes(device_path, + IREE_ARRAYSIZE(device_uuid.bytes), + (uint8_t*)device_uuid.bytes)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid GPU UUID: '%.*s'", (int)device_path.size, + device_path.data); + } + return iree_hal_hip_driver_create_device_by_uuid( + base_driver, driver_name, &device_uuid, param_count, params, + host_allocator, out_device); + } + + // Try to parse as a device index. + int device_index = 0; + if (iree_string_view_atoi_int32(device_path, &device_index)) { + return iree_hal_hip_driver_create_device_by_index( + base_driver, driver_name, device_index, param_count, params, + host_allocator, out_device); + } + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); +} + +static const iree_hal_driver_vtable_t iree_hal_hip_driver_vtable = { + .destroy = iree_hal_hip_driver_destroy, + .query_available_devices = iree_hal_hip_driver_query_available_devices, + .dump_device_info = iree_hal_hip_driver_dump_device_info, + .create_device_by_id = iree_hal_hip_driver_create_device_by_id, + .create_device_by_path = iree_hal_hip_driver_create_device_by_path, +}; diff --git a/experimental/hip/hip_headers.h b/experimental/hip/hip_headers.h new file mode 100644 index 000000000000..f3e9c2707bcb --- /dev/null +++ b/experimental/hip/hip_headers.h @@ -0,0 +1,16 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HIP_HIP_HEADERS_H_ +#define IREE_EXPERIMENTAL_HIP_HIP_HEADERS_H_ + +#if defined(IREE_PTR_SIZE_32) +#error 32-bit not supported on ROCm +#endif // defined(IREE_PTR_SIZE_32) + +#include "hip/hip_runtime.h" // IWYU pragma: export + +#endif // IREE_EXPERIMENTAL_HIP_HIP_HEADERS_H_ diff --git a/experimental/hip/registration/CMakeLists.txt b/experimental/hip/registration/CMakeLists.txt new file mode 100644 index 000000000000..48a3069f9a46 --- /dev/null +++ b/experimental/hip/registration/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + registration + HDRS + "driver_module.h" + SRCS + "driver_module.c" + DEPS + iree::base + iree::base::core_headers + iree::experimental::hip + iree::hal + DEFINES + "IREE_HAVE_HAL_HIP_DRIVER_MODULE=1" + PUBLIC +) diff --git a/experimental/hip/registration/driver_module.c b/experimental/hip/registration/driver_module.c new file mode 100644 index 000000000000..b7ebc57558b9 --- /dev/null +++ b/experimental/hip/registration/driver_module.c @@ -0,0 +1,106 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/hip/registration/driver_module.h" + +#include +#include + +#include "experimental/hip/api.h" +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/base/status.h" +#include "iree/base/tracing.h" + +IREE_FLAG(int32_t, hip_default_index, 0, + "Specifies the index of the default HIP device to use"); + +IREE_FLAG(bool, hip_default_index_from_mpi, true, + "Infers the default HIP device index from the PMI_RANK or\n" + "OMPI_COMM_WORLD_LOCAL_RANK environment variables when set"); + +static bool iree_try_parse_env_i32(const char* var_name, int32_t* out_value) { + const char* var_value = getenv(var_name); + if (!var_value || strlen(var_value) == 0) return false; + return iree_string_view_atoi_int32(iree_make_cstring_view(var_value), + out_value); +} + +// Tries to infer the device index using the local MPI rank from environment +// variables; otherwise returns |default_index|. +// +// This makes it easy to use N devices on a single system when running via +// `mpiexec`. +static int32_t iree_hal_hip_infer_device_index_from_env( + int32_t default_index) { + // TODO: try more env vars from other implementations. This covers Intel/MS + // and OpenMPI today. + int32_t result = 0; + if (iree_try_parse_env_i32("PMI_RANK", &result) || + iree_try_parse_env_i32("OMPI_COMM_WORLD_LOCAL_RANK", &result)) { + return result; + } + return default_index; +} + +static iree_status_t iree_hal_hip_driver_factory_enumerate( + void* self, iree_host_size_t* out_driver_info_count, + const iree_hal_driver_info_t** out_driver_infos) { + IREE_ASSERT_ARGUMENT(out_driver_info_count); + IREE_ASSERT_ARGUMENT(out_driver_infos); + IREE_TRACE_ZONE_BEGIN(z0); + + static const iree_hal_driver_info_t driver_infos[1] = {{ + .driver_name = IREE_SVL("hip"), + .full_name = IREE_SVL("HIP HAL driver (via dylib)"), + }}; + *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); + *out_driver_infos = driver_infos; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hip_driver_factory_try_create( + void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(out_driver); + + if (!iree_string_view_equal(driver_name, IREE_SV("hip"))) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "no driver '%.*s' is provided by this factory", + (int)driver_name.size, driver_name.data); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_driver_options_t driver_options; + iree_hal_hip_driver_options_initialize(&driver_options); + + driver_options.default_device_index = FLAG_hip_default_index; + if (FLAG_hip_default_index_from_mpi) { + driver_options.default_device_index = + iree_hal_hip_infer_device_index_from_env( + driver_options.default_device_index); + } + + iree_status_t status = iree_hal_hip_driver_create( + driver_name, &driver_options, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + + return status; +} + +IREE_API_EXPORT iree_status_t +iree_hal_hip_driver_module_register(iree_hal_driver_registry_t* registry) { + static const iree_hal_driver_factory_t factory = { + .self = NULL, + .enumerate = iree_hal_hip_driver_factory_enumerate, + .try_create = iree_hal_hip_driver_factory_try_create, + }; + return iree_hal_driver_registry_register_factory(registry, &factory); +} diff --git a/experimental/hip/registration/driver_module.h b/experimental/hip/registration/driver_module.h new file mode 100644 index 000000000000..1c3a0078f445 --- /dev/null +++ b/experimental/hip/registration/driver_module.h @@ -0,0 +1,25 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HIP_REGISTRATION_DRIVER_MODULE_H_ +#define IREE_EXPERIMENTAL_HIP_REGISTRATION_DRIVER_MODULE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Registers the HIP HAL driver to the given |registry|. +IREE_API_EXPORT iree_status_t +iree_hal_hip_driver_module_register(iree_hal_driver_registry_t* registry); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HIP_REGISTRATION_DRIVER_MODULE_H_ diff --git a/experimental/hip/status_util.c b/experimental/hip/status_util.c new file mode 100644 index 000000000000..43d0622b6e72 --- /dev/null +++ b/experimental/hip/status_util.c @@ -0,0 +1,43 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/hip/status_util.h" + +#include + +#include "experimental/hip/dynamic_symbols.h" +#include "iree/base/status.h" + +// TODO: Map HIP error strings with their corresponding IREE error state +// classification. + +// Converts HIP |error_name| to the corresponding IREE status code. +static iree_status_code_t iree_hal_hip_error_name_to_status_code( + const char* error_name) { + return IREE_STATUS_UNKNOWN; +} + +iree_status_t iree_hal_hip_result_to_status( + const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, + const char* file, uint32_t line) { + if (IREE_LIKELY(result == hipSuccess)) { + return iree_ok_status(); + } + + const char *error_name = syms->hipGetErrorName(result); + if (result == hipErrorUnknown) { + error_name = "UNKNOWN"; + } + + const char *error_string = syms->hipGetErrorString(result); + if (result == hipErrorUnknown) { + error_string = "Unknown error."; + } + + return iree_make_status_with_location( + file, line, iree_hal_hip_error_name_to_status_code(error_name), + "HIP driver error '%s' (%d): %s", error_name, result, error_string); +} diff --git a/experimental/hip/status_util.h b/experimental/hip/status_util.h new file mode 100644 index 000000000000..6138210437f8 --- /dev/null +++ b/experimental/hip/status_util.h @@ -0,0 +1,72 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HIP_STATUS_UTIL_H_ +#define IREE_EXPERIMENTAL_HIP_STATUS_UTIL_H_ + +#include + +#include "experimental/hip/dynamic_symbols.h" +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// HIP result macros +//===----------------------------------------------------------------------===// + +// Converts a hipError_t to an iree_status_t. +// +// Usage: +// iree_status_t status = IREE_HIP_RESULT_TO_STATUS(hip_symbols, +// hipDoThing(...)); +#define IREE_HIP_RESULT_TO_STATUS(syms, expr, ...) \ + iree_hal_hip_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) + +// IREE_RETURN_IF_ERROR but implicitly converts the hipError_t return value to +// an iree_status_t. +// +// Usage: +// IREE_HIP_RETURN_IF_ERROR(hip_symbols, hipDoThing(...), "message"); +#define IREE_HIP_RETURN_IF_ERROR(syms, expr, ...) \ + IREE_RETURN_IF_ERROR(iree_hal_hip_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__), \ + __VA_ARGS__) + +// IREE_RETURN_IF_ERROR but ends the current zone and implicitly converts the +// hipError_t return value to an iree_status_t. +// +// Usage: +// IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(zone_id, hip_symbols, +// hipDoThing(...), "message"); +#define IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(zone_id, syms, expr, ...) \ + IREE_RETURN_AND_END_ZONE_IF_ERROR( \ + zone_id, \ + iree_hal_hip_result_to_status((syms), ((syms)->expr), __FILE__, \ + __LINE__), \ + __VA_ARGS__) + +// IREE_IGNORE_ERROR but implicitly converts the hipError_t return value to an +// iree_status_t. +// +// Usage: +// IREE_HIP_IGNORE_ERROR(hip_symbols, hipDoThing(...)); +#define IREE_HIP_IGNORE_ERROR(syms, expr) \ + IREE_IGNORE_ERROR(iree_hal_hip_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__)) + +// Converts a hipError_t to an iree_status_t object. +iree_status_t iree_hal_hip_result_to_status( + const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, + const char* file, uint32_t line); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HIP_STATUS_UTIL_H_