Skip to content

Commit de4008e

Browse files
authored
[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage (vllm-project#9352)
Signed-off-by: Joe Runde <[email protected]>
1 parent 48138a8 commit de4008e

File tree

4 files changed

+122
-17
lines changed

4 files changed

+122
-17
lines changed

tests/entrypoints/llm/test_lazy_outlines.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def test_lazy_outlines(sample_regex):
2626
# make sure outlines is not imported
2727
assert 'outlines' not in sys.modules
2828

29+
# The second LLM needs to request a higher gpu_memory_utilization because
30+
# the first LLM has already allocated a full 30% of the gpu memory.
2931
llm = LLM(model="facebook/opt-125m",
3032
enforce_eager=True,
3133
guided_decoding_backend="lm-format-enforcer",
32-
gpu_memory_utilization=0.3)
34+
gpu_memory_utilization=0.6)
3335
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
3436
outputs = llm.generate(
3537
prompts=[

tests/entrypoints/offline_mode/test_offline_mode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_offline_mode(llm: LLM, monkeypatch):
4444
LLM(model=MODEL_NAME,
4545
max_num_batched_tokens=4096,
4646
tensor_parallel_size=1,
47-
gpu_memory_utilization=0.10,
47+
gpu_memory_utilization=0.20,
4848
enforce_eager=True)
4949
finally:
5050
# Reset the environment after the test

tests/worker/test_profile.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
3+
from vllm.engine.arg_utils import EngineArgs
4+
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
5+
from vllm.worker.cache_engine import CacheEngine
6+
from vllm.worker.worker import Worker
7+
8+
9+
def test_gpu_memory_profiling():
10+
# Tests the gpu profiling that happens in order to determine the number of
11+
# KV cache blocks that we can allocate on the GPU.
12+
# This test mocks the maximum available gpu memory so that it can run on
13+
# any gpu setup.
14+
15+
# Set up engine args to build a worker.
16+
engine_args = EngineArgs(model="facebook/opt-125m",
17+
dtype="half",
18+
load_format="dummy")
19+
engine_config = engine_args.create_engine_config()
20+
engine_config.cache_config.num_gpu_blocks = 1000
21+
engine_config.cache_config.num_cpu_blocks = 1000
22+
23+
# Create the worker.
24+
distributed_init_method = get_distributed_init_method(
25+
get_ip(), get_open_port())
26+
worker = Worker(
27+
model_config=engine_config.model_config,
28+
parallel_config=engine_config.parallel_config,
29+
scheduler_config=engine_config.scheduler_config,
30+
device_config=engine_config.device_config,
31+
cache_config=engine_config.cache_config,
32+
load_config=engine_config.load_config,
33+
local_rank=0,
34+
rank=0,
35+
distributed_init_method=distributed_init_method,
36+
is_driver_worker=True,
37+
)
38+
39+
# Load the model so we can profile it
40+
worker.init_device()
41+
worker.load_model()
42+
43+
# Set 10GiB as the total gpu ram to be device-agnostic
44+
def mock_mem_info():
45+
current_usage = torch.cuda.memory_stats(
46+
)["allocated_bytes.all.current"]
47+
mock_total_bytes = 10 * 1024**3
48+
free = mock_total_bytes - current_usage
49+
50+
return (free, mock_total_bytes)
51+
52+
from unittest.mock import patch
53+
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
54+
gpu_blocks, _ = worker.determine_num_available_blocks()
55+
56+
# Peak vram usage by torch should be 0.7077 GiB
57+
# Non-torch allocations should be 0.0079 GiB
58+
# 9.0 GiB should be the utilization target
59+
# 8.2843 GiB should be available for the KV cache
60+
block_size = CacheEngine.get_cache_block_size(
61+
engine_config.cache_config, engine_config.model_config,
62+
engine_config.parallel_config)
63+
64+
expected_blocks = (8.2843 * 1024**3) // block_size
65+
66+
# Check within a small tolerance for portability
67+
# Hardware, kernel, or dependency changes could all affect memory
68+
# utilization
69+
assert abs(gpu_blocks - expected_blocks) < 5

vllm/worker/worker.py

+49-15
Original file line numberDiff line numberDiff line change
@@ -217,42 +217,76 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
217217
# Profile the memory usage of the model and get the maximum number of
218218
# cache blocks that can be allocated with the remaining free memory.
219219
torch.cuda.empty_cache()
220+
torch.cuda.reset_peak_memory_stats()
221+
222+
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
220223

221224
# Execute a forward pass with dummy inputs to profile the memory usage
222225
# of the model.
223226
self.model_runner.profile_run()
227+
torch.cuda.synchronize()
228+
229+
self._assert_memory_footprint_increased_during_profiling()
230+
231+
# Get the peak memory allocation recorded by torch
232+
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
233+
234+
# Check for any memory left around that may have been allocated on the
235+
# gpu outside of `torch`. NCCL operations, for example, can use a few
236+
# GB during a forward pass
237+
torch.cuda.empty_cache()
238+
# After emptying the torch cache, any other increase in gpu ram should
239+
# be from non-torch allocations.
240+
non_torch_allocations = free_memory_pre_profile - \
241+
torch.cuda.mem_get_info()[0]
242+
if non_torch_allocations > 0:
243+
peak_memory += non_torch_allocations
244+
245+
available_kv_cache_memory = (
246+
total_gpu_memory * self.cache_config.gpu_memory_utilization -
247+
peak_memory)
224248

225249
# Calculate the number of blocks that can be allocated with the
226250
# profiled peak memory.
227-
torch.cuda.synchronize()
228-
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
229-
# NOTE(woosuk): Here we assume that the other processes using the same
230-
# GPU did not change their memory usage during the profiling.
231-
peak_memory = self.init_gpu_memory - free_gpu_memory
232-
assert peak_memory > 0, (
233-
"Error in memory profiling. "
234-
f"Initial free memory {self.init_gpu_memory}, current free memory"
235-
f" {free_gpu_memory}. This happens when the GPU memory was "
236-
"not properly cleaned up before initializing the vLLM instance.")
237-
238251
cache_block_size = self.get_cache_block_size_bytes()
239252
if cache_block_size == 0:
240253
num_gpu_blocks = 0
241254
num_cpu_blocks = 0
242255
else:
243-
num_gpu_blocks = int(
244-
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
245-
peak_memory) // cache_block_size)
256+
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
246257
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
247258
cache_block_size)
248259
num_gpu_blocks = max(num_gpu_blocks, 0)
249260
num_cpu_blocks = max(num_cpu_blocks, 0)
261+
262+
logger.info(
263+
"Memory profiling results: total_gpu_memory=%.2fGiB"
264+
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
265+
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
266+
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
267+
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
268+
(peak_memory - non_torch_allocations) / (1024**3),
269+
non_torch_allocations / (1024**3),
270+
available_kv_cache_memory / (1024**3),
271+
self.cache_config.gpu_memory_utilization)
272+
273+
# Final cleanup
250274
if self.model_runner.lora_manager:
251275
self.model_runner.remove_all_loras()
252276
gc.collect()
253-
torch.cuda.empty_cache()
277+
254278
return num_gpu_blocks, num_cpu_blocks
255279

280+
def _assert_memory_footprint_increased_during_profiling(self):
281+
# NOTE(woosuk): Here we assume that the other processes using the same
282+
# GPU did not change their memory usage during the profiling.
283+
free_gpu_memory, _ = torch.cuda.mem_get_info()
284+
assert self.init_gpu_memory - free_gpu_memory > 0, (
285+
"Error in memory profiling. "
286+
f"Initial free memory {self.init_gpu_memory}, current free memory"
287+
f" {free_gpu_memory}. This happens when the GPU memory was "
288+
"not properly cleaned up before initializing the vLLM instance.")
289+
256290
def initialize_cache(self, num_gpu_blocks: int,
257291
num_cpu_blocks: int) -> None:
258292
"""Allocate GPU and CPU KV cache with the specified number of blocks.

0 commit comments

Comments
 (0)