Skip to content

Commit fa31ce0

Browse files
[TRTLLM-11366][feat] Add dedicated virtual memory tag for model weights, configurable restore mode (NVIDIA#11889)
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
1 parent e03e361 commit fa31ce0

File tree

15 files changed

+577
-143
lines changed

15 files changed

+577
-143
lines changed

cpp/include/tensorrt_llm/runtime/virtualMemory.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ class CudaVirtualMemoryAllocator
473473
bool mBackground{};
474474

475475
friend class CudaVirtualMemoryAllocator;
476-
friend void setVirtualMemoryAllocator(
476+
friend void pushVirtualMemoryAllocator(
477477
std::string const& tag, RestoreMode mode, std::shared_ptr<CudaStream> backStream);
478478

479479
public:
@@ -566,8 +566,8 @@ namespace tensorrt_llm::runtime
566566
{
567567
CudaVirtualMemoryManager& getVirtualMemoryManager();
568568
CudaVirtualMemoryAllocator getVirtualMemoryAllocator();
569-
void setVirtualMemoryAllocator(
569+
void pushVirtualMemoryAllocator(
570570
std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr<CudaStream> backStream);
571-
void clearVirtualMemoryAllocator();
571+
void popVirtualMemoryAllocator();
572572

573573
} // namespace tensorrt_llm::runtime

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -343,20 +343,18 @@ void initBindings(nb::module_& m)
343343
nb::rv_policy::reference);
344344

345345
m.def(
346-
"set_virtual_memory_allocator",
346+
"push_virtual_memory_allocator",
347347
[](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream)
348348
{
349349
static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t));
350-
tr::setVirtualMemoryAllocator(tag, mode,
350+
tr::pushVirtualMemoryAllocator(tag, mode,
351351
std::make_shared<tr::CudaStream>(
352352
reinterpret_cast<cudaStream_t>(stream), tensorrt_llm::common::getDevice(), false));
353353
},
354-
"Set the virtual memory allocator and start allocating virtual memory for CUDA allocations",
355-
nb::call_guard<nb::gil_scoped_release>());
354+
"Push a virtual memory allocator onto the allocator stack.", nb::call_guard<nb::gil_scoped_release>());
356355

357-
m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator,
358-
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations",
359-
nb::call_guard<nb::gil_scoped_release>());
356+
m.def("pop_virtual_memory_allocator", &tr::popVirtualMemoryAllocator,
357+
"Pop the top virtual memory allocator from the allocator stack", nb::call_guard<nb::gil_scoped_release>());
360358

361359
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
362360
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), nb::arg("buf_size"),

cpp/tensorrt_llm/runtime/virtualMemory.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -402,32 +402,30 @@ using AllocConf = CudaVirtualMemoryAllocator::Configuration;
402402

403403
AllocConf AllocConf::backgroundConfiguration{getVirtualMemoryManager(), "", NONE, nullptr, true};
404404

405-
static const std::shared_ptr<AllocConf> bgConf{std::shared_ptr<AllocConf>{}, &AllocConf::backgroundConfiguration};
406-
407-
static std::shared_mutex currentConfMutex;
408-
static std::shared_ptr<AllocConf> currentConf = bgConf;
405+
static std::shared_mutex sConfMutex;
406+
static std::shared_ptr<AllocConf> sCurrentConf{std::shared_ptr<AllocConf>{}, &AllocConf::backgroundConfiguration};
407+
static std::vector<std::shared_ptr<AllocConf>> sConfStack;
409408

410409
CudaVirtualMemoryAllocator getVirtualMemoryAllocator()
411410
{
412-
std::shared_lock lock(currentConfMutex);
413-
return CudaVirtualMemoryAllocator{currentConf};
411+
std::shared_lock lock(sConfMutex);
412+
return CudaVirtualMemoryAllocator{sCurrentConf};
414413
}
415414

416-
void setVirtualMemoryAllocator(
415+
void pushVirtualMemoryAllocator(
417416
std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr<CudaStream> backStream)
418417
{
419-
std::unique_lock lock(currentConfMutex);
420-
421-
TLLM_CHECK_WITH_INFO(currentConf == bgConf,
422-
"An active virtual memory allocator (tag: %s, mode: %d, stream: %p) is already present",
423-
currentConf->mTag.c_str(), currentConf->mMode, currentConf->mBackStream.get());
424-
currentConf = std::make_shared<AllocConf>(getVirtualMemoryManager(), tag, mode, backStream);
418+
std::unique_lock lock(sConfMutex);
419+
sCurrentConf.swap(
420+
sConfStack.emplace_back(std::make_shared<AllocConf>(getVirtualMemoryManager(), tag, mode, backStream)));
425421
}
426422

427-
void clearVirtualMemoryAllocator()
423+
void popVirtualMemoryAllocator()
428424
{
429-
std::unique_lock lock(currentConfMutex);
430-
currentConf = bgConf;
425+
std::unique_lock lock(sConfMutex);
426+
TLLM_CHECK_WITH_INFO(!sConfStack.empty(), "popVirtualMemoryAllocator called with empty stack");
427+
sCurrentConf.swap(sConfStack.back());
428+
sConfStack.pop_back();
431429
}
432430

433431
} // namespace tensorrt_llm::runtime

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def __init__(
146146
torch.nn.Module]] = None,
147147
model: Optional[torch.nn.Module] = None,
148148
checkpoint_loader: Optional[BaseCheckpointLoader] = None,
149+
model_weights_memory_tag: Optional[str] = None,
150+
model_weights_restore_mode=None,
149151
):
150152
self.forward_pass_callable = None
151153
self.ub_buffers = None
@@ -212,6 +214,8 @@ def __init__(
212214
max_num_tokens=self.max_num_tokens,
213215
max_seq_len=self.max_seq_len,
214216
lora_config=lora_config,
217+
model_weights_memory_tag=model_weights_memory_tag,
218+
model_weights_restore_mode=model_weights_restore_mode,
215219
)
216220
self.model, moe_load_balancer = self.model_loader.load(
217221
checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader)

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import inspect
33
import os
44
import traceback
5+
import warnings
56
from typing import Callable, Optional, Tuple
67

78
import torch
89

910
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import (
1011
AutoCheckpointMapper, BaseCheckpointLoader)
1112
from tensorrt_llm._utils import str_dtype_to_torch
12-
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
13+
from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType, TorchLlmArgs
1314
from tensorrt_llm.llmapi.llm_utils import apply_model_defaults_to_llm_args
1415
from tensorrt_llm.logger import logger
1516
from tensorrt_llm.lora_helper import LoraConfig
@@ -25,6 +26,8 @@
2526
timing)
2627
from ..modules.fused_moe.moe_load_balancer import (
2728
MoeLoadBalancer, maybe_create_moe_load_balancer)
29+
from ..virtual_memory import RestoreMode
30+
from ..virtual_memory import scope as virtual_memory_scope
2831

2932
_KV_CACHE_MAP = {
3033
"fp8": QuantAlgo.FP8.value,
@@ -182,6 +185,15 @@ def _construct_checkpoint_loader(
182185
return checkpoint_loader
183186

184187

188+
def _apply_to_buffers_only(model: torch.nn.Module, fn):
189+
"""Apply *fn* to every buffer in *model*, skipping parameters.
190+
"""
191+
for module in model.modules():
192+
for key, buf in module._buffers.items():
193+
if buf is not None:
194+
module._buffers[key] = fn(buf)
195+
196+
185197
class ModelLoader:
186198
"""
187199
Handles the loading, configuration, and weight initialization of a PyTorch model.
@@ -195,7 +207,9 @@ def __init__(self,
195207
sparse_attention_config: Optional["SparseAttentionConfig"],
196208
max_num_tokens: int,
197209
max_seq_len: Optional[int],
198-
lora_config: Optional[LoraConfig] = None):
210+
lora_config: Optional[LoraConfig] = None,
211+
model_weights_memory_tag: Optional[ExecutorMemoryType] = None,
212+
model_weights_restore_mode: Optional[RestoreMode] = None):
199213
"""
200214
Initializes the ModelLoader.
201215
@@ -206,6 +220,11 @@ def __init__(self,
206220
max_num_tokens: The maximum number of tokens the engine will handle.
207221
max_seq_len: The maximum sequence length.
208222
lora_config: Configuration for LoRA.
223+
model_weights_memory_tag: When set, parameter allocations during
224+
``load()`` are placed under a separate virtual-memory tag so
225+
they can be released/materialized independently of buffers.
226+
model_weights_restore_mode: RestoreMode for the model weights
227+
virtual-memory scope.
209228
"""
210229
self.llm_args = llm_args
211230
self.mapping = mapping
@@ -214,6 +233,9 @@ def __init__(self,
214233
self.max_num_tokens = max_num_tokens
215234
self.max_seq_len = max_seq_len
216235
self.lora_config = lora_config
236+
self.model_weights_memory_tag = model_weights_memory_tag
237+
self.model_weights_restore_mode = model_weights_restore_mode
238+
self._weight_pool_proxy = None
217239

218240
@staticmethod
219241
def load_config_and_apply_defaults(
@@ -275,29 +297,81 @@ def load(
275297
config_copy = copy.deepcopy(config)
276298
with MetaInitMode():
277299
model = AutoModelForCausalLM.from_config(config_copy)
300+
config = config_copy
301+
is_meta_init = True
302+
except Exception:
303+
logger.info(
304+
f"Fallback to regular model init: {traceback.format_exc(limit=10)}"
305+
)
306+
model = AutoModelForCausalLM.from_config(config)
307+
is_meta_init = False
308+
309+
memo = dict()
310+
311+
if self.model_weights_memory_tag is not None:
312+
# Allocate buffers to the outer virtual_memory_scope,
313+
# but parameters (weights) to the dedicated inner virtual_memory_scope.
314+
315+
def allocate_buffer_on_cuda(t: torch.Tensor):
316+
if t not in memo:
317+
if t.device == torch.device('meta'):
318+
cuda_t = torch.empty_like(t, device='cuda')
319+
else:
320+
cuda_t = t.cuda()
321+
memo[t] = cuda_t
322+
memo[cuda_t] = cuda_t
323+
return memo[t]
278324

279-
memo = dict()
325+
_apply_to_buffers_only(model, allocate_buffer_on_cuda)
326+
327+
need_initialized_weights = load_format not in (LoadFormat.AUTO,
328+
LoadFormat.DUMMY)
329+
330+
def allocate_weights_on_cuda(t: torch.Tensor):
331+
if t not in memo:
332+
cuda_t = torch.empty_like(t, device='cuda')
333+
if t.device != torch.device('meta') and (
334+
need_initialized_weights or is_meta_init):
335+
if t.is_cuda:
336+
memory_type_map = {
337+
ExecutorMemoryType.MODEL_WEIGHTS_MAIN:
338+
ExecutorMemoryType.MODEL_ENGINE_MAIN,
339+
ExecutorMemoryType.MODEL_WEIGHTS_DRAFT:
340+
ExecutorMemoryType.MODEL_ENGINE_DRAFT,
341+
}
342+
343+
warnings.warn(
344+
f"A weight tensor of shape {t.shape} is already allocated on CUDA device before "
345+
f"the weight allocation stage. This will cause extra CUDA memory usage in the "
346+
f"'{memory_type_map[self.model_weights_memory_tag]}' scope."
347+
)
348+
cuda_t.copy_(t)
349+
memo[t] = cuda_t
350+
memo[cuda_t] = cuda_t
351+
return memo[t]
352+
353+
with virtual_memory_scope(
354+
self.model_weights_memory_tag,
355+
self.model_weights_restore_mode) as pool:
356+
model._apply(allocate_weights_on_cuda)
357+
self._weight_pool_proxy = pool
358+
elif is_meta_init:
280359

281360
def init_meta_tensor(t: torch.Tensor):
282361
if t.device != torch.device('meta'):
283362
return t
363+
284364
if t not in memo:
285365
memo[t] = torch.empty_like(t, device='cuda')
286366
return memo[t]
287367

288368
model._apply(init_meta_tensor)
289-
config = config_copy
290-
291-
except Exception:
292-
logger.info(
293-
f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n"
294-
)
295-
model = AutoModelForCausalLM.from_config(config)
296-
finally:
297-
if 'memo' in locals():
298-
del memo
299369

370+
# Ensure everything is at least on CUDA
371+
# No-op if worked as expected
300372
model.to("cuda")
373+
del memo
374+
301375
rank_model_storage = get_rank_model_storage(model)
302376
logger.info(
303377
f"Use {rank_model_storage / (1024**3):.2f} GB for model weights."

0 commit comments

Comments
 (0)