@@ -217,42 +217,76 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
217
217
# Profile the memory usage of the model and get the maximum number of
218
218
# cache blocks that can be allocated with the remaining free memory.
219
219
torch .cuda .empty_cache ()
220
+ torch .cuda .reset_peak_memory_stats ()
221
+
222
+ free_memory_pre_profile , total_gpu_memory = torch .cuda .mem_get_info ()
220
223
221
224
# Execute a forward pass with dummy inputs to profile the memory usage
222
225
# of the model.
223
226
self .model_runner .profile_run ()
227
+ torch .cuda .synchronize ()
228
+
229
+ self ._assert_memory_footprint_increased_during_profiling ()
230
+
231
+ # Get the peak memory allocation recorded by torch
232
+ peak_memory = torch .cuda .memory_stats ()["allocated_bytes.all.peak" ]
233
+
234
+ # Check for any memory left around that may have been allocated on the
235
+ # gpu outside of `torch`. NCCL operations, for example, can use a few
236
+ # GB during a forward pass
237
+ torch .cuda .empty_cache ()
238
+ # After emptying the torch cache, any other increase in gpu ram should
239
+ # be from non-torch allocations.
240
+ non_torch_allocations = free_memory_pre_profile - \
241
+ torch .cuda .mem_get_info ()[0 ]
242
+ if non_torch_allocations > 0 :
243
+ peak_memory += non_torch_allocations
244
+
245
+ available_kv_cache_memory = (
246
+ total_gpu_memory * self .cache_config .gpu_memory_utilization -
247
+ peak_memory )
224
248
225
249
# Calculate the number of blocks that can be allocated with the
226
250
# profiled peak memory.
227
- torch .cuda .synchronize ()
228
- free_gpu_memory , total_gpu_memory = torch .cuda .mem_get_info ()
229
- # NOTE(woosuk): Here we assume that the other processes using the same
230
- # GPU did not change their memory usage during the profiling.
231
- peak_memory = self .init_gpu_memory - free_gpu_memory
232
- assert peak_memory > 0 , (
233
- "Error in memory profiling. "
234
- f"Initial free memory { self .init_gpu_memory } , current free memory"
235
- f" { free_gpu_memory } . This happens when the GPU memory was "
236
- "not properly cleaned up before initializing the vLLM instance." )
237
-
238
251
cache_block_size = self .get_cache_block_size_bytes ()
239
252
if cache_block_size == 0 :
240
253
num_gpu_blocks = 0
241
254
num_cpu_blocks = 0
242
255
else :
243
- num_gpu_blocks = int (
244
- (total_gpu_memory * self .cache_config .gpu_memory_utilization -
245
- peak_memory ) // cache_block_size )
256
+ num_gpu_blocks = int (available_kv_cache_memory // cache_block_size )
246
257
num_cpu_blocks = int (self .cache_config .swap_space_bytes //
247
258
cache_block_size )
248
259
num_gpu_blocks = max (num_gpu_blocks , 0 )
249
260
num_cpu_blocks = max (num_cpu_blocks , 0 )
261
+
262
+ logger .info (
263
+ "Memory profiling results: total_gpu_memory=%.2fGiB"
264
+ " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
265
+ " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
266
+ " gpu_memory_utilization=%.2f" , total_gpu_memory / (1024 ** 3 ),
267
+ (total_gpu_memory - free_memory_pre_profile ) / (1024 ** 3 ),
268
+ (peak_memory - non_torch_allocations ) / (1024 ** 3 ),
269
+ non_torch_allocations / (1024 ** 3 ),
270
+ available_kv_cache_memory / (1024 ** 3 ),
271
+ self .cache_config .gpu_memory_utilization )
272
+
273
+ # Final cleanup
250
274
if self .model_runner .lora_manager :
251
275
self .model_runner .remove_all_loras ()
252
276
gc .collect ()
253
- torch . cuda . empty_cache ()
277
+
254
278
return num_gpu_blocks , num_cpu_blocks
255
279
280
+ def _assert_memory_footprint_increased_during_profiling (self ):
281
+ # NOTE(woosuk): Here we assume that the other processes using the same
282
+ # GPU did not change their memory usage during the profiling.
283
+ free_gpu_memory , _ = torch .cuda .mem_get_info ()
284
+ assert self .init_gpu_memory - free_gpu_memory > 0 , (
285
+ "Error in memory profiling. "
286
+ f"Initial free memory { self .init_gpu_memory } , current free memory"
287
+ f" { free_gpu_memory } . This happens when the GPU memory was "
288
+ "not properly cleaned up before initializing the vLLM instance." )
289
+
256
290
def initialize_cache (self , num_gpu_blocks : int ,
257
291
num_cpu_blocks : int ) -> None :
258
292
"""Allocate GPU and CPU KV cache with the specified number of blocks.
0 commit comments