Skip to content

Commit

Permalink
Move GpuDriver::GetDeviceMemoryInfo into CudaExecutor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684100021
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 10, 2024
1 parent cae9085 commit 01d2f4d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 26 deletions.
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ absl::Status CudnnSupport::Init() {
LOG(ERROR) << "Could not create cudnn handle: "
<< CudnnStatusToString(status);
int64_t free, total;
GpuDriver::GetDeviceMemoryInfo(parent_->gpu_context(), &free, &total);
parent_->DeviceMemoryUsage(&free, &total);
LOG(ERROR) << "Memory usage: " << free << " bytes free, " << total
<< " bytes total.";

Expand Down
16 changes: 0 additions & 16 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1255,22 +1255,6 @@ bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) {
return true;
}

bool GpuDriver::GetDeviceMemoryInfo(Context* context, int64_t* free_out,
int64_t* total_out) {
ScopedActivateContext activation(context);
size_t free = 0;
size_t total = 0;
auto status = cuda::ToStatus(cuMemGetInfo(&free, &total));
if (!status.ok()) {
LOG(ERROR) << "failed to query device memory info: " << status;
return false;
}

*free_out = free;
*total_out = total;
return true;
}

bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) {
size_t value{};
auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device));
Expand Down
16 changes: 14 additions & 2 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,20 @@ absl::Status CudaExecutor::EnablePeerAccessTo(StreamExecutor* other) {
return GpuDriver::EnablePeerAccess(gpu_context(), cuda_other->gpu_context());
}

bool CudaExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
return GpuDriver::GetDeviceMemoryInfo(gpu_context(), free, total);
bool CudaExecutor::DeviceMemoryUsage(int64_t* free_out,
int64_t* total_out) const {
ScopedActivateContext activation(gpu_context());
size_t free = 0;
size_t total = 0;
auto status = cuda::ToStatus(cuMemGetInfo(&free, &total));
if (!status.ok()) {
LOG(ERROR) << "failed to query device memory info: " << status;
return false;
}

*free_out = free;
*total_out = total;
return true;
}

absl::StatusOr<DeviceMemoryBase> CudaExecutor::GetSymbol(
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CudaExecutor : public GpuExecutor {
absl::Status BlockHostUntilDone(Stream* stream) override;
absl::Status EnablePeerAccessTo(StreamExecutor* other) override;
bool CanEnablePeerAccessTo(StreamExecutor* other) override;
bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override;
bool DeviceMemoryUsage(int64_t* free_out, int64_t* total_out) const override;
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;
void UnloadKernel(const Kernel* kernel) override;
Expand Down
6 changes: 0 additions & 6 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,6 @@ class GpuDriver {
// context, in bytes, via cuDeviceTotalMem.
static bool GetDeviceTotalMemory(GpuDeviceHandle device, uint64_t* result);

// Returns the free amount of memory and total amount of memory, as reported
// by cuMemGetInfo.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0
static bool GetDeviceMemoryInfo(Context* context, int64_t* free,
int64_t* total);

// Returns a PCI bus id string for the device.
// [domain]:[bus]:[device].[function]
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc
Expand Down

0 comments on commit 01d2f4d

Please sign in to comment.