Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "adding new ROCM-6.2 features" #16

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ limitations under the License.

// needed for rocblas_gemm_ex_get_solutions* functionality
#define ROCBLAS_BETA_FEATURES_API

#include "rocm/rocm_config.h"
#include "rocm/include/rocblas/rocblas.h"
#include "xla/stream_executor/gpu/gpu_activation.h"
#include "xla/stream_executor/platform/dso_loader.h"
Expand Down Expand Up @@ -276,21 +274,10 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
__macro(rocblas_get_stream) \
__macro(rocblas_set_stream) \
__macro(rocblas_set_atomics_mode)
// clang-format on

FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)

#if TF_ROCM_VERSION >= 60200

// clang-format off
#define FOREACH_ROCBLAS_API_62(__macro) \
__macro(rocblas_get_version_string_size) \
__macro(rocblas_get_version_string)
// clang-format on

FOREACH_ROCBLAS_API_62(ROCBLAS_API_WRAPPER)

#endif // TF_ROCM_VERSION >= 60200
FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)

} // namespace wrap
} // namespace stream_executor
Expand Down
8 changes: 4 additions & 4 deletions xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1212,17 +1212,17 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched)
}
}

absl::Status ROCMBlas::GetVersion(std::string *version) {
#if TF_ROCM_VERSION >= 60200
absl::Status ROCMBlas::GetVersion(string *version) {
#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1
absl::MutexLock lock{&mu_};
size_t len = 0;
if (auto res = wrap::rocblas_get_version_string_size(&len);
if (auto res = rocblas_get_version_string_size(&len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
}
std::vector<char> buf(len + 1);
if (auto res = wrap::rocblas_get_version_string(buf.data(), len);
if (auto res = rocblas_get_version_string(buf.data(), len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
Expand Down
37 changes: 4 additions & 33 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"

#include "xla/stream_executor/gpu/gpu_diagnostics.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/platform/port.h"
Expand Down Expand Up @@ -111,31 +110,6 @@ string ToString(hipError_t result) {
OSTREAM_ROCM_ERROR(ContextAlreadyInUse)
OSTREAM_ROCM_ERROR(PeerAccessUnsupported)
OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM.
#if TF_ROCM_VERSION >= 60200
OSTREAM_ROCM_ERROR(LaunchTimeOut)
OSTREAM_ROCM_ERROR(PeerAccessAlreadyEnabled)
OSTREAM_ROCM_ERROR(PeerAccessNotEnabled)
OSTREAM_ROCM_ERROR(SetOnActiveProcess)
OSTREAM_ROCM_ERROR(ContextIsDestroyed)
OSTREAM_ROCM_ERROR(Assert)
OSTREAM_ROCM_ERROR(HostMemoryAlreadyRegistered)
OSTREAM_ROCM_ERROR(HostMemoryNotRegistered)
OSTREAM_ROCM_ERROR(LaunchFailure)
OSTREAM_ROCM_ERROR(CooperativeLaunchTooLarge)
OSTREAM_ROCM_ERROR(NotSupported)
OSTREAM_ROCM_ERROR(StreamCaptureUnsupported)
OSTREAM_ROCM_ERROR(StreamCaptureInvalidated)
OSTREAM_ROCM_ERROR(StreamCaptureMerge)
OSTREAM_ROCM_ERROR(StreamCaptureUnmatched)
OSTREAM_ROCM_ERROR(StreamCaptureUnjoined)
OSTREAM_ROCM_ERROR(StreamCaptureIsolation)
OSTREAM_ROCM_ERROR(StreamCaptureImplicit)
OSTREAM_ROCM_ERROR(CapturedEvent)
OSTREAM_ROCM_ERROR(StreamCaptureWrongThread)
OSTREAM_ROCM_ERROR(GraphExecUpdateFailure)
OSTREAM_ROCM_ERROR(RuntimeMemory)
OSTREAM_ROCM_ERROR(RuntimeOther)
#endif // TF_ROCM_VERSION >= 60200
default:
return absl::StrCat("hipError_t(", static_cast<int>(result), ")");
}
Expand Down Expand Up @@ -1120,22 +1094,19 @@ struct BitPatternToValue {
VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes
<<" func: " << (const void*)function;
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;

auto res = hipSuccess;
#if TF_ROCM_VERSION < 60200
// for in-process kernel this function returns mangled kernel function name,
// and null otherwise
auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream);

auto res = hipSuccess;
if (name != nullptr) {
res = wrap::hipLaunchKernel((const void*)function,
dim3(grid_dim_x, grid_dim_y, grid_dim_z),
dim3(block_dim_x, block_dim_y, block_dim_z),
kernel_params, shared_mem_bytes, stream);
} else
#endif // TF_ROCM_VERSION < 60200
{
} else {
res = wrap::hipModuleLaunchKernel(
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
Expand Down
14 changes: 0 additions & 14 deletions xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.

#define __HIP_DISABLE_CPP_FUNCTIONS__

#include "rocm/rocm_config.h"
#include "rocm/include/hip/hip_runtime.h"
#include "xla/stream_executor/platform/dso_loader.h"
#include "xla/stream_executor/platform/port.h"
Expand Down Expand Up @@ -175,19 +174,6 @@ namespace wrap {

HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)

#if TF_ROCM_VERSION >= 60200

// clang-format off
#define HIP_ROUTINE_EACH_62(__macro) \
__macro(hipGetFuncBySymbol) \
__macro(hipStreamBeginCaptureToGraph)
// clang-format on

HIP_ROUTINE_EACH_62(STREAM_EXECUTOR_HIP_WRAP)

#undef HIP_ROUTINE_EACH_62
#endif // TF_ROCM_VERSION >= 60200

#undef HIP_ROUTINE_EACH
#undef STREAM_EXECUTOR_HIP_WRAP
#undef TO_STR
Expand Down
9 changes: 0 additions & 9 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
Expand Down Expand Up @@ -287,16 +286,8 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
VLOG(1) << "Resolve ROCM kernel " << *kernel_name
<< " from symbol pointer: " << symbol;

#if TF_ROCM_VERSION >= 60200
TF_ASSIGN_OR_RETURN(
GpuFunctionHandle function,
GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol()));
*rocm_kernel->gpu_function_ptr() = function;
#else
*rocm_kernel->gpu_function_ptr() =
static_cast<hipFunction_t>(spec.in_process_symbol().symbol());
#endif // TF_ROCM_VERSION >= 60200

} else {
return absl::InternalError("No method of loading ROCM kernel provided");
}
Expand Down
7 changes: 0 additions & 7 deletions xla/stream_executor/rocm/rocm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@ namespace gpu {

absl::StatusOr<GpuFunctionHandle> GpuRuntime::GetFuncBySymbol(void* symbol) {
VLOG(2) << "Get ROCM function from a symbol: " << symbol;
#if TF_ROCM_VERSION >= 60200
hipFunction_t func;
RETURN_IF_ROCM_ERROR(wrap::hipGetFuncBySymbol(&func, symbol),
"Failed call to hipGetFuncBySymbol");
return func;
#else
return absl::UnimplementedError("GetFuncBySymbol is not implemented");
#endif // TF_ROCM_VERSION >= 60200
}

absl::StatusOr<int32_t> GpuRuntime::GetRuntimeVersion() {
Expand Down
Loading