[TRTLLM-11851][feat] Add MX-only P2P checkpoint loading support for TRTLLM#13531
[TRTLLM-11851][feat] Add MX-only P2P checkpoint loading support for TRTLLM#13531chienchunhung wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
05ce987 to
d6f0384
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45846 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces MX (ModelExpress) peer-to-peer weight transfer support for checkpoint loading. A new Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client/ModelLoader
participant Loader as MXCheckpointLoader
participant MX as ModelExpress<br/>(P2P)
participant HF as HuggingFace<br/>(Disk)
participant Model as Model Instance
Client->>Loader: load_weights(checkpoint_dir,<br/>mapping, model=...)
alt MX Server & Model Reference Available
Loader->>MX: MxLiveWeightLoader.transfer_weights()
alt Transfer Success
MX-->>Loader: weights_dict (empty or partial)
Loader->>Model: Direct parameter writes<br/>(P2P succeeded = true)
alt Partial Transfer (non-empty dict)
Loader->>HF: Full disk load fallback<br/>(P2P succeeded = false)
HF-->>Loader: complete weights
Loader-->>Client: merged weights
else Complete Transfer (empty dict)
Loader-->>Client: P2P weights only
end
else Transfer Fails
MX--XLoader: Exception
Loader->>HF: Fallback to disk load<br/>(P2P succeeded = false)
HF-->>Loader: weights from disk
Loader-->>Client: disk weights
end
else Missing Config or modelexpress
Loader->>HF: Fallback to disk load<br/>(P2P succeeded = false)
HF-->>Loader: weights from disk
Loader-->>Client: disk weights
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/linear.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd required NVIDIA copyright/SPDX header to this modified Python source file.
This file was modified but still lacks the required header block at the top.
Proposed fix
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from __future__ import annotationsAs per coding guidelines, "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification" and "Include NVIDIA copyright header on all new files; update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/linear.py` at line 1, Add the required NVIDIA copyright/SPDX header block at the very top of tensorrt_llm/_torch/modules/linear.py (before the existing "from __future__ import annotations" line); the header must include the NVIDIA copyright line with the year of latest meaningful modification and the SPDX-License-Identifier (e.g., SPDX-License-Identifier: Apache-2.0) as used across the repo so the file complies with project coding guidelines.setup.py (1)
1-1:⚠️ Potential issue | 🟠 MajorUpdate SPDX copyright year for this modified file.
setup.pywas changed in 2026, but the header still ends at 2025.🔧 Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, “All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification” and “update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@setup.py` at line 1, Update the SPDX copyright header line that currently reads "SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." to reflect the latest modification year (2026); locate the header by the unique string "SPDX-FileCopyrightText" in setup.py and change the year range to "2022-2026" (or to a single year "2026" if preferred by project convention) so the file header matches the most recent modification.
🧹 Nitpick comments (1)
tests/unittest/_torch/pyexecutor/test_model_loader_mx.py (1)
97-157: Add regressions for MX success +reload()and non-default preshard strategy.These tests only exercise the
"per_module"happy path. The production branch also ownsself.weight_mappersetup and strategy-specific skip logic, so a pure MX load followed byreload()ormx_preshard_strategy="global"can regress without this suite noticing. QA list updates look unnecessary here because this is unit-only coverage.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/pyexecutor/test_model_loader_mx.py` around lines 97 - 157, Tests only cover the "per_module" MX preshard path and miss regressions for a subsequent reload() call and for the "global" mx_preshard_strategy; extend unit tests (e.g., add cases alongside test_mx_success_marks_main_linears_and_skips_weight_mapping and test_mx_fallback_runs_standard_weight_mapping) to simulate: (1) calling loader.reload(...) after a successful MX load to ensure loader.weight_mapper and skip logic still behave, and (2) running loader.load with mx_preshard_strategy="global" (or by configuring loader.weight_mapper to use global strategy) to assert preshard marking/skipping behaves as expected for main modules vs draft_model modules; reuse _make_loader and checkpoint_loader mocks (set checkpoint_loader.p2p_succeeded True/False and checkpoint_loader.load_weights return values) and assert loader._call_load_weights counts, model.*_weights_presharded flags, and event order just like the existing tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 421-445: When mx_p2p_succeeded is true the code marks Linear
modules presharded but never initializes self.weight_mapper nor validates
mx_preshard_strategy, yet reload() later expects self.weight_mapper; update the
mx_p2p_succeeded branch to always set self.weight_mapper via
checkpoint_loader.get_initialized_weight_mapper(model, config) (same as the
non-fast path) and validate config.mx_preshard_strategy (e.g., raise or handle
if it's not "per_module") before marking modules so non-"per_module" strategies
fail fast; keep using model.load_weights with self.weight_mapper so reload() can
safely consume it.
---
Outside diff comments:
In `@setup.py`:
- Line 1: Update the SPDX copyright header line that currently reads
"SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION &
AFFILIATES. All rights reserved." to reflect the latest modification year
(2026); locate the header by the unique string "SPDX-FileCopyrightText" in
setup.py and change the year range to "2022-2026" (or to a single year "2026" if
preferred by project convention) so the file header matches the most recent
modification.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Line 1: Add the required NVIDIA copyright/SPDX header block at the very top of
tensorrt_llm/_torch/modules/linear.py (before the existing "from __future__
import annotations" line); the header must include the NVIDIA copyright line
with the year of latest meaningful modification and the SPDX-License-Identifier
(e.g., SPDX-License-Identifier: Apache-2.0) as used across the repo so the file
complies with project coding guidelines.
---
Nitpick comments:
In `@tests/unittest/_torch/pyexecutor/test_model_loader_mx.py`:
- Around line 97-157: Tests only cover the "per_module" MX preshard path and
miss regressions for a subsequent reload() call and for the "global"
mx_preshard_strategy; extend unit tests (e.g., add cases alongside
test_mx_success_marks_main_linears_and_skips_weight_mapping and
test_mx_fallback_runs_standard_weight_mapping) to simulate: (1) calling
loader.reload(...) after a successful MX load to ensure loader.weight_mapper and
skip logic still behave, and (2) running loader.load with
mx_preshard_strategy="global" (or by configuring loader.weight_mapper to use
global strategy) to assert preshard marking/skipping behaves as expected for
main modules vs draft_model modules; reuse _make_loader and checkpoint_loader
mocks (set checkpoint_loader.p2p_succeeded True/False and
checkpoint_loader.load_weights return values) and assert
loader._call_load_weights counts, model.*_weights_presharded flags, and event
order just like the existing tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b240e53f-4bf3-42cf-965b-c5c28cfd7c1f
📒 Files selected for processing (19)
setup.pytensorrt_llm/_torch/models/checkpoints/__init__.pytensorrt_llm/_torch/models/checkpoints/auto_mapper.pytensorrt_llm/_torch/models/checkpoints/base_weight_loader.pytensorrt_llm/_torch/models/checkpoints/hf/config_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.pytensorrt_llm/_torch/models/checkpoints/mx/__init__.pytensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/executor/base_worker.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.pytests/unittest/_torch/pyexecutor/test_model_loader_mx.pytests/unittest/api_stability/references/llm.yamltests/unittest/llmapi/test_mx_args.py
|
PR_Github #45846 [ run ] completed with state |
d6f0384 to
4b00f9d
Compare
|
PR LGTM - lack approval privileges |
tburt-nv
left a comment
There was a problem hiding this comment.
No problem with the setup.py comments
4b00f9d to
7cd32b2
Compare
brb-nv
left a comment
There was a problem hiding this comment.
Minor comments. Changes LGTM.
Introduce the first PR slice from the MX/GMS prototype: checkpoint_format="MX" support using upstream modelexpress MxLiveWeightLoader and publish_model_params, while intentionally excluding GMS/load_format changes. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
7cd32b2 to
49faefc
Compare
venkywonka
left a comment
There was a problem hiding this comment.
This PR seem to have no documentation update on this awesome new feature addition. If that is aimed at a follow-up PR then no worries, but if not, I'd recommend adding some docs:
Here are some places codex suggested:
docs/source/features/checkpoint-loading.md- Add a dedicated
features/model-express-p2p- checkpoint-loading.md, wire it into docs/source/index.rst - Add small pointers/examples in
overview.md,trtllm-serve.rst,
and optionallyquickstart_advanced.py
|
also if you desire that this be tracked in telemetry as a feature usage, you might also want to update |
Summary by CodeRabbit
Release Notes
mx_server_urlfor MX server endpoint andmx_preshard_strategyto control weight sharding behavior.Description
Summary
This PR is the MX-only first slice split out from PR #13045.
It adds
checkpoint_format="MX"support to TRT-LLM's PyTorch backend using upstreammodelexpress.trtllm_live_transfer.MxLiveWeightLoaderandpublish_model_params. GMS is intentionally excluded so reviewers can validate MX first.Follow-up PRs will add:
LoadFormat.GMS,GMSBackend, GMS args/tests)What This PR Adds
MX Checkpoint Loader
Adds
MXCheckpointLoaderundertensorrt_llm/_torch/models/checkpoints/mx/.Behavior:
HfCheckpointLoader, so HF disk fallback is inherited.MxLiveWeightLoader(mx_server=url).load_weights(checkpoint_dir, mapping=..., model=...).publish_model_params(model)beforepost_load_weights()for source workers.p2p_succeededsoModelLoadercan skip normal weight mapping on full P2P success.fallback_weights, keep P2P-delivered tensors and merge only the returned fallback tensors through the standard disk pipeline.MX Config
Adds MX-only prototype fields:
mx_server_urlmx_server_query_timeout_smx_preshard_strategyBehavior:
MODEL_EXPRESS_URLis used as fallback formx_server_urlwhencheckpoint_format="MX".mx_server_query_timeout_slets deployments size source-discovery wait time.MX_SOURCE_QUERY_TIMEOUT=30for fast disk fallbackmx_preshard_strategy="global"fails fast untilLoadFormat.PRESHARDEDexists upstream.ModelLoader Integration
Updates only the existing
LoadFormat.AUTOpath:model=modelto checkpoint loaders. Generic HF loaders ignore it; MX uses it for direct P2P writes.self.weight_mapper, including the MX fast path, soreload()remains safe.Linearmodules as_weights_presharded=Trueafter MX success.fallback_weightsthrough the standard weight-loading path for partial MX fallback.Linear Marker
Adds
_weights_presharded = FalsetoLinear.This PR does not route presharded tensors back through
load_weight_shard()helpers; the marker is set after successful MX direct writes and is kept for the current ModelLoader path plus futureLoadFormat.PRESHARDEDwork.What This PR Excludes
This PR intentionally does not include:
LoadFormat.GMSGMSBackendgms_socket_path,gms_mode,gms_tag[gms]/[dynamo]packaging extrasPackaging Note
This PR does not add a
[mx]extra yet.For prototype testing:
pip install "modelexpress>=0.3.0,<0.4.0"modelexpressis on PyPI but still needs NVIDIA OSS allowlist onboarding (tracked as MX-7). Once complete, restoringpip install tensorrt_llm[mx]is a smallsetup.pychange.Running MX
Optional timeout override:
Python API:
Test Coverage
Added MX-only unit tests:
tests/unittest/llmapi/test_mx_args.pytests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.pytests/unittest/_torch/pyexecutor/test_model_loader_mx.pyCoverage includes:
MODEL_EXPRESS_URLfallbackmx_server_query_timeout_smx_preshard_strategyvalidationMODEL_NAMEresolutionPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.