Skip to content

Commit

Permalink
[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory u…
Browse files Browse the repository at this point in the history
…sage (vllm-project#9352)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde authored Oct 18, 2024
1 parent 48138a8 commit de4008e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 17 deletions.
4 changes: 3 additions & 1 deletion tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
assert 'outlines' not in sys.modules

# The second LLM needs to request a higher gpu_memory_utilization because
# the first LLM has already allocated a full 30% of the gpu memory.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/offline_mode/test_offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_offline_mode(llm: LLM, monkeypatch):
LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
gpu_memory_utilization=0.20,
enforce_eager=True)
finally:
# Reset the environment after the test
Expand Down
69 changes: 69 additions & 0 deletions tests/worker/test_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker


def test_gpu_memory_profiling():
# Tests the gpu profiling that happens in order to determine the number of
# KV cache blocks that we can allocate on the GPU.
# This test mocks the maximum available gpu memory so that it can run on
# any gpu setup.

# Set up engine args to build a worker.
engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half",
load_format="dummy")
engine_config = engine_args.create_engine_config()
engine_config.cache_config.num_gpu_blocks = 1000
engine_config.cache_config.num_cpu_blocks = 1000

# Create the worker.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)

# Load the model so we can profile it
worker.init_device()
worker.load_model()

# Set 10GiB as the total gpu ram to be device-agnostic
def mock_mem_info():
current_usage = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
mock_total_bytes = 10 * 1024**3
free = mock_total_bytes - current_usage

return (free, mock_total_bytes)

from unittest.mock import patch
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
gpu_blocks, _ = worker.determine_num_available_blocks()

# Peak vram usage by torch should be 0.7077 GiB
# Non-torch allocations should be 0.0079 GiB
# 9.0 GiB should be the utilization target
# 8.2843 GiB should be available for the KV cache
block_size = CacheEngine.get_cache_block_size(
engine_config.cache_config, engine_config.model_config,
engine_config.parallel_config)

expected_blocks = (8.2843 * 1024**3) // block_size

# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization
assert abs(gpu_blocks - expected_blocks) < 5
64 changes: 49 additions & 15 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,42 +217,76 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
torch.cuda.synchronize()

self._assert_memory_footprint_increased_during_profiling()

# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
# After emptying the torch cache, any other increase in gpu ram should
# be from non-torch allocations.
non_torch_allocations = free_memory_pre_profile - \
torch.cuda.mem_get_info()[0]
if non_torch_allocations > 0:
peak_memory += non_torch_allocations

available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)

logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
non_torch_allocations / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)

# Final cleanup
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()

return num_gpu_blocks, num_cpu_blocks

def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, _ = torch.cuda.mem_get_info()
assert self.init_gpu_memory - free_gpu_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
Expand Down

0 comments on commit de4008e

Please sign in to comment.