Skip to content

Commit 7cd32b2

Browse files
committed
[TRTLLM-11851][feat] Add MX-only P2P checkpoint loading support
Introduce the first PR slice from the MX/GMS prototype: checkpoint_format="MX" support using upstream modelexpress MxLiveWeightLoader and publish_model_params, while intentionally excluding GMS/load_format changes. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
1 parent be1f6f5 commit 7cd32b2

19 files changed

Lines changed: 1819 additions & 20 deletions

File tree

setup.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,21 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
427427
scripts=['tensorrt_llm/llmapi/trtllm-llmapi-launch'],
428428
extras_require={
429429
"devel": devel_deps,
430+
# NOTE: The MX (modelexpress) Python package used by
431+
# tensorrt_llm._torch.models.checkpoints.mx is intentionally NOT
432+
# declared as an ``[mx]`` extra while this integration is at
433+
# prototype status. ``modelexpress`` v0.3.0 is on PyPI
434+
# (Apache-2.0) but is brand-new (Beta status, single release) and
435+
# still needs onboarding into NVIDIA's OSS package allowlist.
436+
#
437+
# Until allowlisting is complete, users who want to exercise the
438+
# MX code path install the dependency manually:
439+
#
440+
# pip install "modelexpress>=0.3.0,<0.4.0"
441+
#
442+
# Restoring one-line ``pip install tensorrt_llm[mx]`` ergonomics
443+
# is a single revert of this hunk once the OSS-allowlist step is
444+
# complete (tracked in §15 of the design doc as MX-7).
430445
},
431446
zip_safe=True,
432447
install_requires=required_deps,

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .mistral.config_loader import MistralConfigLoader
2020
from .mistral.weight_mapper import (MistralLarge3WeightMapper,
2121
MistralWeightMapper)
22+
from .mx.checkpoint_loader import MXCheckpointLoader
2223

2324
__all__ = [
2425
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "MistralConfigLoader",
@@ -28,5 +29,5 @@
2829
"Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
2930
"Qwen3_5MoeHfWeightMapper", "Qwen3NextHfWeightMapper",
3031
"LlavaNextHfWeightMapper", "MistralLarge3CheckpointLoader",
31-
"MistralLarge3WeightMapper", "Qwen3VLHfWeightMapper"
32+
"MistralLarge3WeightMapper", "MXCheckpointLoader", "Qwen3VLHfWeightMapper"
3233
]

tensorrt_llm/_torch/models/checkpoints/auto_mapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ def get(format: str, name: Optional[str] = None) -> "BaseWeightMapper":
1111
try:
1212
return MODEL_CLASS_MAPPER_MAPPING[f'{name}_{format}']()
1313
except KeyError: # no mapper for this model architecture, resort to default
14+
if format == "MX":
15+
# MX uses HF on-disk checkpoint format for fallback, so
16+
# an architecture-specific HF mapper is closer than the
17+
# generic MX/HF default mapper.
18+
try:
19+
return MODEL_CLASS_MAPPER_MAPPING[f'{name}_HF']()
20+
except KeyError:
21+
pass
1422
# TODO smor- a potential bug here, if the class isn't added to __init__, it will return the default mapper
1523
return MODEL_CLASS_MAPPER_MAPPING[format]()
1624
else:

tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import threading
25
from abc import ABC, abstractmethod
36
from typing import Any, Dict, Iterator, Tuple, Union
@@ -91,15 +94,16 @@ def mark_consumed(self, prefix: str) -> int:
9194
class BaseWeightLoader(ABC):
9295

9396
@abstractmethod
94-
def load_weights(
95-
self, checkpoint_dir: str,
96-
mapping: Mapping) -> Union[Dict[str, Any], ConsumableWeightsDict]:
97+
def load_weights(self, checkpoint_dir: str, mapping: Mapping,
98+
**kwargs) -> Union[Dict[str, Any], ConsumableWeightsDict]:
9799
"""
98100
Loads weights from a checkpoint directory.
99101
100102
Args:
101103
checkpoint_dir: A path to the checkpoint directory.
102104
mapping: A mapping object containing the distributed configuration.
105+
**kwargs: Optional format-specific loader arguments. Generic HF
106+
loaders ignore these; MX uses ``model`` for direct P2P writes.
103107
104108
Returns:
105109
A dictionary (or ConsumableWeightsDict) where keys are tensor names

tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
from tensorrt_llm._torch.model_config import ModelConfig
25
from tensorrt_llm._torch.models.checkpoints.base_config_loader import \
36
BaseConfigLoader
47
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
58

69

10+
@register_config_loader("MX")
711
@register_config_loader("HF")
812
class HfConfigLoader(BaseConfigLoader):
913

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorrt_llm.mapping import Mapping
3535

3636

37+
@register_checkpoint_weight_loader("MX")
3738
@register_checkpoint_weight_loader("mistral")
3839
@register_checkpoint_weight_loader("HF")
3940
class HfWeightLoader(BaseWeightLoader):
@@ -59,8 +60,8 @@ def _get_local_available_host_memory() -> int:
5960
op=_MPI.MIN)
6061
return available_host_memory
6162

62-
def load_weights(self, checkpoint_dir: str,
63-
mapping: Mapping) -> dict[str, Any]:
63+
def load_weights(self, checkpoint_dir: str, mapping: Mapping,
64+
**kwargs) -> dict[str, Any]:
6465
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
6566
# Some model checkpoint directories contain not only the sharded safetensors, but one
6667
# consolidated tensor. In the presence of both, we favor the former, as there really is no need

tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..base_weight_mapper import BaseWeightMapper
88

99

10+
@register_mapper("MX")
1011
@register_mapper("HF")
1112
class HfWeightMapper(BaseWeightMapper):
1213

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .checkpoint_loader import MXCheckpointLoader
17+
18+
__all__ = ["MXCheckpointLoader"]

0 commit comments

Comments
 (0)