From 01d2f4ddb318c560fca71d78167c5e6707642b17 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 9 Oct 2024 11:02:02 -0700 Subject: [PATCH] Move GpuDriver::GetDeviceMemoryInfo into CudaExecutor. PiperOrigin-RevId: 684100021 --- xla/stream_executor/cuda/cuda_dnn.cc | 2 +- xla/stream_executor/cuda/cuda_driver.cc | 16 ---------------- xla/stream_executor/cuda/cuda_executor.cc | 16 ++++++++++++++-- xla/stream_executor/cuda/cuda_executor.h | 2 +- xla/stream_executor/gpu/gpu_driver.h | 6 ------ 5 files changed, 16 insertions(+), 26 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 67d51a13ce5b29..6f201a76585fd3 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -534,7 +534,7 @@ absl::Status CudnnSupport::Init() { LOG(ERROR) << "Could not create cudnn handle: " << CudnnStatusToString(status); int64_t free, total; - GpuDriver::GetDeviceMemoryInfo(parent_->gpu_context(), &free, &total); + parent_->DeviceMemoryUsage(&free, &total); LOG(ERROR) << "Memory usage: " << free << " bytes free, " << total << " bytes total."; diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index b7234973dae9c9..0eedec58756e87 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -1255,22 +1255,6 @@ bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) { return true; } -bool GpuDriver::GetDeviceMemoryInfo(Context* context, int64_t* free_out, - int64_t* total_out) { - ScopedActivateContext activation(context); - size_t free = 0; - size_t total = 0; - auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); - if (!status.ok()) { - LOG(ERROR) << "failed to query device memory info: " << status; - return false; - } - - *free_out = free; - *total_out = total; - return true; -} - bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) { size_t value{}; auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device)); diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 44a37358374f2b..b24035799ed851 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -753,8 +753,20 @@ absl::Status CudaExecutor::EnablePeerAccessTo(StreamExecutor* other) { return GpuDriver::EnablePeerAccess(gpu_context(), cuda_other->gpu_context()); } -bool CudaExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return GpuDriver::GetDeviceMemoryInfo(gpu_context(), free, total); +bool CudaExecutor::DeviceMemoryUsage(int64_t* free_out, + int64_t* total_out) const { + ScopedActivateContext activation(gpu_context()); + size_t free = 0; + size_t total = 0; + auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); + if (!status.ok()) { + LOG(ERROR) << "failed to query device memory info: " << status; + return false; + } + + *free_out = free; + *total_out = total; + return true; } absl::StatusOr CudaExecutor::GetSymbol( diff --git a/xla/stream_executor/cuda/cuda_executor.h b/xla/stream_executor/cuda/cuda_executor.h index 97276c650abe61..46905a0ccb4cef 100644 --- a/xla/stream_executor/cuda/cuda_executor.h +++ b/xla/stream_executor/cuda/cuda_executor.h @@ -89,7 +89,7 @@ class CudaExecutor : public GpuExecutor { absl::Status BlockHostUntilDone(Stream* stream) override; absl::Status EnablePeerAccessTo(StreamExecutor* other) override; bool CanEnablePeerAccessTo(StreamExecutor* other) override; - bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; + bool DeviceMemoryUsage(int64_t* free_out, int64_t* total_out) const override; absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; void UnloadKernel(const Kernel* kernel) override; diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 5546b99341ec5c..e8d1a0958f1fce 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -597,12 +597,6 @@ class GpuDriver { // context, in bytes, via cuDeviceTotalMem. static bool GetDeviceTotalMemory(GpuDeviceHandle device, uint64_t* result); - // Returns the free amount of memory and total amount of memory, as reported - // by cuMemGetInfo. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0 - static bool GetDeviceMemoryInfo(Context* context, int64_t* free, - int64_t* total); - // Returns a PCI bus id string for the device. // [domain]:[bus]:[device].[function] // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc