Skip to content

Commit

Permalink
Add wrap interface for calling into rocblas shared library
Browse files Browse the repository at this point in the history
  • Loading branch information
hsharsha committed Jun 26, 2024
1 parent edc830c commit e2b1753
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
__macro(rocblas_destroy_handle) \
__macro(rocblas_get_stream) \
__macro(rocblas_set_stream) \
__macro(rocblas_set_atomics_mode)
__macro(rocblas_set_atomics_mode) \
__macro(rocblas_get_version_string) \
__macro(rocblas_get_version_string_size)

// clang-format on

Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1351,16 +1351,16 @@ absl::Status ROCMBlas::DoBlasGemmStridedBatched(
}

absl::Status ROCMBlas::GetVersion(string *version) {
#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1
#if TF_ROCM_VERSION > 60100 // Not available in ROCM-6.1
absl::MutexLock lock{&mu_};
size_t len = 0;
if (auto res = rocblas_get_version_string_size(&len);
if (auto res = wrap::rocblas_get_version_string_size(&len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
}
std::vector<char> buf(len + 1);
if (auto res = rocblas_get_version_string(buf.data(), len);
if (auto res = wrap::rocblas_get_version_string(buf.data(), len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
Expand Down

0 comments on commit e2b1753

Please sign in to comment.