From e2b17535e548efa86e707179096e55bb98b84da6 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Tue, 25 Jun 2024 14:04:49 -0700 Subject: [PATCH] Add wrap interface for calling into rocblas shared library --- third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h | 4 +++- third_party/xla/xla/stream_executor/rocm/rocm_blas.cc | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h index a637d68428d5e3..3992cf9ab48a1c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h @@ -270,7 +270,9 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_destroy_handle) \ __macro(rocblas_get_stream) \ __macro(rocblas_set_stream) \ - __macro(rocblas_set_atomics_mode) + __macro(rocblas_set_atomics_mode) \ + __macro(rocblas_get_version_string) \ + __macro(rocblas_get_version_string_size) // clang-format on diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index 1507af51e4e459..f3b21504599f88 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -1351,16 +1351,16 @@ absl::Status ROCMBlas::DoBlasGemmStridedBatched( } absl::Status ROCMBlas::GetVersion(string *version) { -#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1 +#if TF_ROCM_VERSION > 60100 // Not available in ROCM-6.1 absl::MutexLock lock{&mu_}; size_t len = 0; - if (auto res = rocblas_get_version_string_size(&len); + if (auto res = wrap::rocblas_get_version_string_size(&len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res))); } std::vector buf(len + 1); - if (auto res = rocblas_get_version_string(buf.data(), len); + if (auto res = wrap::rocblas_get_version_string(buf.data(), len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res)));