Skip to content

Commit

Permalink
Move GpuDriver::GetDevice to the proper Executor classes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684576866
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 12, 2024
1 parent 3c7ad93 commit 9671b37
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 31 deletions.
5 changes: 0 additions & 5 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,6 @@ absl::Status GpuDriver::Init() {
return *init_retval;
}

absl::Status GpuDriver::GetDevice(int device_ordinal, CUdevice* device) {
return cuda::ToStatus(cuDeviceGet(device, device_ordinal),
"Failed call to cuDeviceGet");
}

absl::Status GpuDriver::CreateGraph(CUgraph* graph) {
VLOG(2) << "Create new CUDA graph";
TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0),
Expand Down
13 changes: 10 additions & 3 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ absl::Status GetGridLimits(int* x, int* y, int* z, CUdevice device) {
*z = value;
return absl::OkStatus();
}

// Returns the device associated with the given device_ordinal.
absl::StatusOr<CUdevice> GetDevice(int device_ordinal) {
CUdevice device;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGet(&device, device_ordinal),
"Failed call to cuDeviceGet"));
return device;
}
} // namespace

// Given const GPU memory, returns a libcuda device pointer datatype, suitable
Expand All @@ -377,7 +385,7 @@ CudaExecutor::~CudaExecutor() {

absl::Status CudaExecutor::Init() {
TF_RETURN_IF_ERROR(GpuDriver::Init());
TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal(), &device_));
TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal()));
TF_ASSIGN_OR_RETURN(Context * context,
CudaContext::Create(device_ordinal(), device_));
set_context(context);
Expand Down Expand Up @@ -930,8 +938,7 @@ absl::Status CudaExecutor::TrimGraphMemory() {

absl::StatusOr<std::unique_ptr<DeviceDescription>>
CudaExecutor::CreateDeviceDescription(int device_ordinal) {
GpuDeviceHandle device;
TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal, &device));
TF_ASSIGN_OR_RETURN(GpuDeviceHandle device, GetDevice(device_ordinal));

int cc_major;
int cc_minor;
Expand Down
7 changes: 0 additions & 7 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ class GpuDriver {
// previously registered.
static bool HostUnregister(Context* context, void* location);

// Given a device ordinal, returns a device handle into the device outparam,
// which must not be null.
//
// N.B. these device handles do not have a corresponding destroy function in
// the CUDA/HIP driver API.
static absl::Status GetDevice(int device_ordinal, GpuDeviceHandle* device);

// Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control
Expand Down
10 changes: 0 additions & 10 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,6 @@ absl::Status GpuDriver::Init() {
return *init_retval;
}

absl::Status GpuDriver::GetDevice(int device_ordinal, hipDevice_t* device) {
hipError_t res = wrap::hipDeviceGet(device, device_ordinal);
if (res == hipSuccess) {
return absl::OkStatus();
}

return absl::InternalError(
absl::StrCat("failed call to hipDeviceGet: ", ToString(res)));
}

absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) {
VLOG(2) << "Create new HIP graph";
TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(graph, /*flags=*/0),
Expand Down
19 changes: 13 additions & 6 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,17 @@ absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {
return absl::OkStatus();
}

// Returns the device associated with the given device_ordinal.
absl::StatusOr<hipDevice_t> GetDevice(int device_ordinal) {
hipDevice_t device;
hipError_t res = wrap::hipDeviceGet(&device, device_ordinal);
if (res == hipSuccess) {
return device;
}

return absl::InternalError(
absl::StrCat("failed call to hipDeviceGet: ", ToString(res)));
}
} // namespace

RocmExecutor::~RocmExecutor() {
Expand Down Expand Up @@ -443,7 +454,7 @@ void RocmExecutor::UnloadKernel(const Kernel* kernel) {
absl::Status RocmExecutor::Init() {
TF_RETURN_IF_ERROR(GpuDriver::Init());

TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal(), &device_));
TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal()));

TF_ASSIGN_OR_RETURN(rocm_context_,
RocmContext::Create(device_ordinal(), device_));
Expand Down Expand Up @@ -775,11 +786,7 @@ absl::Status RocmExecutor::TrimGraphMemory() {

absl::StatusOr<std::unique_ptr<DeviceDescription>>
RocmExecutor::CreateDeviceDescription(int device_ordinal) {
GpuDeviceHandle device;
auto status = GpuDriver::GetDevice(device_ordinal, &device);
if (!status.ok()) {
return status;
}
TF_ASSIGN_OR_RETURN(GpuDeviceHandle device, GetDevice(device_ordinal));

TF_ASSIGN_OR_RETURN(std::string gcn_arch_name, GetGpuGCNArchName(device));

Expand Down

0 comments on commit 9671b37

Please sign in to comment.