Skip to content

Commit

Permalink
Merge pull request #16 from ROCm/qa-0428-rocm62-revert
Browse files Browse the repository at this point in the history
Revert "adding new ROCM-6.2 features"
  • Loading branch information
i-chaochen authored Jun 19, 2024
2 parents c9d02b5 + 8245435 commit 0955566
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 81 deletions.
15 changes: 1 addition & 14 deletions xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ limitations under the License.

// needed for rocblas_gemm_ex_get_solutions* functionality
#define ROCBLAS_BETA_FEATURES_API

#include "rocm/rocm_config.h"
#include "rocm/include/rocblas/rocblas.h"
#include "xla/stream_executor/gpu/gpu_activation.h"
#include "xla/stream_executor/platform/dso_loader.h"
Expand Down Expand Up @@ -276,21 +274,10 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
__macro(rocblas_get_stream) \
__macro(rocblas_set_stream) \
__macro(rocblas_set_atomics_mode)
// clang-format on

FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)

#if TF_ROCM_VERSION >= 60200

// clang-format off
#define FOREACH_ROCBLAS_API_62(__macro) \
__macro(rocblas_get_version_string_size) \
__macro(rocblas_get_version_string)
// clang-format on

FOREACH_ROCBLAS_API_62(ROCBLAS_API_WRAPPER)

#endif // TF_ROCM_VERSION >= 60200
FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)

} // namespace wrap
} // namespace stream_executor
Expand Down
8 changes: 4 additions & 4 deletions xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1212,17 +1212,17 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched)
}
}

absl::Status ROCMBlas::GetVersion(std::string *version) {
#if TF_ROCM_VERSION >= 60200
absl::Status ROCMBlas::GetVersion(string *version) {
#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1
absl::MutexLock lock{&mu_};
size_t len = 0;
if (auto res = wrap::rocblas_get_version_string_size(&len);
if (auto res = rocblas_get_version_string_size(&len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
}
std::vector<char> buf(len + 1);
if (auto res = wrap::rocblas_get_version_string(buf.data(), len);
if (auto res = rocblas_get_version_string(buf.data(), len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
Expand Down
37 changes: 4 additions & 33 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"

#include "xla/stream_executor/gpu/gpu_diagnostics.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/platform/port.h"
Expand Down Expand Up @@ -111,31 +110,6 @@ string ToString(hipError_t result) {
OSTREAM_ROCM_ERROR(ContextAlreadyInUse)
OSTREAM_ROCM_ERROR(PeerAccessUnsupported)
OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM.
#if TF_ROCM_VERSION >= 60200
OSTREAM_ROCM_ERROR(LaunchTimeOut)
OSTREAM_ROCM_ERROR(PeerAccessAlreadyEnabled)
OSTREAM_ROCM_ERROR(PeerAccessNotEnabled)
OSTREAM_ROCM_ERROR(SetOnActiveProcess)
OSTREAM_ROCM_ERROR(ContextIsDestroyed)
OSTREAM_ROCM_ERROR(Assert)
OSTREAM_ROCM_ERROR(HostMemoryAlreadyRegistered)
OSTREAM_ROCM_ERROR(HostMemoryNotRegistered)
OSTREAM_ROCM_ERROR(LaunchFailure)
OSTREAM_ROCM_ERROR(CooperativeLaunchTooLarge)
OSTREAM_ROCM_ERROR(NotSupported)
OSTREAM_ROCM_ERROR(StreamCaptureUnsupported)
OSTREAM_ROCM_ERROR(StreamCaptureInvalidated)
OSTREAM_ROCM_ERROR(StreamCaptureMerge)
OSTREAM_ROCM_ERROR(StreamCaptureUnmatched)
OSTREAM_ROCM_ERROR(StreamCaptureUnjoined)
OSTREAM_ROCM_ERROR(StreamCaptureIsolation)
OSTREAM_ROCM_ERROR(StreamCaptureImplicit)
OSTREAM_ROCM_ERROR(CapturedEvent)
OSTREAM_ROCM_ERROR(StreamCaptureWrongThread)
OSTREAM_ROCM_ERROR(GraphExecUpdateFailure)
OSTREAM_ROCM_ERROR(RuntimeMemory)
OSTREAM_ROCM_ERROR(RuntimeOther)
#endif // TF_ROCM_VERSION >= 60200
default:
return absl::StrCat("hipError_t(", static_cast<int>(result), ")");
}
Expand Down Expand Up @@ -1120,22 +1094,19 @@ struct BitPatternToValue {
VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes
<<" func: " << (const void*)function;
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;

auto res = hipSuccess;
#if TF_ROCM_VERSION < 60200
// for in-process kernel this function returns mangled kernel function name,
// and null otherwise
auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream);

auto res = hipSuccess;
if (name != nullptr) {
res = wrap::hipLaunchKernel((const void*)function,
dim3(grid_dim_x, grid_dim_y, grid_dim_z),
dim3(block_dim_x, block_dim_y, block_dim_z),
kernel_params, shared_mem_bytes, stream);
} else
#endif // TF_ROCM_VERSION < 60200
{
} else {
res = wrap::hipModuleLaunchKernel(
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
Expand Down
14 changes: 0 additions & 14 deletions xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.

#define __HIP_DISABLE_CPP_FUNCTIONS__

#include "rocm/rocm_config.h"
#include "rocm/include/hip/hip_runtime.h"
#include "xla/stream_executor/platform/dso_loader.h"
#include "xla/stream_executor/platform/port.h"
Expand Down Expand Up @@ -175,19 +174,6 @@ namespace wrap {

HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)

#if TF_ROCM_VERSION >= 60200

// clang-format off
#define HIP_ROUTINE_EACH_62(__macro) \
__macro(hipGetFuncBySymbol) \
__macro(hipStreamBeginCaptureToGraph)
// clang-format on

HIP_ROUTINE_EACH_62(STREAM_EXECUTOR_HIP_WRAP)

#undef HIP_ROUTINE_EACH_62
#endif // TF_ROCM_VERSION >= 60200

#undef HIP_ROUTINE_EACH
#undef STREAM_EXECUTOR_HIP_WRAP
#undef TO_STR
Expand Down
9 changes: 0 additions & 9 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
Expand Down Expand Up @@ -287,16 +286,8 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
VLOG(1) << "Resolve ROCM kernel " << *kernel_name
<< " from symbol pointer: " << symbol;

#if TF_ROCM_VERSION >= 60200
TF_ASSIGN_OR_RETURN(
GpuFunctionHandle function,
GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol()));
*rocm_kernel->gpu_function_ptr() = function;
#else
*rocm_kernel->gpu_function_ptr() =
static_cast<hipFunction_t>(spec.in_process_symbol().symbol());
#endif // TF_ROCM_VERSION >= 60200

} else {
return absl::InternalError("No method of loading ROCM kernel provided");
}
Expand Down
7 changes: 0 additions & 7 deletions xla/stream_executor/rocm/rocm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@ namespace gpu {

absl::StatusOr<GpuFunctionHandle> GpuRuntime::GetFuncBySymbol(void* symbol) {
VLOG(2) << "Get ROCM function from a symbol: " << symbol;
#if TF_ROCM_VERSION >= 60200
hipFunction_t func;
RETURN_IF_ROCM_ERROR(wrap::hipGetFuncBySymbol(&func, symbol),
"Failed call to hipGetFuncBySymbol");
return func;
#else
return absl::UnimplementedError("GetFuncBySymbol is not implemented");
#endif // TF_ROCM_VERSION >= 60200
}

absl::StatusOr<int32_t> GpuRuntime::GetRuntimeVersion() {
Expand Down

0 comments on commit 0955566

Please sign in to comment.