From e8e08746b88291bc2a3842cf51cdea2d94df46f5 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 10 Oct 2024 16:09:29 -0700 Subject: [PATCH] Move GpuDriver::GetDeviceMemoryInfo into CudaExecutor. PiperOrigin-RevId: 684611306 --- xla/stream_executor/cuda/cuda_driver.cc | 16 ---------------- xla/stream_executor/cuda/cuda_executor.cc | 16 ++++++++++++++-- xla/stream_executor/cuda/cuda_executor.h | 2 +- xla/stream_executor/gpu/gpu_driver.h | 6 ------ 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index b7234973dae9c..0eedec58756e8 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -1255,22 +1255,6 @@ bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) { return true; } -bool GpuDriver::GetDeviceMemoryInfo(Context* context, int64_t* free_out, - int64_t* total_out) { - ScopedActivateContext activation(context); - size_t free = 0; - size_t total = 0; - auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); - if (!status.ok()) { - LOG(ERROR) << "failed to query device memory info: " << status; - return false; - } - - *free_out = free; - *total_out = total; - return true; -} - bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) { size_t value{}; auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device)); diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 44a37358374f2..b24035799ed85 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -753,8 +753,20 @@ absl::Status CudaExecutor::EnablePeerAccessTo(StreamExecutor* other) { return GpuDriver::EnablePeerAccess(gpu_context(), cuda_other->gpu_context()); } -bool CudaExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return GpuDriver::GetDeviceMemoryInfo(gpu_context(), free, total); +bool CudaExecutor::DeviceMemoryUsage(int64_t* free_out, + int64_t* total_out) const { + ScopedActivateContext activation(gpu_context()); + size_t free = 0; + size_t total = 0; + auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); + if (!status.ok()) { + LOG(ERROR) << "failed to query device memory info: " << status; + return false; + } + + *free_out = free; + *total_out = total; + return true; } absl::StatusOr CudaExecutor::GetSymbol( diff --git a/xla/stream_executor/cuda/cuda_executor.h b/xla/stream_executor/cuda/cuda_executor.h index 97276c650abe6..46905a0ccb4ce 100644 --- a/xla/stream_executor/cuda/cuda_executor.h +++ b/xla/stream_executor/cuda/cuda_executor.h @@ -89,7 +89,7 @@ class CudaExecutor : public GpuExecutor { absl::Status BlockHostUntilDone(Stream* stream) override; absl::Status EnablePeerAccessTo(StreamExecutor* other) override; bool CanEnablePeerAccessTo(StreamExecutor* other) override; - bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; + bool DeviceMemoryUsage(int64_t* free_out, int64_t* total_out) const override; absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; void UnloadKernel(const Kernel* kernel) override; diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 5546b99341ec5..e8d1a0958f1fc 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -597,12 +597,6 @@ class GpuDriver { // context, in bytes, via cuDeviceTotalMem. static bool GetDeviceTotalMemory(GpuDeviceHandle device, uint64_t* result); - // Returns the free amount of memory and total amount of memory, as reported - // by cuMemGetInfo. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0 - static bool GetDeviceMemoryInfo(Context* context, int64_t* free, - int64_t* total); - // Returns a PCI bus id string for the device. // [domain]:[bus]:[device].[function] // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc