Skip to content

Commit d04b750

Browse files
committed
feat(torch): add GMS weight-loading prototype
1 parent d41949b commit d04b750

8 files changed

Lines changed: 826 additions & 33 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .gpu_memory_backend import GMSBackend, GPUMemoryBackend
5+
6+
__all__ = ["GPUMemoryBackend", "GMSBackend"]
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from contextlib import contextmanager
5+
from typing import Iterator, Optional, Protocol, runtime_checkable
6+
7+
import torch
8+
from torch import nn
9+
10+
from tensorrt_llm.logger import logger
11+
from tensorrt_llm.mapping import Mapping
12+
13+
_MODE_ALIASES = ("rw", "ro", "auto")
14+
15+
16+
@runtime_checkable
17+
class GPUMemoryBackend(Protocol):
18+
def connect(self) -> bool:
19+
...
20+
21+
@property
22+
def is_rw(self) -> Optional[bool]:
23+
...
24+
25+
def has_committed_weights(self) -> bool:
26+
...
27+
28+
def mem_pool_scope(self, device: Optional[torch.device] = None) -> Iterator[None]:
29+
...
30+
31+
def materialize_module(self, model: nn.Module) -> None:
32+
...
33+
34+
def finalize_write(self, model: nn.Module) -> int:
35+
...
36+
37+
def move_untracked_params(self, model: nn.Module) -> None:
38+
...
39+
40+
def cleanup(self) -> None:
41+
...
42+
43+
44+
class GMSBackend:
45+
DEFAULT_TAG = "weights"
46+
47+
def __init__(
48+
self,
49+
socket_path: Optional[str],
50+
mapping: Mapping,
51+
mode: str = "auto",
52+
tag: str = DEFAULT_TAG,
53+
) -> None:
54+
if mode not in _MODE_ALIASES:
55+
raise ValueError(
56+
f"GMS mode must be one of {_MODE_ALIASES}, got {mode!r}")
57+
58+
self._socket_path = socket_path
59+
self._mapping = mapping
60+
self._mode = mode
61+
self._tag = tag
62+
self._device_index = torch.cuda.current_device()
63+
self._client = None
64+
self._is_rw: Optional[bool] = None
65+
66+
def connect(self) -> bool:
67+
try:
68+
from gpu_memory_service.client.torch.allocator import (
69+
get_or_create_gms_client_memory_manager,
70+
)
71+
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
72+
from gpu_memory_service.common.utils import get_socket_path
73+
from gpu_memory_service.integrations.common.patches import patch_empty_cache
74+
except ImportError:
75+
logger.warning(
76+
"gpu_memory_service is not installed; LoadFormat.GMS is unavailable.")
77+
return False
78+
79+
mode_map = {
80+
"rw": RequestedLockType.RW,
81+
"ro": RequestedLockType.RO,
82+
"auto": RequestedLockType.RW_OR_RO,
83+
}
84+
85+
socket_path = self._socket_path
86+
if socket_path is None:
87+
socket_path = get_socket_path(self._device_index, self._tag)
88+
self._socket_path = socket_path
89+
90+
try:
91+
self._client = get_or_create_gms_client_memory_manager(
92+
socket_path,
93+
self._device_index,
94+
mode=mode_map[self._mode],
95+
tag=self._tag,
96+
)
97+
except Exception as e:
98+
logger.warning(
99+
"Failed to connect to GMS at %s (mode=%s, tag=%s): %s",
100+
socket_path,
101+
self._mode,
102+
self._tag,
103+
e,
104+
)
105+
self._client = None
106+
return False
107+
108+
self._is_rw = self._client.granted_lock_type == GrantedLockType.RW
109+
try:
110+
patch_empty_cache()
111+
except Exception as e:
112+
logger.debug("GMS patch_empty_cache failed (non-fatal): %s", e)
113+
114+
logger.info(
115+
"Connected to GMS at %s (mode=%s, granted=%s, tag=%s)",
116+
socket_path,
117+
self._mode,
118+
"RW" if self._is_rw else "RO",
119+
self._tag,
120+
)
121+
return True
122+
123+
@property
124+
def is_rw(self) -> Optional[bool]:
125+
return self._is_rw
126+
127+
def has_committed_weights(self) -> bool:
128+
if self._client is None:
129+
return False
130+
try:
131+
from gpu_memory_service.common.locks import GrantedLockType
132+
133+
return self._client.granted_lock_type == GrantedLockType.RO
134+
except Exception:
135+
return False
136+
137+
@contextmanager
138+
def mem_pool_scope(
139+
self,
140+
device: Optional[torch.device] = None,
141+
) -> Iterator[None]:
142+
if self._client is None:
143+
raise RuntimeError("GMS client not connected. Call connect() first.")
144+
if self._is_rw is False:
145+
raise RuntimeError(
146+
"GMS mem_pool_scope() is only valid in RW mode (this client was granted RO)."
147+
)
148+
149+
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
150+
151+
target_device = device
152+
if target_device is None:
153+
target_device = torch.device("cuda", self._device_index)
154+
155+
with gms_use_mem_pool(self._tag, target_device):
156+
yield
157+
158+
def move_untracked_params(self, model: nn.Module) -> None:
159+
if self._client is None:
160+
raise RuntimeError("GMS client not connected. Call connect() first.")
161+
162+
from gpu_memory_service.client.torch.module import _iter_module_tensors
163+
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
164+
165+
gms_client = self._client
166+
seen: set[int] = set()
167+
168+
with torch.no_grad():
169+
for _name, tensor, tensor_type in _iter_module_tensors(model):
170+
if tensor_type != "parameter" or tensor is None or not tensor.is_cuda:
171+
continue
172+
173+
storage_ptr = tensor.untyped_storage().data_ptr()
174+
if storage_ptr in seen:
175+
continue
176+
seen.add(storage_ptr)
177+
178+
if _ptr_in_gms(gms_client, int(tensor.data_ptr())):
179+
continue
180+
181+
nbytes = _storage_nbytes(tensor)
182+
base_va = gms_client.create_mapping(size=nbytes, tag=self._tag)
183+
replacement = _tensor_from_pointer(
184+
int(base_va),
185+
list(tensor.shape),
186+
list(tensor.stride()),
187+
tensor.dtype,
188+
self._device_index,
189+
)
190+
replacement.copy_(tensor)
191+
tensor.data = replacement
192+
193+
def finalize_write(self, model: nn.Module) -> int:
194+
if self._client is None:
195+
raise RuntimeError("GMS client not connected. Call connect() first.")
196+
if self._is_rw is False:
197+
raise RuntimeError("GMS finalize_write() is only valid in RW mode.")
198+
199+
from gpu_memory_service.client.torch.module import register_module_tensors
200+
from gpu_memory_service.integrations.common.utils import finalize_gms_write
201+
202+
register_module_tensors(self._client, model)
203+
bytes_committed = int(self._client.total_bytes)
204+
torch.cuda.synchronize()
205+
finalize_gms_write(self._client)
206+
self._is_rw = False
207+
logger.info(
208+
"GMS RW->RO: committed %.2f GiB at %s (tag=%s)",
209+
bytes_committed / (1 << 30),
210+
self._socket_path,
211+
self._tag,
212+
)
213+
return bytes_committed
214+
215+
def materialize_module(self, model: nn.Module) -> None:
216+
if self._client is None:
217+
raise RuntimeError("GMS client not connected. Call connect() first.")
218+
219+
from gpu_memory_service.client.torch.module import materialize_module_from_gms
220+
from tensorrt_llm._torch.modules.linear import Linear
221+
222+
materialize_module_from_gms(
223+
self._client,
224+
model,
225+
device_index=self._device_index,
226+
)
227+
228+
for module in model.modules():
229+
if isinstance(module, Linear):
230+
module._weights_presharded = True
231+
232+
logger.info(
233+
"GMS RO: materialized weights from %s (tag=%s, tp_rank=%d/%d, total_bytes=%.2f GiB)",
234+
self._socket_path,
235+
self._tag,
236+
self._mapping.tp_rank,
237+
self._mapping.tp_size,
238+
int(self._client.total_bytes) / (1 << 30),
239+
)
240+
241+
def cleanup(self) -> None:
242+
if self._client is None:
243+
return
244+
245+
try:
246+
from gpu_memory_service.client.torch.allocator import (
247+
evict_gms_client_memory_manager,
248+
)
249+
250+
client = self._client
251+
try:
252+
client.close()
253+
except Exception:
254+
pass
255+
evict_gms_client_memory_manager(client)
256+
logger.info("GMS: disconnected from %s", self._socket_path)
257+
except Exception as e:
258+
logger.warning("GMS cleanup error: %s", e)
259+
finally:
260+
self._client = None
261+
262+
263+
def _ptr_in_gms(gms_client, ptr: int) -> bool:
264+
mappings = getattr(gms_client, "mappings", None)
265+
if not mappings:
266+
mappings = getattr(gms_client, "_mappings", None)
267+
if not mappings:
268+
return False
269+
270+
for mapping in mappings.values():
271+
base = int(getattr(mapping, "va", 0))
272+
size = int(getattr(mapping, "aligned_size", getattr(mapping, "size", 0)))
273+
if base and size and base <= ptr < base + size:
274+
return True
275+
return False
276+
277+
278+
def _storage_nbytes(tensor: torch.Tensor) -> int:
279+
return int(tensor.untyped_storage().nbytes())

tensorrt_llm/_torch/modules/linear.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,11 @@ def load_weights_vanilla_helper(module: Linear,
183183
if module.bias is not None:
184184
assert "bias" in weights[0]
185185
device = torch.device('cuda')
186+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
187+
tp_rank = 0 if getattr(module, '_weights_presharded', False) else module.tp_rank
186188

187-
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
188-
module.tp_rank, module.tp_mode,
189+
weight = load_weight_shard(weights[0]['weight'], tp_size,
190+
tp_rank, module.tp_mode,
189191
device) if "weight" in weights[0] else None
190192

191193
if weight is not None:
@@ -201,8 +203,8 @@ def load_weights_vanilla_helper(module: Linear,
201203
copy_weight(module.weight, weight_transform(weight))
202204

203205
if module.bias is not None:
204-
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
205-
module.tp_rank, module.tp_mode,
206+
bias = load_weight_shard(weights[0]['bias'], tp_size,
207+
tp_rank, module.tp_mode,
206208
device) if "bias" in weights[0] else None
207209
if bias is not None:
208210
copy_weight(module.bias, bias_transform(bias))
@@ -224,26 +226,28 @@ def load_weights_fused_qkv_helper(
224226
module, "fused_weight_shard_indices_mapping", None
225227
) is not None, "Fused weight shard indices mapping is required in partial loading"
226228
device = torch.device('cuda')
229+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
230+
tp_rank = 0 if getattr(module, '_weights_presharded', False) else module.tp_rank
227231

228-
q_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
229-
module.tp_rank, module.tp_mode,
232+
q_weight = load_weight_shard(weights[0]['weight'], tp_size,
233+
tp_rank, module.tp_mode,
230234
device) if "weight" in weights[0] else None
231-
k_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
232-
module.tp_rank, module.tp_mode,
235+
k_weight = load_weight_shard(weights[1]['weight'], tp_size,
236+
tp_rank, module.tp_mode,
233237
device) if "weight" in weights[1] else None
234-
v_weight = load_weight_shard(weights[2]['weight'], module.tp_size,
235-
module.tp_rank, module.tp_mode,
238+
v_weight = load_weight_shard(weights[2]['weight'], tp_size,
239+
tp_rank, module.tp_mode,
236240
device) if "weight" in weights[2] else None
237241

238242
if module.bias is not None:
239-
q_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
240-
module.tp_rank, module.tp_mode,
243+
q_bias = load_weight_shard(weights[0]['bias'], tp_size,
244+
tp_rank, module.tp_mode,
241245
device) if "bias" in weights[0] else None
242-
k_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
243-
module.tp_rank, module.tp_mode,
246+
k_bias = load_weight_shard(weights[1]['bias'], tp_size,
247+
tp_rank, module.tp_mode,
244248
device) if "bias" in weights[1] else None
245-
v_bias = load_weight_shard(weights[2]['bias'], module.tp_size,
246-
module.tp_rank, module.tp_mode,
249+
v_bias = load_weight_shard(weights[2]['bias'], tp_size,
250+
tp_rank, module.tp_mode,
247251
device) if "bias" in weights[2] else None
248252
if not allow_partial_loading:
249253
copy_weight(module.bias,
@@ -277,19 +281,21 @@ def load_weights_fused_gate_up_helper(
277281
module, "fused_weight_shard_indices_mapping", None
278282
) is not None, "Fused weight shard indices mapping is required in partial loading"
279283
device = torch.device('cuda')
284+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
285+
tp_rank = 0 if getattr(module, '_weights_presharded', False) else module.tp_rank
280286

281-
gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
282-
module.tp_rank, module.tp_mode,
287+
gate_weight = load_weight_shard(weights[0]['weight'], tp_size,
288+
tp_rank, module.tp_mode,
283289
device) if "weight" in weights[0] else None
284-
up_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
285-
module.tp_rank, module.tp_mode,
290+
up_weight = load_weight_shard(weights[1]['weight'], tp_size,
291+
tp_rank, module.tp_mode,
286292
device) if "weight" in weights[1] else None
287293
if module.bias is not None:
288-
gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
289-
module.tp_rank, module.tp_mode,
294+
gate_bias = load_weight_shard(weights[0]['bias'], tp_size,
295+
tp_rank, module.tp_mode,
290296
device) if "bias" in weights[0] else None
291-
up_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
292-
module.tp_rank, module.tp_mode,
297+
up_bias = load_weight_shard(weights[1]['bias'], tp_size,
298+
tp_rank, module.tp_mode,
293299
device) if "bias" in weights[1] else None
294300
if not allow_partial_loading:
295301
copy_weight(module.bias,
@@ -2502,6 +2508,7 @@ def __init__(
25022508
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
25032509
self.disable_deep_gemm = disable_deep_gemm
25042510
self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping
2511+
self._weights_presharded = False
25052512

25062513
# Store NVFP4 GEMM allowed backends configuration
25072514
# Read from model_extra_attrs if not explicitly provided (allows config via llm_api_options)

0 commit comments

Comments
 (0)