22import inspect
33import os
44import traceback
5+ import warnings
56from typing import Callable , Optional , Tuple
67
78import torch
89
910from tensorrt_llm ._torch .models .checkpoints .base_checkpoint_loader import (
1011 AutoCheckpointMapper , BaseCheckpointLoader )
1112from tensorrt_llm ._utils import str_dtype_to_torch
12- from tensorrt_llm .llmapi .llm_args import TorchLlmArgs
13+ from tensorrt_llm .llmapi .llm_args import ExecutorMemoryType , TorchLlmArgs
1314from tensorrt_llm .llmapi .llm_utils import apply_model_defaults_to_llm_args
1415from tensorrt_llm .logger import logger
1516from tensorrt_llm .lora_helper import LoraConfig
2526 timing )
2627from ..modules .fused_moe .moe_load_balancer import (
2728 MoeLoadBalancer , maybe_create_moe_load_balancer )
29+ from ..virtual_memory import RestoreMode
30+ from ..virtual_memory import scope as virtual_memory_scope
2831
2932_KV_CACHE_MAP = {
3033 "fp8" : QuantAlgo .FP8 .value ,
@@ -182,6 +185,15 @@ def _construct_checkpoint_loader(
182185 return checkpoint_loader
183186
184187
188+ def _apply_to_buffers_only (model : torch .nn .Module , fn ):
189+ """Apply *fn* to every buffer in *model*, skipping parameters.
190+ """
191+ for module in model .modules ():
192+ for key , buf in module ._buffers .items ():
193+ if buf is not None :
194+ module ._buffers [key ] = fn (buf )
195+
196+
185197class ModelLoader :
186198 """
187199 Handles the loading, configuration, and weight initialization of a PyTorch model.
@@ -195,7 +207,9 @@ def __init__(self,
195207 sparse_attention_config : Optional ["SparseAttentionConfig" ],
196208 max_num_tokens : int ,
197209 max_seq_len : Optional [int ],
198- lora_config : Optional [LoraConfig ] = None ):
210+ lora_config : Optional [LoraConfig ] = None ,
211+ model_weights_memory_tag : Optional [ExecutorMemoryType ] = None ,
212+ model_weights_restore_mode : Optional [RestoreMode ] = None ):
199213 """
200214 Initializes the ModelLoader.
201215
@@ -206,6 +220,11 @@ def __init__(self,
206220 max_num_tokens: The maximum number of tokens the engine will handle.
207221 max_seq_len: The maximum sequence length.
208222 lora_config: Configuration for LoRA.
223+ model_weights_memory_tag: When set, parameter allocations during
224+ ``load()`` are placed under a separate virtual-memory tag so
225+ they can be released/materialized independently of buffers.
226+ model_weights_restore_mode: RestoreMode for the model weights
227+ virtual-memory scope.
209228 """
210229 self .llm_args = llm_args
211230 self .mapping = mapping
@@ -214,6 +233,9 @@ def __init__(self,
214233 self .max_num_tokens = max_num_tokens
215234 self .max_seq_len = max_seq_len
216235 self .lora_config = lora_config
236+ self .model_weights_memory_tag = model_weights_memory_tag
237+ self .model_weights_restore_mode = model_weights_restore_mode
238+ self ._weight_pool_proxy = None
217239
218240 @staticmethod
219241 def load_config_and_apply_defaults (
@@ -275,29 +297,81 @@ def load(
275297 config_copy = copy .deepcopy (config )
276298 with MetaInitMode ():
277299 model = AutoModelForCausalLM .from_config (config_copy )
300+ config = config_copy
301+ is_meta_init = True
302+ except Exception :
303+ logger .info (
304+ f"Fallback to regular model init: { traceback .format_exc (limit = 10 )} "
305+ )
306+ model = AutoModelForCausalLM .from_config (config )
307+ is_meta_init = False
308+
309+ memo = dict ()
310+
311+ if self .model_weights_memory_tag is not None :
312+ # Allocate buffers to the outer virtual_memory_scope,
313+ # but parameters (weights) to the dedicated inner virtual_memory_scope.
314+
315+ def allocate_buffer_on_cuda (t : torch .Tensor ):
316+ if t not in memo :
317+ if t .device == torch .device ('meta' ):
318+ cuda_t = torch .empty_like (t , device = 'cuda' )
319+ else :
320+ cuda_t = t .cuda ()
321+ memo [t ] = cuda_t
322+ memo [cuda_t ] = cuda_t
323+ return memo [t ]
278324
279- memo = dict ()
325+ _apply_to_buffers_only (model , allocate_buffer_on_cuda )
326+
327+ need_initialized_weights = load_format not in (LoadFormat .AUTO ,
328+ LoadFormat .DUMMY )
329+
330+ def allocate_weights_on_cuda (t : torch .Tensor ):
331+ if t not in memo :
332+ cuda_t = torch .empty_like (t , device = 'cuda' )
333+ if t .device != torch .device ('meta' ) and (
334+ need_initialized_weights or is_meta_init ):
335+ if t .is_cuda :
336+ memory_type_map = {
337+ ExecutorMemoryType .MODEL_WEIGHTS_MAIN :
338+ ExecutorMemoryType .MODEL_ENGINE_MAIN ,
339+ ExecutorMemoryType .MODEL_WEIGHTS_DRAFT :
340+ ExecutorMemoryType .MODEL_ENGINE_DRAFT ,
341+ }
342+
343+ warnings .warn (
344+ f"A weight tensor of shape { t .shape } is already allocated on CUDA device before "
345+ f"the weight allocation stage. This will cause extra CUDA memory usage in the "
346+ f"'{ memory_type_map [self .model_weights_memory_tag ]} ' scope."
347+ )
348+ cuda_t .copy_ (t )
349+ memo [t ] = cuda_t
350+ memo [cuda_t ] = cuda_t
351+ return memo [t ]
352+
353+ with virtual_memory_scope (
354+ self .model_weights_memory_tag ,
355+ self .model_weights_restore_mode ) as pool :
356+ model ._apply (allocate_weights_on_cuda )
357+ self ._weight_pool_proxy = pool
358+ elif is_meta_init :
280359
281360 def init_meta_tensor (t : torch .Tensor ):
282361 if t .device != torch .device ('meta' ):
283362 return t
363+
284364 if t not in memo :
285365 memo [t ] = torch .empty_like (t , device = 'cuda' )
286366 return memo [t ]
287367
288368 model ._apply (init_meta_tensor )
289- config = config_copy
290-
291- except Exception :
292- logger .info (
293- f"Fallback to regular model init: { traceback .format_exc (limit = 10 )} \n "
294- )
295- model = AutoModelForCausalLM .from_config (config )
296- finally :
297- if 'memo' in locals ():
298- del memo
299369
370+ # Ensure everything is at least on CUDA
371+ # No-op if worked as expected
300372 model .to ("cuda" )
373+ del memo
374+
301375 rank_model_storage = get_rank_model_storage (model )
302376 logger .info (
303377 f"Use { rank_model_storage / (1024 ** 3 ):.2f} GB for model weights."
0 commit comments