[TRTLLM-11851][feat] MX and GMS integration MVP for Dynamo weight sharing#13045
[TRTLLM-11851][feat] MX and GMS integration MVP for Dynamo weight sharing#13045chienchunhung wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
…th implementation The protocol listing in Section 5 (challenge 8) had stale method signatures (commit/upgrade_lock) that diverged from the canonical implementation in Section 4 and the prototype PR NVIDIA#13045. Updated to match: finalize_write takes tag parameter, cleanup replaces commit+upgrade_lock. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
…13045 Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
Update Section 15 (Prototype Validation Plan) with the actual progress of integrating PR NVIDIA#13045 with the §10/§11 benchmark infrastructure: - Replace the speculative Branch Strategy with concrete branch info: dynamo/proto-rebased (prototype on current upstream/main, 0 rebase conflicts) and dynamo/proto-bench-integration-v2 (= proto-rebased + 7 bench commits cherry-picked). Both pushed to fork; the two original branches (dynamo-integration-prototype and dynamo/startup-profiling) are untouched. - Document all 5 conflict resolutions (py_executor_creator.py + model_loader.py) on the foundational bench commit. Strategy: keep prototype semantics, add §10 timers without changing control flow. - Add an Execution Status section capturing: * Phase A done (branch integration + smoke verification) * Smoke confirms M1 (AUTO/HF) baseline path is unaffected and bench profiler captures the full hierarchy on the integrated branch (Qwen 7B TP=1 warm NFS: 67.6s server / 36.3s worker). * Phase B blocker: prototype's GMSBackend uses a high-level API (gms_client.connect, get_mem_pool, materialize_module_from_gms as top-level functions) that does NOT exist in the actually- merged ai-dynamo/dynamo PR NVIDIA#7575. Merged GMS exposes a class- based API (GMSClientMemoryManager + gms_use_mem_pool context manager) plus an official setup_gms() monkey-patch integration at gpu_memory_service.integrations.trtllm. The prototype's "Align GMS backend with merged GMS library API (PR NVIDIA#7575)" commit (2026-04-16) was written against an earlier, un-merged iteration of the API. Three resolution paths laid out for decision. * MX install not feasible on current node (no Rust, no Docker); ai-dynamo/modelexpress would need separate node prep. - Recommend resolving the GMS API mismatch first (cheaper, exercises ~half the prototype's new code) before MX node prep. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
… (DONE) Update Section 15 to reflect the resolution of the GMS API mismatch blocker that was open as of the previous Section 15 commit. PR NVIDIA#13045 has been: 1. Rebased onto current upstream/main (7b84136 — 95 commits ahead of the original prototype base 4a848cc). Zero rebase conflicts. 2. Refactored to align both adapters with what was actually merged upstream: - GMS (gpu_memory_service v0.9.0): use the class-based API (GMSClientMemoryManager + gms_use_mem_pool context manager + materialize_module_from_gms + finalize_gms_write) instead of the non-existent gms_client.connect() / get_mem_pool() / commit() / disconnect() convenience functions. Default tag changes from 'model_weights' to 'weights' to match the GMS convention. - MX (modelexpress v0.3.0): delegate the actual NIXL transfer to MxLiveWeightLoader and publishing to publish_model_params instead of calling non-existent mx_client.connect() / list_sources() / receive() / register_source() helpers. Old SourceIdentity dict schema replaced by structured fields. 3. Added an mx_preshard_strategy config knob (per_module | global) defaulting to per_module. 'global' raises NotImplementedError until LoadFormat.PRESHARDED lands upstream. 4. Declared modelexpress and gpu-memory-service as setuptools optional extras with pinned upper bounds. The new live working branch is dynamo-integration-prototype-rebased (tip: 62ac40f). It will be force-pushed to dynamo-integration-prototype once the alignment work is reviewed. Pre-alignment branches (dynamo/proto-rebased, dynamo/proto-bench-integration-v2) remain as historical snapshots and are slated for cleanup after Phase B validation completes. Doc changes: - Status header updated from "Phase B blocked on GMS API alignment" to "Phase B ready once MX server + GMS daemon are reachable on a test node" - Branch Strategy table adds the new dynamo-integration-prototype-rebased row and clarifies that the pre-alignment branches are historical - ASCII tree updated to show the new branch on top of current upstream/main - "Phase B blocked" subsection replaced with "API Alignment (DONE)" documenting what was wrong, what the alignment commit does, and the three-layer separation principle (TRT-LLM owns Layer 3; we call upstream Layer 2 primitives; we never duplicate Layer 1 wire protocol mechanics) - Recommended Next Step now reflects that the API blocker is resolved and the next gates are environment-side (force-push the PR, GMS-only end-to-end test, MX setup on a separate node) No changes to the Test Matrix, Verification Tests, Critical Diagnostic, Service Setup Reference, or Benchmark Script Changes sections. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
84dfb2a to
62ac40f
Compare
Update Sections 3, 4, 5, 6 to reflect the actually-merged GMS and MX Python APIs that PR NVIDIA#13045 (commit 62ac40f) now uses. Section 15 already documents the alignment work itself; these edits bring the rest of the design doc into sync so cross-references match the code. Section 3 (Architecture): - New replica startup diagram: torch.cuda.use_mem_pool(gms_pool) -> gms_backend.mem_pool_scope(device); add move_untracked_params() to the RW commit chain. - Component roles diagram: GMS RW node now mentions mem_pool_scope(device) and move_untracked_params(); MX publish node shows it delegates to publish_model_params(). - Root-cause callout for the MX-in-GMS-RW limitation now points at MX-5 in Section 15 as the upstream alignment item. Section 4 (Implementation & API Design): - MXCheckpointLoader code block rewritten against current MX API: delegates load to MxLiveWeightLoader, delegates publish to publish_model_params (with the MODEL_EXPRESS_URL env-var dance for per-instance URLs), drops the non-existent mx_client.connect / list_sources / receive / register_source convenience calls and the proto.SourceIdentity(extra_params=dict) schema. Mixed-success conservative behavior documented. - Identity schema section explains the new structured p2p_pb2.SourceIdentity fields and forward-references MX-2 / MX-3. - GMS code block in ModelLoader.load() updated: gms_pool / use_mem_pool(gms_pool) -> mem_pool_scope(device); add move_untracked_params() before finalize_write(); add empty_cache() drain inside the scope (mirrors upstream RW reference); default tag changes to "weights" via GMSBackend.DEFAULT_TAG. - GPUMemoryBackend protocol listing rewritten: get_mem_pool() -> MemPool replaced by mem_pool_scope(device) -> ContextManager; finalize_write() now returns int (bytes committed) and takes no tag arg; release() removed (cleanup() handles teardown); has_committed_weights() takes no arg. - New "Layer-2 mapping" table shows each protocol method's upstream call so reviewers can see where API drift hits. - "What TRT-LLM Implements vs GMS Library" table refreshed with new method names + the optional-extras install hint. - Configuration section: added mx_preshard_strategy field with per_module/global semantics; gms_tag default changes to "weights" with comment pointing to GMS_TAGS upstream; gms_socket_path default comment points at upstream get_socket_path(). - Validators section updated to mention mx_preshard_strategy. - Explicitly notes we do NOT use upstream setup_gms() monkey-patch and links to GMS-1 in Section 15. Section 5 (Challenges): - Challenge 3 (TP rank matching): code example rewritten to use the current MxClient.list_sources + get_metadata flow; adds a callout for MX-3 (per-rank addressing in identity/metadata schema). - Challenge 6 (CUDA VMM Integration): use_mem_pool(gms_pool) -> mem_pool_scope(device); update method names to match the new GPUMemoryBackend protocol; clarify cleanup() responsibilities. - Challenge 8 (GMS API Stability): full rewrite. Reflects that GMS has stabilized as of PR NVIDIA#7575 and our adapter calls the merged Layer 2 primitives. New protocol listing matches code. Layer-2 mapping table added. Explicitly explains why we do NOT use the upstream setup_gms() monkey-patch entry point. References GMS-2 (move_untracked_params upstream promotion). - Challenge 9 (Transfer backend): client.receive() and client.register_source() examples replaced with current MxLiveWeightLoader.load_weights() and publish_model_params(). Section 6 (Executor & Failover): - Section intro: "gms_client.upgrade_lock()" replaced with the actual primitive sequence (unmap_all_vas + abort + reconnect-RW + remap_all_vas). - _shadow_loop pseudocode: drop the gms_client.heartbeat() call (the unix socket is the lock; OS keepalives suffice). Note that GMS-3 (peek RPC) would give a cheaper shadow-health signal here. - _activate_from_shadow pseudocode: gms_client.upgrade_lock() -> gms_backend.upgrade_to_rw() with a comment that GMSBackend should grow this method as a one-call wrapper around the primitive sequence. - _init_shadow_with_gms pseudocode: gms_client= -> gms_backend=, documents that this delegates to upstream materialize_module_from_gms(mgr, model, device_index=N). - Sleep/wake mapping table: tag="model_weights" -> tag="weights", and each "GMS Library Call" column entry updated to show the actual current API call. Adds a note explaining the "weights" / "kv_cache" naming convention from GMS_TAGS. Sections 1, 2, 7-11 had no stale GMS/MX API references and are unchanged. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
…ounds Add a dedicated "Upstream Alignment Requests" section to §15 with the full MX-1..MX-6 / GMS-1..GMS-5 request inventory. The summary table has the priority + workaround status + merge-blocking flag for each item. Cross-references in §3, §4, §5, §6 (~8 of them) have been updated to point to the new section's anchor. Updates two of the existing entries to reflect new workarounds added to PR NVIDIA#13045 in this same series of follow-up commits: MX-2 (promote _build_trtllm_identity to public API): expand the workaround note to mention that `publish_as_source()` now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars temporarily (we used to only set MODEL_EXPRESS_URL). The two env-var dances would collapse into one direct call if MX-2 lands. The MODEL_NAME resolution itself is now plumbed cleanly through `llm_args.model → MXCheckpointLoader(model_name=...)` with HF snapshot path unmangling for the basename-fallback case. MX-4 (non-blocking source-query API): add the `MX_SOURCE_QUERY_TIMEOUT=30` defensive setdefault as the current workaround. `MXCheckpointLoader.__init__` calls `os.environ.setdefault(...)` whenever an MX URL is configured so first-replica cold-cluster startup degrades from 1 hour (upstream default) to 30s. setdefault semantics preserve any explicit user value. The new section also captures items that were previously only in chat context (MX-1, MX-3, MX-5, MX-6, GMS-1..5) so the cross-references from §3-§6 actually resolve to a stable in-doc location. No changes to the test matrix, verification tests, conflict-resolution table, or the rest of §15. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
… timeout, model_name plumbing Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
b78dafc to
edb4f0f
Compare
… timeout, model_name plumbing Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
|
/bot run --disable-fail-fast |
… timeout, model_name plumbing Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
edb4f0f to
0a2a950
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #44810 [ run ] triggered by Bot. Commit: |
| with gms_use_mem_pool(self._tag, target_device): | ||
| yield | ||
|
|
||
| def move_untracked_params(self, model: nn.Module) -> None: |
There was a problem hiding this comment.
are we sure this won't cause an OOM? typically when you migrate something into GMS you will have to request a new allocation for the same piece of memory, then copy it in. simply remapping will not store it in GMS.
There was a problem hiding this comment.
Good catch. To clarify: GMSBackend.move_untracked_params is not a remap — it walks the model's parameters, calls gms_client.create_mapping(size=nbytes, tag=self._tag) to allocate fresh GMS-pool memory, then replacement.copy_(tensor) and reassigns tensor.data = replacement. Mirrors upstream gpu_memory_service.integrations.trtllm.model_loader._move_untracked_params 1:1. So the bytes do end up in GMS.
The transient peak you're worried about is real though: for the duration of each per-tensor replacement.copy_(tensor), both the old (out-of-pool) buffer and the new (in-pool) buffer exist before the old one is GC'd. In practice this only fires for stray params that landed outside mem_pool_scope (post-load transforms, etc.), which is a small fraction of total weights — but for very large models on tight budgets it could spike.
Also, in this round I moved move_untracked_params inside mem_pool_scope (commit c6c6e57) to match the upstream _load_rw reference path exactly. That doesn't change the per-tensor peak, but it does ensure any intermediate buffers tensor.copy_() may allocate via PyTorch's caching allocator end up in the GMS pool and get drained by the subsequent torch.cuda.empty_cache().
Worth tracking as an upstream ask: a streaming/in-place migrator that rebinds an existing physical alloc to a GMS-mapped VA without a copy. What do you think?
|
|
||
| from gpu_memory_service.integrations.common.utils import finalize_gms_write | ||
|
|
||
| bytes_committed = int(finalize_gms_write(self._client, model)) |
There was a problem hiding this comment.
finalize_gms_write is not quite the right abstraction if we are going to want to support speculative decoding models in this path as well, if I remember correctly. we would need to defer commit until draft model weights are loaded into GMS.
There was a problem hiding this comment.
Right — the original RW branch only loaded the main model and committed immediately, which would have left model.draft_model either uncommitted or out-of-pool. Just pushed c6c6e57 which loads draft weights inside the same mem_pool_scope (mirroring the AUTO branch's if self.spec_config is not None and ...need_load_draft_weights(): flow) so they land in the GMS pool. finalize_gms_write(model) then picks them up via its existing model-tree walk (finalize_gms_write() walks named_parameters, so anything reachable from model — including model.draft_model — gets registered in the same commit).
📝 WalkthroughWalkthroughAdds GMS GPU-memory backend and MX (ModelExpress) P2P checkpoint loader, integrates MX/GMS options into model-loading paths, adds Linear presharding flag, and extends CLI/config args and tests to support MX/GMS workflows. Changes
Sequence DiagramssequenceDiagram
participant Client
participant MXLoader as MXCheckpointLoader
participant Modelexpress
participant HFDisk as HF Checkpoint Disk
Client->>MXLoader: load_weights(checkpoint_dir, mapping, model=...)
MXLoader->>MXLoader: validate MX url & model arg
alt MX available
MXLoader->>Modelexpress: request P2P transfer
Modelexpress->>MXLoader: response (fallback_weights?)
alt no fallback
MXLoader->>Client: return {} (p2p_succeeded=True)
else partial/fallback
MXLoader->>HFDisk: load from disk
HFDisk-->>MXLoader: weights dict
MXLoader->>Client: return disk-loaded weights (p2p_succeeded=False)
end
else MX unavailable/exception
MXLoader->>HFDisk: load from disk
HFDisk-->>MXLoader: weights dict
MXLoader->>Client: return disk-loaded weights (p2p_succeeded=False)
end
sequenceDiagram
participant Client
participant Loader as ModelLoader
participant GMS as GMSBackend
participant Daemon as GMS Daemon
participant Model
Note over Client,Model: GMS RW path (loading)
Client->>Loader: load(format=GMS, mode=RW)
Loader->>GMS: connect()
GMS->>Daemon: request RW lock & mappings
Daemon-->>GMS: connected + mappings
Loader->>GMS: enter mem_pool_scope()
Loader->>Model: load weights (allocations routed to GMS pool)
Loader->>GMS: move_untracked_params(model)
GMS-->>Model: migrate params into GMS mappings
Loader->>GMS: finalize_write(model)
GMS->>Daemon: commit weights -> transition to RO
Note over Client,Model: GMS RO path (inference)
Client->>Loader: load(format=GMS, mode=RO)
Loader->>GMS: connect()
GMS->>Daemon: request RO lock
Daemon-->>GMS: connected + mappings
Loader->>GMS: materialize_module(model)
GMS-->>Model: zero-copy materialize weights, mark presharded
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
setup.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the copyright year for this modified file.
This file was meaningfully modified in 2026, but the header still ends at 2025.
Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "Add NVIDIA copyright header on ALL new files and update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@setup.py` at line 1, Update the SPDX header year range in the setup.py file from "2022-2025" to include 2026 (e.g., "2022-2026") so the copyright line reflects the file was modified in 2026; locate and edit the existing header string at the top of setup.py accordingly.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/memory/gpu_memory_backend.py (1)
395-415: Log the exception inclient.close()instead of silently passing.Static analysis flags the try-except-pass pattern (S110). Even for best-effort cleanup, logging at debug level helps diagnose shutdown issues without cluttering normal output.
🔧 Proposed fix to log the close() exception
try: client.close() - except Exception: + except Exception as e: # Best-effort: even if close() fails, evict so a future # connect() in the same process can re-establish state. - pass + logger.debug("GMS client.close() failed (best-effort): %s", e) evict_gms_client_memory_manager(client)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 395 - 415, In cleanup(), when calling client.close() inside the inner try/except, do not silently pass on exceptions; catch the Exception and log it (at debug level) so shutdown issues are visible. Locate the cleanup method and the inner try around client.close() (variable client and function evict_gms_client_memory_manager) and replace the empty except with a logger.debug call that includes the exception details (e.g., logger.debug("GMS client.close() error: %s", e, exc_info=True)) before continuing to evict and clear self._client.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py`:
- Around line 7-8: Add the required SPDX/NVIDIA copyright/header block at the
top of tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py so the file
includes the repository-standard NVIDIA copyright header with the correct year
of latest meaningful modification and SPDX identifier; place it above the
existing decorators (e.g., above `@register_config_loader`("MX") /
`@register_config_loader`("HF")) and ensure the header matches other source files
in format and content.
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py`:
- Around line 142-231: The logger.warning in load_weights currently passes
len(fallback_weights) twice to a message that implies two different values;
update the log text in the fallback handling inside load_weights (in
tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader) to avoid the
misleading "delivered %d out of N" phrasing—either compute and pass an actual
delivered count if available or reword the message to only report the fallback
count (len(fallback_weights)) and the MX server URL or total tensors if you can
obtain them; ensure the call no longer supplies duplicate format args and
reference MxLiveWeightLoader/fallback_weights/_fallback_to_disk while making
this change.
- Around line 247-318: The publish_as_source method uses implicit Optional
defaults; update its signature to use explicit union types: change mapping:
Mapping = None to mapping: Mapping | None and checkpoint_dir: str = None to
checkpoint_dir: str | None (keep model's type as-is or add/qualify with
torch.nn.Module if you want an explicit model type); ensure Python 3.10+ union
syntax is supported in this project or use typing.Optional if not, and add any
missing imports (e.g., from torch import nn if you change model to nn.Module) or
use a forward/string annotation to avoid import cycles.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 187-195: Compute the effective TP coordinates once (e.g.,
effective_tp_size = 1 if getattr(module, '_weights_presharded', False) else
module.tp_size and effective_tp_rank = 0 if getattr(module,
'_weights_presharded', False) else module.tp_rank) and reuse those variables for
every load_weight_shard(...) call in this module (including the calls that load
quantization metadata such as weight_scale, input_scale, pre_quant_scale and any
other quantized tensors), replacing direct uses of module.tp_size/module.tp_rank
so that both weights and their quantization metadata are sliced consistently
based on module._weights_presharded and module.tp_mode.
- Line 2533: The file is missing the repo-standard SPDX/NVIDIA copyright header;
add the required NVIDIA copyright/SPDX header block at the very top of the
source file (using the year of the latest meaningful modification) so all source
files include the standard notice; update the header in the module that defines
the Linear class (the file containing the assignment self._weights_presharded =
False) to match the repository's canonical NVIDIA header format.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 502-575: GMS branches skip loading draft-model weights so
model.draft_model remains uninitialized; after the GMS RW/RO handling (but
before storing self._gms_backend) replicate the existing speculative/draft load
flow: check self.spec_config.spec_dec_mode.need_load_draft_weights() and if true
load weights from self.spec_config.speculative_model using the same
checkpoint_loader pattern (use load_weights with mapping=self.mapping, call
checkpoint_loader.get_initialized_weight_mapper to set up a mapper, then invoke
self._call_load_weights to apply weights), assign the resulting module to
model.draft_model (or set via the same initializer used in other branches) and
ensure any GMS-specific memory scope is respected (for RW use
gms_backend.mem_pool_scope as needed) so draft weights are loaded outside/inside
GMS consistently.
- Around line 526-529: The GMS branch is using the hard-coded checkpoint_dir
variable which ignores a model's redirected checkpoint source; inside the
gms_backend.mem_pool_scope block (around load_weights_kwargs and
checkpoint_loader.load_weights) change the call to pass the model's redirected
path when present by resolving checkpoint_source = getattr(self, "model", None)
and using getattr(self.model, "llm_checkpoint_dir", checkpoint_dir) (or
equivalent) so that checkpoint_loader.load_weights(...) uses
model.llm_checkpoint_dir when available, falling back to checkpoint_dir
otherwise.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3437-3438: This file (tensorrt_llm/llmapi/llm_args.py) was
modified but is missing the required NVIDIA copyright header; add the
repository-mandated NVIDIA file header (including the project name and the year
of latest meaningful modification) to the top of the file so all Python sources
comply, ensuring it appears before any code or comments (i.e., above the section
that defines GMS = 3 and the "Load weights..." comment) and update the year if
this is a modification rather than a new file.
- Around line 3716-3725: Add an explicit Pydantic field validator for gms_tag to
reject empty or whitespace-only strings so configs like gms_tag="" fail fast
instead of collapsing to the default; implement a `@field_validator`("gms_tag")
`@classmethod` named e.g. validate_gms_tag that checks v.strip() and raises
ValueError("gms_tag must be non-empty") when empty, otherwise returns v; place
it in the same class that defines the gms_tag Field so the validation runs
during model parsing.
In `@tests/unittest/_torch/memory/test_gms_backend.py`:
- Around line 39-41: Remove the class-level pytest.mark.skipif and instead make
tests run on CPU by patching torch.cuda.current_device or by instantiating
GMSBackend without calling __init__; specifically, in tests that exercise
GMSBackend (e.g., where GMSBackend.__init__, connect(), and RW/RO gating are
involved) replace the skip with a setup that monkeypatches
torch.cuda.current_device to return a valid device (e.g.,
monkeypatch.setattr(torch.cuda, "current_device", lambda: 0)) so __init__ can
run under mocks, or create the object via GMSBackend.__new__(GMSBackend) and
manually set any internal attributes needed for the mocked connect()/protocol
code paths to execute; apply this change to the blocks referenced (around lines
with the prior skip and the other ranges called out) so the mocked happy/failure
connect() and gating paths execute on CPU CI.
In `@tests/unittest/llmapi/test_mx_gms_args.py`:
- Around line 279-289: Add a regression test to ensure integer load_format
values map to the new enum member by calling TorchLlmArgs.convert_load_format
with the raw integer 3 and asserting it yields LoadFormat.GMS; update the
TestLoadFormatEnum suite (or add a new test method) to include an assertion that
TorchLlmArgs.convert_load_format(3) == LoadFormat.GMS while keeping the existing
checks for AUTO/DUMMY/VISION_ONLY unchanged.
---
Outside diff comments:
In `@setup.py`:
- Line 1: Update the SPDX header year range in the setup.py file from
"2022-2025" to include 2026 (e.g., "2022-2026") so the copyright line reflects
the file was modified in 2026; locate and edit the existing header string at the
top of setup.py accordingly.
---
Nitpick comments:
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py`:
- Around line 395-415: In cleanup(), when calling client.close() inside the
inner try/except, do not silently pass on exceptions; catch the Exception and
log it (at debug level) so shutdown issues are visible. Locate the cleanup
method and the inner try around client.close() (variable client and function
evict_gms_client_memory_manager) and replace the empty except with a
logger.debug call that includes the exception details (e.g., logger.debug("GMS
client.close() error: %s", e, exc_info=True)) before continuing to evict and
clear self._client.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 88b1a034-9c30-4c0d-a6cd-250f1d7100ae
📒 Files selected for processing (19)
setup.pytensorrt_llm/_torch/memory/__init__.pytensorrt_llm/_torch/memory/gpu_memory_backend.pytensorrt_llm/_torch/models/checkpoints/__init__.pytensorrt_llm/_torch/models/checkpoints/hf/config_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.pytensorrt_llm/_torch/models/checkpoints/mx/__init__.pytensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/executor/base_worker.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/memory/test_gms_backend.pytests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.pytests/unittest/api_stability/references/llm.yamltests/unittest/llmapi/test_mx_gms_args.py
| @register_config_loader("MX") | ||
| @register_config_loader("HF") |
There was a problem hiding this comment.
Add the required SPDX/NVIDIA header to this modified source file.
The registry change is fine, but this TensorRT-LLM source file still lacks the repo-standard copyright header block.
As per coding guidelines "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py` around lines 7 -
8, Add the required SPDX/NVIDIA copyright/header block at the top of
tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py so the file includes
the repository-standard NVIDIA copyright header with the correct year of latest
meaningful modification and SPDX identifier; place it above the existing
decorators (e.g., above `@register_config_loader`("MX") /
`@register_config_loader`("HF")) and ensure the header matches other source files
in format and content.
| def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]: | ||
| """Load weights, preferring MX P2P transfer when available. | ||
|
|
||
| Delegates the actual transfer to the upstream | ||
| ``modelexpress.trtllm_live_transfer.MxLiveWeightLoader``, | ||
| which handles NIXL setup, source discovery, name matching, | ||
| dtype casting, and PVC fallback for size-mismatched tensors. | ||
|
|
||
| Args: | ||
| checkpoint_dir: Path to the HF checkpoint directory. | ||
| mapping: Distributed mapping configuration. | ||
| **kwargs: Additional keyword arguments. When ``model`` is | ||
| passed it is used as the target for direct P2P writes. | ||
|
|
||
| Returns: | ||
| A weights dict. Empty when MX P2P fully succeeded (weights | ||
| already in model params); populated when falling back to | ||
| disk loading for some or all weights. | ||
| """ | ||
| model = kwargs.pop("model", None) | ||
| self._p2p_succeeded = False | ||
|
|
||
| if self._mx_server_url is None or model is None: | ||
| return self._fallback_to_disk( | ||
| checkpoint_dir, | ||
| mapping, | ||
| reason=( | ||
| "no MX server URL configured" | ||
| if self._mx_server_url is None | ||
| else "no model reference passed (cannot do P2P writes)" | ||
| ), | ||
| **kwargs, | ||
| ) | ||
|
|
||
| try: | ||
| from modelexpress.trtllm_live_transfer import ( # type: ignore[import-not-found] | ||
| MxLiveWeightLoader, | ||
| ) | ||
| except ImportError: | ||
| logger.warning( | ||
| "modelexpress library not installed; cannot use MX P2P " | ||
| "weight transfer. Install from " | ||
| "https://github.com/ai-dynamo/modelexpress (Python client at " | ||
| "modelexpress_client/python). Falling back to disk loading." | ||
| ) | ||
| return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs) | ||
|
|
||
| try: | ||
| mx_loader = MxLiveWeightLoader(mx_server=self._mx_server_url) | ||
| fallback_weights = mx_loader.load_weights( | ||
| checkpoint_dir, | ||
| mapping=mapping, | ||
| model=model, | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| "MX P2P transfer failed (%s). Falling back to disk loading.", | ||
| e, | ||
| ) | ||
| return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs) | ||
|
|
||
| if fallback_weights: | ||
| # Mixed-success case: MX delivered most weights into model | ||
| # params via P2P, but returned a dict of size-mismatched | ||
| # tensors that still need to be loaded via the standard | ||
| # pipeline. Marking Linear modules as fully presharded would | ||
| # then incorrectly skip TP slicing for those fallback | ||
| # weights too. The conservative correct behavior is to fall | ||
| # through to the standard disk path entirely, which loads | ||
| # *all* weights normally and ignores the MX-written ones. | ||
| # | ||
| # (When LoadFormat.PRESHARDED lands upstream, the proper | ||
| # fix is per-tensor presharded marking based on whether | ||
| # MX delivered that tensor.) | ||
| logger.warning( | ||
| "MX P2P delivered %d out of N tensors but returned %d " | ||
| "fallback weights (size mismatch). Falling back to full " | ||
| "disk load to avoid mixing presharded and non-presharded " | ||
| "weights.", | ||
| len(fallback_weights), # this is a count of fallback only | ||
| len(fallback_weights), | ||
| ) | ||
| return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs) | ||
|
|
||
| self._p2p_succeeded = True | ||
| logger.info( | ||
| "MX P2P weight transfer succeeded from %s", | ||
| self._mx_server_url, | ||
| ) | ||
| return {} |
There was a problem hiding this comment.
Misleading log message format.
The log message on lines 216-222 uses len(fallback_weights) for both format arguments, but the message text implies two different values: "delivered %d out of N" (successful transfers) and "returned %d fallback weights" (failures). Since only the fallback count is available, the message should be clarified.
📝 Proposed fix for clearer logging
logger.warning(
- "MX P2P delivered %d out of N tensors but returned %d "
- "fallback weights (size mismatch). Falling back to full "
+ "MX P2P returned %d fallback weights (size mismatch). "
+ "Falling back to full "
"disk load to avoid mixing presharded and non-presharded "
"weights.",
- len(fallback_weights), # this is a count of fallback only
len(fallback_weights),
)🧰 Tools
🪛 Ruff (0.15.10)
[warning] 196-196: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
142 - 231, The logger.warning in load_weights currently passes
len(fallback_weights) twice to a message that implies two different values;
update the log text in the fallback handling inside load_weights (in
tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader) to avoid the
misleading "delivered %d out of N" phrasing—either compute and pass an actual
delivered count if available or reword the message to only report the fallback
count (len(fallback_weights)) and the MX server URL or total tensors if you can
obtain them; ensure the call no longer supplies duplicate format args and
reference MxLiveWeightLoader/fallback_weights/_fallback_to_disk while making
this change.
| def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str = None) -> None: | ||
| """Publish this instance's weights so other ranks can pull via P2P. | ||
|
|
||
| Called by the integration in ``model_loader.py`` *before* | ||
| ``post_load_weights()`` so targets receive raw loaded state and | ||
| can apply their own post-load transforms. | ||
|
|
||
| Delegates to the upstream | ||
| ``modelexpress.trtllm_live_transfer.publish_model_params`` | ||
| helper, which handles the per-rank NIXL setup, tensor | ||
| registration, and gRPC publish. | ||
|
|
||
| Args: | ||
| model: The model whose weights to publish. | ||
| mapping: Distributed mapping. Currently unused — kept for | ||
| signature symmetry with the prior prototype API and for | ||
| forward-compat with future upstream signatures. | ||
| checkpoint_dir: Checkpoint directory. Used as a last-resort | ||
| fallback for resolving the ``MODEL_NAME`` identity when | ||
| neither ``model_name`` was passed to the constructor nor | ||
| ``MODEL_NAME`` is set in the environment. | ||
| """ | ||
| del mapping # currently unused; see docstring. | ||
|
|
||
| if self._mx_server_url is None: | ||
| return | ||
|
|
||
| try: | ||
| from modelexpress.trtllm_live_transfer import ( # type: ignore[import-not-found] | ||
| publish_model_params, | ||
| ) | ||
| except ImportError: | ||
| logger.debug("modelexpress library not installed; skipping MX publish.") | ||
| return | ||
|
|
||
| # Upstream publish_model_params reads MODEL_EXPRESS_URL and | ||
| # MODEL_NAME from the environment. Set both from our resolved | ||
| # configuration so per-instance values (URL passed via | ||
| # llm_args.mx_server_url, identity from llm_args.model) are | ||
| # respected, then restore prior state. Tracked as MX-2 in §15 | ||
| # (the env-var dance goes away when upstream exports a public | ||
| # ``build_identity()`` we can call directly). | ||
| resolved_name = _resolve_mx_model_name(self._model_name, checkpoint_dir) | ||
|
|
||
| env_overrides = { | ||
| "MODEL_EXPRESS_URL": self._mx_server_url, | ||
| "MODEL_NAME": resolved_name, | ||
| } | ||
| prior = {key: os.environ.get(key) for key in env_overrides} | ||
| for key, value in env_overrides.items(): | ||
| os.environ[key] = value | ||
|
|
||
| try: | ||
| publish_model_params(model) | ||
| logger.info( | ||
| "Published weights to MX server at %s as model=%r", | ||
| self._mx_server_url, | ||
| resolved_name, | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| "Failed to publish weights to MX server at %s: %s", | ||
| self._mx_server_url, | ||
| e, | ||
| ) | ||
| finally: | ||
| for key, prior_value in prior.items(): | ||
| if prior_value is None: | ||
| os.environ.pop(key, None) | ||
| else: | ||
| os.environ[key] = prior_value | ||
|
|
There was a problem hiding this comment.
Use explicit | None for optional parameter types.
Per PEP 484 and coding guidelines, implicit Optional (via = None default) is prohibited. The parameters mapping and checkpoint_dir need explicit union type annotations.
🔧 Proposed fix for type annotations
- def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str = None) -> None:
+ def publish_as_source(
+ self, model: "nn.Module", mapping: Mapping | None = None, checkpoint_dir: str | None = None
+ ) -> None:Note: You'll need to add from torch import nn to imports if not already present, or use a string annotation for model.
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 247-247: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
[warning] 306-306: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
247 - 318, The publish_as_source method uses implicit Optional defaults; update
its signature to use explicit union types: change mapping: Mapping = None to
mapping: Mapping | None and checkpoint_dir: str = None to checkpoint_dir: str |
None (keep model's type as-is or add/qualify with torch.nn.Module if you want an
explicit model type); ensure Python 3.10+ union syntax is supported in this
project or use typing.Optional if not, and add any missing imports (e.g., from
torch import nn if you change model to nn.Module) or use a forward/string
annotation to avoid import cycles.
| # Skip TP slicing when weights are already sharded (e.g. MX P2P transfer). | ||
| tp_size = 1 if getattr(module, '_weights_presharded', | ||
| False) else module.tp_size | ||
| tp_rank = 0 if getattr(module, '_weights_presharded', | ||
| False) else module.tp_rank | ||
|
|
||
| weight = load_weight_shard(weights[0]['weight'], tp_size, tp_rank, | ||
| module.tp_mode, | ||
| device) if "weight" in weights[0] else None |
There was a problem hiding this comment.
Don't bypass TP slicing for weights without also bypassing it for quantization metadata.
_weights_presharded now stops extra slicing for the main weight/bias tensors, but the quantized loaders later in this file still shard weight_scale, input_scale, pre_quant_scale, and similar tensors with module.tp_size/module.tp_rank. For presharded MX/GMS loads, that leaves full local weights paired with re-sliced metadata, which will either shape-mismatch or silently produce wrong results on quantized models. Please centralize the “effective TP coords” logic and reuse it for every load_weight_shard(...) call in this file, not just these helpers.
Also applies to: 234-248, 293-304
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/linear.py` around lines 187 - 195, Compute the
effective TP coordinates once (e.g., effective_tp_size = 1 if getattr(module,
'_weights_presharded', False) else module.tp_size and effective_tp_rank = 0 if
getattr(module, '_weights_presharded', False) else module.tp_rank) and reuse
those variables for every load_weight_shard(...) call in this module (including
the calls that load quantization metadata such as weight_scale, input_scale,
pre_quant_scale and any other quantized tensors), replacing direct uses of
module.tp_size/module.tp_rank so that both weights and their quantization
metadata are sliced consistently based on module._weights_presharded and
module.tp_mode.
| self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm | ||
| self.disable_deep_gemm = disable_deep_gemm | ||
| self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping | ||
| self._weights_presharded = False |
There was a problem hiding this comment.
Add the required SPDX/NVIDIA header to this modified source file.
This file now has meaningful changes but still does not carry the repo-standard copyright header block.
As per coding guidelines "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/linear.py` at line 2533, The file is missing the
repo-standard SPDX/NVIDIA copyright header; add the required NVIDIA
copyright/SPDX header block at the very top of the source file (using the year
of the latest meaningful modification) so all source files include the standard
notice; update the header in the module that defines the Linear class (the file
containing the assignment self._weights_presharded = False) to match the
repository's canonical NVIDIA header format.
| with gms_backend.mem_pool_scope(device): | ||
| load_weights_kwargs = {"mapping": self.mapping} | ||
| weights = checkpoint_loader.load_weights( | ||
| checkpoint_dir, **load_weights_kwargs) |
There was a problem hiding this comment.
GMS RW drops the model.llm_checkpoint_dir redirect.
The AUTO path already loads from model.llm_checkpoint_dir when the model rewrites its checkpoint source, but the GMS RW branch hard-codes checkpoint_dir. That creates a GMS-only failure mode for the same model class.
💡 Suggested fix
- weights = checkpoint_loader.load_weights(
- checkpoint_dir, **load_weights_kwargs)
+ weight_source = (
+ model.llm_checkpoint_dir
+ if hasattr(model, "llm_checkpoint_dir")
+ else checkpoint_dir
+ )
+ weights = checkpoint_loader.load_weights(
+ weight_source, **load_weights_kwargs)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| with gms_backend.mem_pool_scope(device): | |
| load_weights_kwargs = {"mapping": self.mapping} | |
| weights = checkpoint_loader.load_weights( | |
| checkpoint_dir, **load_weights_kwargs) | |
| with gms_backend.mem_pool_scope(device): | |
| load_weights_kwargs = {"mapping": self.mapping} | |
| weight_source = ( | |
| model.llm_checkpoint_dir | |
| if hasattr(model, "llm_checkpoint_dir") | |
| else checkpoint_dir | |
| ) | |
| weights = checkpoint_loader.load_weights( | |
| weight_source, **load_weights_kwargs) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 526 - 529, The
GMS branch is using the hard-coded checkpoint_dir variable which ignores a
model's redirected checkpoint source; inside the gms_backend.mem_pool_scope
block (around load_weights_kwargs and checkpoint_loader.load_weights) change the
call to pass the model's redirected path when present by resolving
checkpoint_source = getattr(self, "model", None) and using getattr(self.model,
"llm_checkpoint_dir", checkpoint_dir) (or equivalent) so that
checkpoint_loader.load_weights(...) uses model.llm_checkpoint_dir when
available, falling back to checkpoint_dir otherwise.
| # Load weights from GPU Memory Service (read-only GPU memory pool). | ||
| GMS = 3 |
There was a problem hiding this comment.
Add the required NVIDIA file header.
This file is modified in this PR but still lacks the repository-mandated NVIDIA copyright header / latest-modification year.
As per coding guidelines, **/*.{cpp,cc,h,hpp,c,py} must “Add NVIDIA copyright header on ALL new files and update year on modified files” and “All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/llmapi/llm_args.py` around lines 3437 - 3438, This file
(tensorrt_llm/llmapi/llm_args.py) was modified but is missing the required
NVIDIA copyright header; add the repository-mandated NVIDIA file header
(including the project name and the year of latest meaningful modification) to
the top of the file so all Python sources comply, ensuring it appears before any
code or comments (i.e., above the section that defines GMS = 3 and the "Load
weights..." comment) and update the year if this is a modification rather than a
new file.
| gms_tag: str = Field( | ||
| default="weights", | ||
| description="Tag identifying the weight set in the GMS memory pool. " | ||
| "Used to distinguish multiple models or versions in the same " | ||
| "GMS instance. Defaults to 'weights' to match the GMS library " | ||
| "convention (gpu_memory_service uses 'weights' for the model " | ||
| "weight tag and 'kv_cache' for the KV cache tag). " | ||
| "Only used when load_format='GMS'.", | ||
| status="prototype", | ||
| ) |
There was a problem hiding this comment.
Reject empty gms_tag values.
gms_tag="" currently parses, and the GMS load path later collapses falsey tags to the shared default "weights". That turns a bad config into a silent pool collision instead of failing fast.
💡 Suggested fix
- gms_tag: str = Field(
+ gms_tag: str = Field(
default="weights",
description="Tag identifying the weight set in the GMS memory pool. "
"Used to distinguish multiple models or versions in the same "
"GMS instance. Defaults to 'weights' to match the GMS library "
"convention (gpu_memory_service uses 'weights' for the model "
"weight tag and 'kv_cache' for the KV cache tag). "
"Only used when load_format='GMS'.",
status="prototype",
)`@field_validator`("gms_tag")
`@classmethod`
def validate_gms_tag(cls, v: str) -> str:
if not v.strip():
raise ValueError("gms_tag must be non-empty")
return v🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/llmapi/llm_args.py` around lines 3716 - 3725, Add an explicit
Pydantic field validator for gms_tag to reject empty or whitespace-only strings
so configs like gms_tag="" fail fast instead of collapsing to the default;
implement a `@field_validator`("gms_tag") `@classmethod` named e.g. validate_gms_tag
that checks v.strip() and raises ValueError("gms_tag must be non-empty") when
empty, otherwise returns v; place it in the same class that defines the gms_tag
Field so the validation runs during model parsing.
| @pytest.mark.skipif( | ||
| not torch.cuda.is_available(), reason="GMSBackend.__init__ calls torch.cuda.current_device()" | ||
| ) |
There was a problem hiding this comment.
These skips leave the new GMS control-path tests effectively uncovered in CPU CI.
Almost everything exercised here is mocked; the only real CUDA dependency is GMSBackend.__init__ touching torch.cuda.current_device(). With the class-level skipif, the pre-connect, connect(), RW/RO gating, and protocol checks all disappear from the default CPU-only unit suite, and this file still never asserts a successful RW or RO connect() path. Please patch torch.cuda.current_device() (or construct via __new__) so the mocked happy/failure paths run without a GPU. QA list updates are unnecessary here because this is still unittest-only coverage.
As per coding guidelines "Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix."
Also applies to: 88-90, 129-148, 156-158, 268-276
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/memory/test_gms_backend.py` around lines 39 - 41,
Remove the class-level pytest.mark.skipif and instead make tests run on CPU by
patching torch.cuda.current_device or by instantiating GMSBackend without
calling __init__; specifically, in tests that exercise GMSBackend (e.g., where
GMSBackend.__init__, connect(), and RW/RO gating are involved) replace the skip
with a setup that monkeypatches torch.cuda.current_device to return a valid
device (e.g., monkeypatch.setattr(torch.cuda, "current_device", lambda: 0)) so
__init__ can run under mocks, or create the object via
GMSBackend.__new__(GMSBackend) and manually set any internal attributes needed
for the mocked connect()/protocol code paths to execute; apply this change to
the blocks referenced (around lines with the prior skip and the other ranges
called out) so the mocked happy/failure connect() and gating paths execute on
CPU CI.
| class TestLoadFormatEnum: | ||
| def test_gms_enum_present(self): | ||
| # The prototype adds LoadFormat.GMS = 3. | ||
| assert hasattr(LoadFormat, "GMS") | ||
| assert LoadFormat.GMS.name == "GMS" | ||
|
|
||
| def test_pre_existing_enums_unchanged(self): | ||
| # Sanity — make sure the prototype didn't reshuffle enum values. | ||
| assert LoadFormat.AUTO.value == 0 | ||
| assert LoadFormat.DUMMY.value == 1 | ||
| assert LoadFormat.VISION_ONLY.value == 2 |
There was a problem hiding this comment.
Add a regression test for integer load_format inputs.
TorchLlmArgs.convert_load_format() now accepts raw ints, but this suite only checks enum members/constants. A case like load_format=3 -> LoadFormat.GMS should be pinned down here.
🧪 Suggested test
class TestLoadFormatEnum:
+ def test_int_value_is_converted(self):
+ args = _make_args(load_format=3)
+ assert args.load_format == LoadFormat.GMS
+
def test_gms_enum_present(self):
# The prototype adds LoadFormat.GMS = 3.
assert hasattr(LoadFormat, "GMS")
assert LoadFormat.GMS.name == "GMS"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/llmapi/test_mx_gms_args.py` around lines 279 - 289, Add a
regression test to ensure integer load_format values map to the new enum member
by calling TorchLlmArgs.convert_load_format with the raw integer 3 and asserting
it yields LoadFormat.GMS; update the TestLoadFormatEnum suite (or add a new test
method) to include an assertion that TorchLlmArgs.convert_load_format(3) ==
LoadFormat.GMS while keeping the existing checks for AUTO/DUMMY/VISION_ONLY
unchanged.
… timeout, model_name plumbing Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
0a2a950 to
6bdcfd8
Compare
…ht sharing
Prototype implementation of the two-axis integration model for MX
(ModelExpress P2P transfer) and GMS (GPU Memory Service) in TRT-LLM's
PyTorch backend. Enables fast cold-start via GPU-to-GPU weight
streaming (MX) and zero-copy weight sharing with crash-resilient
failover (GMS).
The key design insight is that MX and GMS map onto two **independent
axes** in TRT-LLM's existing loading pipeline:
- ``checkpoint_format`` (weight source axis): ``"MX"`` — P2P via
upstream ``modelexpress.trtllm_live_transfer.MxLiveWeightLoader``,
with automatic HF disk fallback (inherited from HfCheckpointLoader).
- ``LoadFormat`` (memory management axis): ``GMS`` — out-of-process
GPU memory pool with RW/RO dual paths via the upstream
``gpu_memory_service`` class-based API.
These compose independently, giving four modes with no combinatorial
config explosion:
| Mode | checkpoint_format | LoadFormat | Use case |
|--------------|-------------------|------------|------------------------------------------|
| Pure TRT-LLM | "HF" (default) | AUTO | Current behavior, unchanged |
| MX only | "MX" | AUTO | Cross-node P2P weight transfer |
| GMS only | "HF" | GMS | Within-node weight sharing + failover |
| MX + GMS | "MX" | GMS | Cross-node P2P + within-node sharing |
Design intent: TRT-LLM owns the integration policy (Layer 3) and
calls the upstream libraries' stable per-call primitives (Layer 2).
We never duplicate the wire-protocol or runtime mechanics (Layer 1 —
gRPC, NIXL RDMA, CUDA VMM). We deliberately do NOT use the upstream
``gpu_memory_service.integrations.trtllm.setup_gms()`` entry point
because it monkey-patches ``ModelLoader.load`` from outside, which is
opaque at code-review time and conflicts with the two-axis design.
What's included:
MXCheckpointLoader (_torch/models/checkpoints/mx/checkpoint_loader.py)
- Subclasses HfCheckpointLoader — disk fallback inherited.
- Delegates the actual NIXL/RDMA transfer to upstream
MxLiveWeightLoader.load_weights(checkpoint_dir, mapping=, model=)
(handles agent setup, source matching, dtype-cast handling,
PVC fallback for size-mismatched tensors).
- Delegates the publish side to upstream publish_model_params(model).
- Exposes ``p2p_succeeded`` property — consumed by ModelLoader to
set ``_weights_presharded`` on Linear modules.
- Pre-post_load_weights() MX source publish timing so targets
receive raw loaded state and run their own transforms.
- Conservative mixed-success behavior: if MX returns
size-mismatched fallback weights, falls through to a full disk
load to avoid mixing presharded and non-presharded weights.
GMSBackend (_torch/memory/gpu_memory_backend.py)
- GPUMemoryBackend protocol + concrete GMSBackend implementation
wrapping upstream GMSClientMemoryManager and the gms_use_mem_pool
context manager.
- Mode resolved at connect() via
get_or_create_gms_client_memory_manager(socket, device,
mode=RequestedLockType.RW_OR_RO) and inspected via granted_lock_type.
- RW path: caller uses gms_backend.mem_pool_scope(device) context
manager (delegates to gms_use_mem_pool("weights", device)); after
load, move_untracked_params() migrates stray buffers into the
pool, then finalize_write() delegates to upstream
finalize_gms_write() (handles register + sync + commit + RO
reconnect + remap in one call).
- RO path: post_load_weights() runs *before* materialize_module()
(sets up module aliases first); guard prevents double-execution;
materialize_module_from_gms(mgr, model, device_index=N) does
the zero-copy import.
- Default tag is ``"weights"`` matching the GMS library convention
(``weights`` for model weights, ``kv_cache`` for KV cache).
- Default socket path resolves via get_socket_path(device, tag)
(UUID-based, stable across CUDA_VISIBLE_DEVICES).
- VMM safety: connect() applies patch_empty_cache() to prevent
segfaults on VMM-backed allocations.
ModelLoader changes (_torch/pyexecutor/model_loader.py)
- New LoadFormat.GMS branch with full RW/RO dual paths.
- MX source publish fires for LoadFormat.AUTO and GMS-RW modes,
before post_load_weights().
- _weights_presharded flag set per-path (MX P2P in AUTO, GMS
materialize in RO).
- Draft model modules excluded from presharded flag.
- Wires the mx_preshard_strategy knob (per_module / global).
Configuration (llmapi/llm_args.py)
- mx_server_url, mx_preshard_strategy, gms_socket_path, gms_mode,
gms_tag — all status="prototype".
- mx_preshard_strategy: per_module | global (default per_module);
global raises NotImplementedError until LoadFormat.PRESHARDED
lands upstream (tracked as MX-1 in the design doc §15).
- Validators: validate_mx_config(), validate_gms_config() reject
bad values with friendly messages.
- API stability YAML updated.
Linear module (_torch/modules/linear.py)
- _weights_presharded = False declared in __init__ — TP slicing
skipped when True (concept from the closed PR NVIDIA#12898).
Optional dependencies (setup.py)
- The MX (modelexpress) and GMS (gpu-memory-service) Python
packages are intentionally NOT declared as ``[mx]`` / ``[gms]`` /
``[dynamo]`` extras_require entries while the integration is at
prototype status. Until both packages are PyPI-published *and*
onboarded into NVIDIA's OSS allowlist, users install them
manually:
pip install "modelexpress>=0.3.0,<0.4.0"
git clone https://github.com/ai-dynamo/dynamo \
&& pip install ./dynamo/lib/gpu_memory_service
Restoring one-line ``pip install tensorrt_llm[dynamo]``
ergonomics is a single-hunk revert once upstream publish + OSS
allowlist steps are complete (tracked in §15 as MX-7 and GMS-6).
Rationale and the suppressed extras block are inline-documented
in setup.py at the extras_require site.
Design decisions:
- No monkey-patching. The upstream setup_gms() entry point patches
ModelLoader.load from outside. We deliberately do not use it:
TRT-LLM owns the integration policy, and GMSBackend is the
explicit, reviewable boundary.
- MX P2P is NOT used in GMS RW mode. Model params are meta tensors
at that point — no CUDA buffers for P2P to write into, and
MX/NIXL would allocate buffers outside the GMS pool. Disk
loading under the GMS pool is the correct behavior.
- Property-based ``checkpoint_format`` override on
MXCheckpointLoader — the parent sets _checkpoint_format = "HF",
the subclass overrides both the property and backing attribute.
CLI usage:
# MX only (cross-node P2P)
trtllm-serve <model> --checkpoint-format mx --mx-server-url http://mx:8001
# GMS only (within-node sharing + crash resilience)
trtllm-serve <model> --load-format gms --gms-socket-path /tmp/gms-0.sock
# MX + GMS (full two-axis composition)
trtllm-serve <model> --checkpoint-format mx --mx-server-url http://mx:8001 \
--load-format gms --gms-socket-path /tmp/gms-0.sock
Relationship to existing work:
- Adopts and adapts the _weights_presharded concept and pre-
post_load_weights() publish timing from MX team's PR NVIDIA#12898
(closed in favor of this two-axis model).
- Uses the same upstream Layer-2 primitives as dynamo PR NVIDIA#7575 (the
official TRT-LLM sleep/wake integration with GMS), but invokes
them directly rather than via setup_gms() monkey-patch.
- Full design doc at docs/design/mx-gms-integration/ on the
docs-and-plans branch.
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
… timeout, model_name plumbing Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
…dapters
Adds 136 CPU-only unit tests covering the MX/GMS prototype's TRT-LLM-side
surface. Tests do NOT exercise the upstream ``modelexpress`` or
``gpu_memory_service`` libraries — fallback paths are exercised by
injecting None into ``sys.modules`` (forcing ImportError on ``import``),
and success paths use fake module trees. Real CUDA is gated with
``@pytest.mark.skipif(not torch.cuda.is_available())``.
Test files added:
tests/unittest/llmapi/test_mx_gms_args.py (41 tests)
- Default values for all 5 prototype config fields. Notable:
``gms_tag`` defaults to ``"weights"`` (not ``"model_weights"``)
matching the upstream GMS_TAGS convention.
- ``mx_preshard_strategy`` validation: accepts per_module/global,
rejects everything else, parametrized cross-field warnings on
(checkpoint_format, strategy).
- ``mx_server_url`` cross-field warning: parametrized over the
{explicit-with-MX, explicit-without-MX, unset} matrix.
- ``MODEL_EXPRESS_URL`` env-var fallback (validator-level): env
populates / explicit wins / no-env stays None / env ignored when
not MX / empty env is no-op / no spurious cross-field warning on
env-fallback path.
- ``gms_mode`` validator: only enforced when ``load_format=GMS``;
accepts auto/rw/ro, rejects others.
- ``gms_socket_path`` cross-field warning: parametrized over the
{GMS, AUTO} matrix.
- Two-axis composition sanity: pure TRT-LLM, MX-only, GMS-only,
MX+GMS all construct without error.
- LoadFormat enum: GMS=3 added, pre-existing values unshifted.
Uses ``unittest.mock.patch`` on
``tensorrt_llm.llmapi.llm_args.logger`` (TRT-LLM uses a custom
Singleton logger that does NOT route through stdlib logging, so
pytest's ``caplog`` does not intercept it).
tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py (53 tests)
- Construction: subclasses HfCheckpointLoader, exposes
mx_server_url/p2p_succeeded/checkpoint_format/model_name
properties, registered under ``checkpoint_format="MX"`` in the
BaseCheckpointLoader registry (verified via the same call shape
that ``_construct_checkpoint_loader`` uses).
- Disk fallback paths (no upstream library involved) — parametrized
on the fallback trigger (no_url / no_model / lib_unavailable /
upstream_raises). Each case asserts the same observable contract:
p2p_succeeded stays False, super().load_weights is invoked once
and its return value is propagated unchanged.
- MX path (mocked upstream):
* Empty fallback dict → p2p_succeeded=True, returns {}, plus
assertions that MxLiveWeightLoader was constructed with
mx_server=URL and load_weights was called with the right args
(pins the integration contract).
* Non-empty fallback dict (mixed-success) → falls through to
full disk load to avoid mixing presharded and non-presharded
weights (tracked as MX-1 in §15 of the design doc).
- publish_as_source: no-op when no URL or upstream missing,
delegates to publish_model_params(model), env-var dance for both
MODEL_EXPRESS_URL and MODEL_NAME (sets during call, restores
prior value), swallows upstream exceptions.
- MX_SOURCE_QUERY_TIMEOUT defensive default: unset gets "30",
explicit user value preserved (setdefault semantics), no MX URL
leaves env untouched.
- model_name plumbing: constructor accepts str/Path/None, coerced
to str at construction.
- _normalize_model_identity (8 parametrized cases): pass-through
for Hub IDs and bare names; basename for absolute, relative,
and home-expansion paths; HF-snapshot unmangling for both simple
and nested-org HF cache layouts.
- _resolve_mx_model_name (6 cases): priority order
(explicit > env > basename > "unknown"); explicit-arg
normalization.
- publish_as_source MODEL_NAME end-to-end (6 cases): explicit
constructor value wins, basename fallback, HF-snapshot
unmangling, constructor priority over env, env-only path,
"unknown" sentinel.
tests/unittest/_torch/memory/test_gms_backend.py (42 tests)
- Construction: validates mode in {rw, ro, auto}, default tag is
"weights" (matches DEFAULT_TAG class attribute), default socket
path is resolved lazily inside connect().
- Pre-connect state: is_rw is None, has_committed_weights returns
False (not raise), all RW-only methods raise RuntimeError with
"not connected" — parametrized on (method_name, invoke), cleanup
is safe.
- connect() failure: returns False (not raises) when
gpu_memory_service is not importable; returns False when the
upstream factory raises (e.g. socket missing).
- RW-only method gating: parametrized — mem_pool_scope and
finalize_write raise "only valid in RW mode" when granted lock
is RO.
- _ptr_in_gms helper: half-open interval semantics (parametrized
over inside / start / just-below-end / at-end / just-below-start
/ far-outside cases), no-mappings-attr handling, zero-va/size
sentinel guards (parametrized).
- _storage_nbytes helper: dtype-dependent sizing, view-independent
sizing on real CPU tensors.
- Protocol conformance: GMSBackend instances satisfy the
runtime-checkable GPUMemoryBackend protocol; the protocol
exposes the exact method set ``model_loader.py`` invokes (so a
different backend like a future CudaIpcBackend can be plugged
in without touching call sites).
Total runtime: ~0.6 s for all 136 tests on the development node.
Coverage gaps deliberately left for follow-ups:
- Real GMS daemon end-to-end (requires gpu-memory-service install
+ GPU). See Section 15 of the design doc for the full E2E test
plan (B2 GMS shadow memory, B6 failover floor).
- Real MX server end-to-end (requires MX server build via Rust +
Docker). See Section 15 (B1 bit-exactness, B3 P2P throughput,
B4 cold-start headline).
- LoadFormat.GMS branch in model_loader.py end-to-end (requires
real model + real GMS daemon).
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
6bdcfd8 to
06491b6
Compare
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
/bot run --disable-fail-fast |
|
PR_Github #45034 [ run ] triggered by Bot. Commit: |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
332-402:⚠️ Potential issue | 🟠 MajorKeep buffer materialization separate from the GMS weight skip.
This now skips
_apply_to_buffers_only(...)and the meta init path for all tensors in GMS mode, not just parameters. Any registered buffer that is not restored from the checkpoint will stay onmeta, and the laterpost_load_weights()/ MoE finalization path can still trip over it. The GMS special-case should bypass parameter preallocation, but buffers still need to be materialized.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 332 - 402, The current GMS branch skips calling _apply_to_buffers_only and the buffer meta-init path, leaving registered buffers on meta; ensure buffers are always materialized: move the call to _apply_to_buffers_only(model, allocate_buffer_on_cuda) out of the conditional that checks self.model_weights_memory_tag and load_format != LoadFormat.GMS so it runs even when load_format == LoadFormat.GMS, and likewise allow the init_meta_tensor path to materialize buffers (i.e., don't gate buffer initialization by load_format != LoadFormat.GMS) while keeping the special-case skip for parameter preallocation (allocate_weights_on_cuda and its virtual_memory_scope) when load_format == LoadFormat.GMS; key symbols to change: _apply_to_buffers_only, allocate_buffer_on_cuda, init_meta_tensor, allocate_weights_on_cuda, self.model_weights_memory_tag, and LoadFormat.GMS.
♻️ Duplicate comments (1)
tests/unittest/_torch/memory/test_gms_backend.py (1)
87-92:⚠️ Potential issue | 🟠 MajorAdd a successful
connect()path to this suite.The CPU-CI enablement is fixed, but the tests still only cover pre-connect behavior and failure paths. The main control path is still unpinned:
connect()successfully granting RW/RO, setting_client/_is_rw, and resolvingsocket_path=Nonevia the upstream helper. A fake happy-path client here would catch regressions in the pathmodel_loaderactually depends on. QA list updates are unnecessary because this is still unittest-only coverage.As per coding guidelines, “Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix.”
Also applies to: 139-183
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/memory/test_gms_backend.py` around lines 87 - 92, Add a “happy-path” connect() flow to the test test_socket_path_none_resolved_lazily by creating a fake client and stubbing the upstream socket-path resolver so that GMSBackend.connect() succeeds, sets backend._client and backend._is_rw appropriately, and resolves socket_path when initialized with socket_path=None; specifically, in the test arrange a MagicMock/fake client that returns expected RW/RO capability, patch the get_socket_path (or equivalent helper used by GMSBackend.connect) to return a concrete path, call backend.connect(), and assert backend._client is the fake client, backend._is_rw is True/False as expected, and backend._socket_path is the resolved path so the real control path is covered.
🧹 Nitpick comments (1)
tensorrt_llm/llmapi/llm_args.py (1)
3684-3697: UseLiteral[...]formx_preshard_strategy.This is a user-facing Pydantic field with a closed value set, but typing it as
strhides those constraints from schema/docs and pushes basic value validation into the model validator.Literal["per_module", "global"]would make the contract explicit while keeping the cross-field warning invalidate_mx_config.As per coding guidelines, user-facing Python Pydantic fields should “Use
Literal["value1", "value2"]instead ofstrin Python Pydantic fields when only certain values are allowed.”🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 3684 - 3697, Change the mx_preshard_strategy field's type from str to a Literal restricted to the allowed values (Literal["per_module", "global"]) so the Pydantic schema and docs expose the closed set; add the necessary import for Literal (from typing or typing_extensions as used in the repo), keep the default="per_module" and existing description/status, and leave any cross-field validation logic in validate_mx_config unchanged (it should continue to enforce/emit warnings or raise NotImplementedError when needed).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py`:
- Around line 296-321: The loop currently uses seen: set[int] and skips
subsequent tensors sharing the same storage_ptr, leaving aliases pointing
outside GMS; change seen to a dict mapping original storage_ptr -> replacement
tensor (or its base_va) so after you create replacement via
gms_client.create_mapping and _tensor_from_pointer (and do replacement.copy_),
you store seen[storage_ptr] = replacement and for any later tensor with the same
storage_ptr (found in _iter_module_tensors) you assign tensor.data =
seen[storage_ptr] instead of continuing; keep the existing _ptr_in_gms check
(and treat tensors already in GMS as already-mapped) and reuse _storage_nbytes,
_tensor_from_pointer, replacement.copy_, and tensor.data = ... symbols to locate
where to implement this change.
- Around line 213-216: In connect(), if patch_empty_cache() raises an exception,
log the error (keep or upgrade the current logger.debug message) and immediately
return False so the connection fails instead of returning True; locate the
try/except around patch_empty_cache() in the connect() method and add a return
False in the except block (reference: connect() and patch_empty_cache()).
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py`:
- Around line 295-321: The env var override + publish sequence in
checkpoint_loader (the block that sets MODEL_EXPRESS_URL/MODEL_NAME, calls
publish_model_params(model), and restores os.environ) must be serialized with a
module-level lock to prevent concurrent publish_as_source()/publish_model_params
calls from interleaving; add a module-level threading.Lock (e.g.,
_env_override_lock) and wrap the entire override/publish/restore sequence in a
with _env_override_lock: so the try/except/finally and the restore logic run
while holding the lock; ensure the lock is declared at module scope so other
functions in this module can reuse it.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 178-191: The _weights_presharded flag set by MX P2P/GMS is never
cleared, so _effective_tp_coords() can incorrectly return (1, 0) on subsequent
non-presharded loads; update Linear.load_weights (or add/ensure a call to
pre_reload_weights) to clear module._weights_presharded (or reconstruct the
Linear instance) before loading regular checkpoint weights so TP slicing is
re-enabled; specifically, ensure any reload path that calls Linear.load_weights
resets _weights_presharded (or invokes pre_reload_weights) so
_effective_tp_coords, the main weight/bias and quantization metadata are sliced
consistently.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 512-610: Assign self._gms_backend immediately after
gms_backend.connect() returns True and then wrap the subsequent GMS load logic
(the RW/RO branches, finalize_write/materialize_module calls, and any
draft-weight handling) in a try/except that on exception clears
self._gms_backend and calls the backend’s cleanup/disconnect method (e.g.,
gms_backend.disconnect() or the appropriate close/cleanup API) before
re-raising; this ensures the live GMS handle and CUDA hooks are released if
checkpoint_loader.load_weights, finalize_write, materialize_module, or other
calls fail.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3833-3835: Reject boolean inputs before treating numbers as
LoadFormat: in the load_format validator (the block that currently does "if
isinstance(v, int): return LoadFormat(v)"), add an explicit check for
isinstance(v, bool) and raise a ValidationError/TypeError so True/False do not
get interpreted as integers; keep the existing int path to map valid integers to
LoadFormat. Add a regression test named test_load_format_rejects_bool in
test_mx_gms_args.py that passes load_format=True and asserts a validation
failure. Also change the field declaration for mx_preshard_strategy from
"mx_preshard_strategy: str = Field(...)" to "mx_preshard_strategy:
Literal['per_module','global'] = Field(...)" so Pydantic enforces the allowed
values.
---
Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 332-402: The current GMS branch skips calling
_apply_to_buffers_only and the buffer meta-init path, leaving registered buffers
on meta; ensure buffers are always materialized: move the call to
_apply_to_buffers_only(model, allocate_buffer_on_cuda) out of the conditional
that checks self.model_weights_memory_tag and load_format != LoadFormat.GMS so
it runs even when load_format == LoadFormat.GMS, and likewise allow the
init_meta_tensor path to materialize buffers (i.e., don't gate buffer
initialization by load_format != LoadFormat.GMS) while keeping the special-case
skip for parameter preallocation (allocate_weights_on_cuda and its
virtual_memory_scope) when load_format == LoadFormat.GMS; key symbols to change:
_apply_to_buffers_only, allocate_buffer_on_cuda, init_meta_tensor,
allocate_weights_on_cuda, self.model_weights_memory_tag, and LoadFormat.GMS.
---
Duplicate comments:
In `@tests/unittest/_torch/memory/test_gms_backend.py`:
- Around line 87-92: Add a “happy-path” connect() flow to the test
test_socket_path_none_resolved_lazily by creating a fake client and stubbing the
upstream socket-path resolver so that GMSBackend.connect() succeeds, sets
backend._client and backend._is_rw appropriately, and resolves socket_path when
initialized with socket_path=None; specifically, in the test arrange a
MagicMock/fake client that returns expected RW/RO capability, patch the
get_socket_path (or equivalent helper used by GMSBackend.connect) to return a
concrete path, call backend.connect(), and assert backend._client is the fake
client, backend._is_rw is True/False as expected, and backend._socket_path is
the resolved path so the real control path is covered.
---
Nitpick comments:
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3684-3697: Change the mx_preshard_strategy field's type from str
to a Literal restricted to the allowed values (Literal["per_module", "global"])
so the Pydantic schema and docs expose the closed set; add the necessary import
for Literal (from typing or typing_extensions as used in the repo), keep the
default="per_module" and existing description/status, and leave any cross-field
validation logic in validate_mx_config unchanged (it should continue to
enforce/emit warnings or raise NotImplementedError when needed).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: ca8dadee-00ef-4aaf-b997-796f3c04c71d
📒 Files selected for processing (19)
setup.pytensorrt_llm/_torch/memory/__init__.pytensorrt_llm/_torch/memory/gpu_memory_backend.pytensorrt_llm/_torch/models/checkpoints/__init__.pytensorrt_llm/_torch/models/checkpoints/hf/config_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_loader.pytensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.pytensorrt_llm/_torch/models/checkpoints/mx/__init__.pytensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/executor/base_worker.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/memory/test_gms_backend.pytests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.pytests/unittest/api_stability/references/llm.yamltests/unittest/llmapi/test_mx_gms_args.py
✅ Files skipped from review due to trivial changes (6)
- tensorrt_llm/_torch/models/checkpoints/mx/init.py
- tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py
- setup.py
- tensorrt_llm/_torch/models/checkpoints/init.py
- tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
- tests/unittest/api_stability/references/llm.yaml
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/executor/base_worker.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py
| try: | ||
| patch_empty_cache() | ||
| except Exception as e: | ||
| logger.debug("GMS patch_empty_cache failed (non-fatal): %s", e) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's find and read the file to understand the context
find . -name "gpu_memory_backend.py" -type fRepository: NVIDIA/TensorRT-LLM
Length of output: 113
🏁 Script executed:
# Read the file to understand the context around lines 213-216
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '200,230p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1189
🏁 Script executed:
# Read more context including move_untracked_params() and the full connect() method
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '1,50p'Repository: NVIDIA/TensorRT-LLM
Length of output: 2733
🏁 Script executed:
# Get the full connect() method to understand the return behavior
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '180,225p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1956
🏁 Script executed:
# Find where patch_empty_cache() is defined
rg "def patch_empty_cache" --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if patch_empty_cache() is imported and where it comes from
rg "patch_empty_cache" tensorrt_llm/_torch/memory/gpu_memory_backend.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 590
🏁 Script executed:
# Search for any documentation or comments about what patch_empty_cache does
rg "patch_empty_cache" --type py -B 5 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 4942
🏁 Script executed:
# Check if there are any callers that depend on the return value of connect()
rg "\.connect\(\)" --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 345
Return False from connect() if patch_empty_cache() fails.
The patch prevents segfaults on GMS-backed VMM tensors when torch.cuda.empty_cache() is called (documented as critical in model_loader.py). If patching fails, the connection should fail (return False), not silently succeed. Currently the method returns True despite the patch failure, creating a latent crash condition for callers.
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 215-215: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 213 - 216, In
connect(), if patch_empty_cache() raises an exception, log the error (keep or
upgrade the current logger.debug message) and immediately return False so the
connection fails instead of returning True; locate the try/except around
patch_empty_cache() in the connect() method and add a return False in the except
block (reference: connect() and patch_empty_cache()).
| seen: set[int] = set() | ||
|
|
||
| with torch.no_grad(): | ||
| for _name, tensor, tensor_type in _iter_module_tensors(model): | ||
| if tensor_type != "parameter" or tensor is None or not tensor.is_cuda: | ||
| continue | ||
|
|
||
| storage_ptr = tensor.untyped_storage().data_ptr() | ||
| if storage_ptr in seen: | ||
| continue | ||
| seen.add(storage_ptr) | ||
|
|
||
| if _ptr_in_gms(gms_client, int(tensor.data_ptr())): | ||
| continue | ||
|
|
||
| nbytes = _storage_nbytes(tensor) | ||
| base_va = gms_client.create_mapping(size=nbytes, tag=self._tag) | ||
| replacement = _tensor_from_pointer( | ||
| int(base_va), | ||
| list(tensor.shape), | ||
| list(tensor.stride()), | ||
| tensor.dtype, | ||
| device_index, | ||
| ) | ||
| replacement.copy_(tensor) | ||
| tensor.data = replacement |
There was a problem hiding this comment.
Rebind shared storages instead of skipping them.
seen deduplicates by the original storage pointer, but after the first tensor is migrated only that parameter is rebound to the GMS mapping. Any later alias that shared the same storage is skipped and keeps pointing outside the pool, which breaks tied weights and can leave finalize_write() with mixed tracked/untracked parameters.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 296 - 321, The
loop currently uses seen: set[int] and skips subsequent tensors sharing the same
storage_ptr, leaving aliases pointing outside GMS; change seen to a dict mapping
original storage_ptr -> replacement tensor (or its base_va) so after you create
replacement via gms_client.create_mapping and _tensor_from_pointer (and do
replacement.copy_), you store seen[storage_ptr] = replacement and for any later
tensor with the same storage_ptr (found in _iter_module_tensors) you assign
tensor.data = seen[storage_ptr] instead of continuing; keep the existing
_ptr_in_gms check (and treat tensors already in GMS as already-mapped) and reuse
_storage_nbytes, _tensor_from_pointer, replacement.copy_, and tensor.data = ...
symbols to locate where to implement this change.
| env_overrides = { | ||
| "MODEL_EXPRESS_URL": self._mx_server_url, | ||
| "MODEL_NAME": resolved_name, | ||
| } | ||
| prior = {key: os.environ.get(key) for key in env_overrides} | ||
| for key, value in env_overrides.items(): | ||
| os.environ[key] = value | ||
|
|
||
| try: | ||
| publish_model_params(model) | ||
| logger.info( | ||
| "Published weights to MX server at %s as model=%r", | ||
| self._mx_server_url, | ||
| resolved_name, | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| "Failed to publish weights to MX server at %s: %s", | ||
| self._mx_server_url, | ||
| e, | ||
| ) | ||
| finally: | ||
| for key, prior_value in prior.items(): | ||
| if prior_value is None: | ||
| os.environ.pop(key, None) | ||
| else: | ||
| os.environ[key] = prior_value |
There was a problem hiding this comment.
Serialize the env-var override around MX publishing.
MODEL_EXPRESS_URL and MODEL_NAME live in process-global os.environ, so two concurrent publish_as_source() calls can interleave here and publish a model under the wrong identity or server URL. Please guard the override/publish/restore sequence with a module-level lock.
🔒 Proposed fix
+import threading
...
+_MX_PUBLISH_ENV_LOCK = threading.Lock()
+
...
- prior = {key: os.environ.get(key) for key in env_overrides}
- for key, value in env_overrides.items():
- os.environ[key] = value
-
- try:
- publish_model_params(model)
- logger.info(
- "Published weights to MX server at %s as model=%r",
- self._mx_server_url,
- resolved_name,
- )
- except Exception as e:
- logger.warning(
- "Failed to publish weights to MX server at %s: %s",
- self._mx_server_url,
- e,
- )
- finally:
- for key, prior_value in prior.items():
- if prior_value is None:
- os.environ.pop(key, None)
- else:
- os.environ[key] = prior_value
+ with _MX_PUBLISH_ENV_LOCK:
+ prior = {key: os.environ.get(key) for key in env_overrides}
+ for key, value in env_overrides.items():
+ os.environ[key] = value
+
+ try:
+ publish_model_params(model)
+ logger.info(
+ "Published weights to MX server at %s as model=%r",
+ self._mx_server_url,
+ resolved_name,
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to publish weights to MX server at %s: %s",
+ self._mx_server_url,
+ e,
+ )
+ finally:
+ for key, prior_value in prior.items():
+ if prior_value is None:
+ os.environ.pop(key, None)
+ else:
+ os.environ[key] = prior_value🧰 Tools
🪛 Ruff (0.15.10)
[warning] 310-310: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
295 - 321, The env var override + publish sequence in checkpoint_loader (the
block that sets MODEL_EXPRESS_URL/MODEL_NAME, calls publish_model_params(model),
and restores os.environ) must be serialized with a module-level lock to prevent
concurrent publish_as_source()/publish_model_params calls from interleaving; add
a module-level threading.Lock (e.g., _env_override_lock) and wrap the entire
override/publish/restore sequence in a with _env_override_lock: so the
try/except/finally and the restore logic run while holding the lock; ensure the
lock is declared at module scope so other functions in this module can reuse it.
| def _effective_tp_coords(module: Linear) -> tuple[int, int]: | ||
| """Return (tp_size, tp_rank) accounting for ``_weights_presharded``. | ||
|
|
||
| When a module's weights are already TP-sharded for the local rank | ||
| (e.g. delivered via MX P2P transfer), TRT-LLM's standard TP slicing | ||
| must be skipped for *all* tensors associated with the module — the | ||
| main weight and bias *and* their quantization metadata | ||
| (weight_scale, input_scale, pre_quant_scale, etc.). Otherwise | ||
| full-size weights would be paired with re-sliced metadata, causing | ||
| shape mismatches or silent numerical errors on quantized models. | ||
| """ | ||
| if getattr(module, '_weights_presharded', False): | ||
| return 1, 0 | ||
| return module.tp_size, module.tp_rank |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Assignments to _weights_presharded:"
rg -n -C2 '_weights_presharded\s*=' tensorrt_llm
echo
echo "Potential reload / reuse entry points:"
rg -n -C3 '\b(pre_reload_weights|load_weights_from_target_model|load_weights\()' tensorrt_llm/_torchRepository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the actual load_weights method in Linear to see if it clears the flag
echo "=== Linear.load_weights implementation ==="
sed -n '2600,2700p' tensorrt_llm/_torch/modules/linear.py | head -50
echo
echo "=== Check for pre_reload_weights in Linear ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/modules/linear.py
echo
echo "=== Check if modules are reconstructed or reused in model_loader ==="
rg -n -B5 -A5 'module._weights_presharded.*True' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -40
echo
echo "=== Check module init flow in model_loader ==="
rg -n 'Linear\(' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 3020
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# View the pre_reload_weights and load_weights methods in Linear
echo "=== Linear.pre_reload_weights (lines 2790-2800) ==="
sed -n '2790,2810p' tensorrt_llm/_torch/modules/linear.py
echo
echo "=== Linear.load_weights signature and start (lines 2795+) ==="
sed -n '2795,2850p' tensorrt_llm/_torch/modules/linear.py | head -60
echo
echo "=== Search for load_weights in Linear with context ==="
rg -n 'def load_weights' tensorrt_llm/_torch/modules/linear.py -A 10 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 4646
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the actual Linear class's load_weights method
echo "=== Find Linear class load_weights ==="
rg -n 'class Linear' tensorrt_llm/_torch/modules/linear.py | head -5
echo
echo "=== Linear.load_weights in class (searching after line 2300) ==="
sed -n '2700,2800p' tensorrt_llm/_torch/modules/linear.py | grep -n 'def load_weights' -A 20 | head -40
echo
echo "=== Try broader search for load_weights in Linear class ==="
sed -n '2750,2795p' tensorrt_llm/_torch/modules/linear.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3123
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if pre_reload_weights is called in the model_loader before load_weights
echo "=== Check model_loader for pre_reload_weights calls ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/pyexecutor/model_loader.py -B 3 -A 3
echo
echo "=== Check modeling_utils for pre_reload_weights calls ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/models/modeling_utils.py -B 3 -A 3
echo
echo "=== Verify that _weights_presharded is NOT reset anywhere ==="
rg -n '_weights_presharded\s*=\s*False' tensorrt_llmRepository: NVIDIA/TensorRT-LLM
Length of output: 118
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# More comprehensive check
echo "=== All calls to pre_reload_weights ==="
rg -r 'pre_reload_weights' tensorrt_llm --type py
echo
echo "=== All assignments to _weights_presharded ==="
rg '_weights_presharded\s*=' tensorrt_llm --type py
echo
echo "=== Check if modules.load_weights is called after MX P2P sets the flag ==="
sed -n '450,480p' tensorrt_llm/_torch/pyexecutor/model_loader.py
echo
echo "=== Context around weight loading after MX P2P ==="
sed -n '420,470p' tensorrt_llm/_torch/pyexecutor/model_loader.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 50377
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Focus on the actual code flow - check if modules are reconstructed or reused
echo "=== Check if initialize_dummy_weights or model creation reconstructs Linear modules ==="
sed -n '490,500p' tensorrt_llm/_torch/pyexecutor/model_loader.py
echo
echo "=== Check the broader load flow - whether modules persist across loads ==="
sed -n '400,435p' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -40
echo
echo "=== Confirm: no reset of _weights_presharded before load_weights ==="
grep -n 'def load_weights' tensorrt_llm/_torch/modules/linear.py -A 20 | grep -E '(_weights_presharded|pre_reload)' || echo "No reset found"Repository: NVIDIA/TensorRT-LLM
Length of output: 2816
Clear the _weights_presharded flag before loading non-presharded weights, or ensure modules are reconstructed on reload.
The flag is initialized to False only in __init__ and set to True by MX P2P and GMS paths, but is never reset before subsequent loads. If a module undergoes presharded loading and is then reloaded from a regular checkpoint, TP slicing will incorrectly stay disabled, causing _effective_tp_coords() to return (1, 0) and full-size weights to be copied into per-rank buffers with catastrophic shape mismatches or silent numerical errors on quantized models.
Either reset the flag before non-presharded loads in Linear.load_weights(), verify that all reload paths reconstruct Linear instances, or call pre_reload_weights() before each load to clear any state.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/linear.py` around lines 178 - 191, The
_weights_presharded flag set by MX P2P/GMS is never cleared, so
_effective_tp_coords() can incorrectly return (1, 0) on subsequent
non-presharded loads; update Linear.load_weights (or add/ensure a call to
pre_reload_weights) to clear module._weights_presharded (or reconstruct the
Linear instance) before loading regular checkpoint weights so TP slicing is
re-enabled; specifically, ensure any reload path that calls Linear.load_weights
resets _weights_presharded (or invokes pre_reload_weights) so
_effective_tp_coords, the main weight/bias and quantization metadata are sliced
consistently.
| if not gms_backend.connect(): | ||
| raise RuntimeError("Failed to connect to GMS at " | ||
| f"{self.llm_args.gms_socket_path}") | ||
|
|
||
| if gms_backend.is_rw: | ||
| # GMS RW path: load via checkpoint_loader under GMS | ||
| # memory pool so allocations go into shared memory. | ||
| # Note: MX P2P is NOT used in GMS RW mode because | ||
| # parameters are still meta tensors (no CUDA buffers | ||
| # for P2P to write into) and allocations must go through | ||
| # the GMS pool. The checkpoint_loader loads from disk, | ||
| # with the GMS pool intercepting all CUDA allocations. | ||
| device = torch.device('cuda') | ||
|
|
||
| with gms_backend.mem_pool_scope(device): | ||
| load_weights_kwargs = {"mapping": self.mapping} | ||
| # Honor model.llm_checkpoint_dir when the model | ||
| # rewrites its checkpoint source — same precedence | ||
| # as the AUTO branch above. Without this, GMS-only | ||
| # would silently load from the original | ||
| # checkpoint_dir for redirect-using model classes. | ||
| weight_source = (model.llm_checkpoint_dir if hasattr( | ||
| model, 'llm_checkpoint_dir') else checkpoint_dir) | ||
| weights = checkpoint_loader.load_weights( | ||
| weight_source, **load_weights_kwargs) | ||
|
|
||
| if weights: | ||
| self.weight_mapper = ( | ||
| checkpoint_loader.get_initialized_weight_mapper( | ||
| model, config)) | ||
| self._call_load_weights(model.load_weights, weights, | ||
| self.weight_mapper) | ||
|
|
||
| # Speculative decoding: load draft weights INSIDE | ||
| # the same mem_pool_scope so they land in the GMS | ||
| # pool and are picked up by finalize_write's model | ||
| # walk. (Deferring commit until after both main and | ||
| # draft weights are in the pool avoids the otherwise- | ||
| # required separate commit-per-model dance flagged | ||
| # in PR #13045 review.) | ||
| if (self.spec_config is not None and self.spec_config. | ||
| spec_dec_mode.need_load_draft_weights()): | ||
| draft_weights = checkpoint_loader.load_weights( | ||
| self.spec_config.speculative_model, | ||
| mapping=self.mapping) | ||
|
|
||
| draft_model_arch = ( | ||
| model.draft_config.pretrained_config. | ||
| architectures[0]) | ||
| draft_weight_mapper = AutoCheckpointMapper.get( | ||
| checkpoint_loader.checkpoint_format, | ||
| draft_model_arch) | ||
| draft_weight_mapper.init_model_and_config( | ||
| model.draft_model, model.draft_config) | ||
|
|
||
| self._call_load_weights(model.load_draft_weights, | ||
| draft_weights, | ||
| draft_weight_mapper) | ||
|
|
||
| # Move any parameters that landed outside the GMS | ||
| # pool (e.g., buffers created during weight-loading | ||
| # transforms) into the pool, then drain the caching | ||
| # allocator. Order and pool-scope membership mirror | ||
| # the upstream gpu_memory_service.integrations.trtllm | ||
| # ``_load_rw`` reference path: | ||
| # https://github.com/ai-dynamo/dynamo/blob/main/lib/gpu_memory_service/integrations/trtllm/model_loader.py | ||
| gms_backend.move_untracked_params(model) | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Commit weights for RO readers and transition to RO mode. | ||
| # finalize_write walks the full model tree (including | ||
| # model.draft_model when present) so a single commit | ||
| # covers both main and draft weights. | ||
| gms_backend.finalize_write(model) | ||
| logger.info( | ||
| "LoadFormat.GMS (RW): loaded and committed " | ||
| "weights via %s", checkpoint_loader.checkpoint_format) | ||
| else: | ||
| # GMS RO path: zero-copy import from existing GMS pool. | ||
| # post_load_weights() must run BEFORE materialize so | ||
| # that module aliases are set up correctly. | ||
| for module in model.modules(): | ||
| if hasattr(module, 'post_load_weights') and not getattr( | ||
| module, '_weights_removed', False): | ||
| module.post_load_weights() | ||
|
|
||
| gms_backend.materialize_module(model) | ||
| # Note: MoE load balancer finalization (lines below) works | ||
| # because register functions access module attributes at | ||
| # call time, resolving to materialized CUDA tensors. | ||
| # The patch_empty_cache() applied during connect() is | ||
| # critical here — MoE's make_tensor_host_accessible() | ||
| # calls torch.cuda.empty_cache() which would segfault | ||
| # on VMM-backed GMS tensors without the patch. | ||
| logger.info("LoadFormat.GMS (RO): zero-copy materialized " | ||
| "weights") | ||
|
|
||
| # Store backend reference for cleanup during shutdown. | ||
| self._gms_backend = gms_backend |
There was a problem hiding this comment.
Unwind the GMS connection on partial-load failures.
connect() succeeds at Line 512, but self._gms_backend is not assigned until Line 609. If load_weights(), the draft-weight path, finalize_write(), or materialize_module() raises first, the live GMS handle and its CUDA hooks leak, and cleanup() cannot recover because _gms_backend is still None. Store the backend immediately after a successful connect() and clean it up in a try/except before re-raising.
Possible fix
if not gms_backend.connect():
raise RuntimeError("Failed to connect to GMS at "
f"{self.llm_args.gms_socket_path}")
+ self._gms_backend = gms_backend
- if gms_backend.is_rw:
- # GMS RW path: load via checkpoint_loader under GMS
- # memory pool so allocations go into shared memory.
- # Note: MX P2P is NOT used in GMS RW mode because
- # parameters are still meta tensors (no CUDA buffers
- # for P2P to write into) and allocations must go through
- # the GMS pool. The checkpoint_loader loads from disk,
- # with the GMS pool intercepting all CUDA allocations.
- device = torch.device('cuda')
+ try:
+ if gms_backend.is_rw:
+ # GMS RW path: load via checkpoint_loader under GMS
+ # memory pool so allocations go into shared memory.
+ device = torch.device('cuda')
- with gms_backend.mem_pool_scope(device):
- ...
+ with gms_backend.mem_pool_scope(device):
+ ...
- gms_backend.finalize_write(model)
- logger.info(
- "LoadFormat.GMS (RW): loaded and committed "
- "weights via %s", checkpoint_loader.checkpoint_format)
- else:
- ...
-
- # Store backend reference for cleanup during shutdown.
- self._gms_backend = gms_backend
+ gms_backend.finalize_write(model)
+ logger.info(
+ "LoadFormat.GMS (RW): loaded and committed "
+ "weights via %s", checkpoint_loader.checkpoint_format)
+ else:
+ ...
+ except Exception:
+ gms_backend.cleanup()
+ self._gms_backend = None
+ raise🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 512 - 610,
Assign self._gms_backend immediately after gms_backend.connect() returns True
and then wrap the subsequent GMS load logic (the RW/RO branches,
finalize_write/materialize_module calls, and any draft-weight handling) in a
try/except that on exception clears self._gms_backend and calls the backend’s
cleanup/disconnect method (e.g., gms_backend.disconnect() or the appropriate
close/cleanup API) before re-raising; this ensures the live GMS handle and CUDA
hooks are released if checkpoint_loader.load_weights, finalize_write,
materialize_module, or other calls fail.
| if isinstance(v, int): | ||
| return LoadFormat(v) | ||
| load_format = v.upper() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
from enum import Enum
class LoadFormat(Enum):
AUTO = 0
DUMMY = 1
VISION_ONLY = 2
GMS = 3
print("bool_is_int:", isinstance(True, int))
print("enum_from_true:", LoadFormat(True))
PYRepository: NVIDIA/TensorRT-LLM
Length of output: 113
🏁 Script executed:
# First, let's examine the convert_load_format method in the file
sed -n '3820,3850p' tensorrt_llm/llmapi/llm_args.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1326
🏁 Script executed:
# Check the mx_preshard_strategy field definition
rg -A 5 'mx_preshard_strategy' tensorrt_llm/llmapi/llm_args.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 1126
🏁 Script executed:
# Look for existing test file and tests related to load_format
find . -path '*/tests/unittest/llmapi/test_mx_gms_args.py' -exec cat {} \;Repository: NVIDIA/TensorRT-LLM
Length of output: 16427
🏁 Script executed:
# Also check if there are other load_format validator tests
rg -A 10 'load_format.*True\|test.*load_format' tests/ --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
Reject booleans before the new integer load_format path.
bool is an int in Python, so a config like load_format: true flows through LoadFormat(v) as LoadFormat.DUMMY (since True == 1) instead of failing fast. This silently switches the loader into random-weight initialization.
Proposed fix
`@field_validator`('load_format', mode='before')
`@classmethod`
def convert_load_format(cls, v):
if isinstance(v, LoadFormat):
return v
+ if isinstance(v, bool):
+ raise ValueError(f"Invalid LoadFormat: {v}")
if isinstance(v, int):
return LoadFormat(v)
+ if not isinstance(v, str):
+ raise ValueError(f"Invalid LoadFormat: {v}")
load_format = v.upper()
if load_format not in LoadFormat.__members__:
raise ValueError(f"Invalid LoadFormat: {v}")
return LoadFormat[load_format]Add a regression test in tests/unittest/llmapi/test_mx_gms_args.py for load_format=True.
Also, change mx_preshard_strategy: str = Field(...) to mx_preshard_strategy: Literal["per_module", "global"] = Field(...) to match the constrained value set, per Pydantic field typing guidelines.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/llmapi/llm_args.py` around lines 3833 - 3835, Reject boolean
inputs before treating numbers as LoadFormat: in the load_format validator (the
block that currently does "if isinstance(v, int): return LoadFormat(v)"), add an
explicit check for isinstance(v, bool) and raise a ValidationError/TypeError so
True/False do not get interpreted as integers; keep the existing int path to map
valid integers to LoadFormat. Add a regression test named
test_load_format_rejects_bool in test_mx_gms_args.py that passes
load_format=True and asserts a validation failure. Also change the field
declaration for mx_preshard_strategy from "mx_preshard_strategy: str =
Field(...)" to "mx_preshard_strategy: Literal['per_module','global'] =
Field(...)" so Pydantic enforces the allowed values.
|
PR_Github #45034 [ run ] completed with state
|
|
Opened another PR (#13531) to extract the MX-only part. Changing this PR back to the draft mode to keep it as a reference; no need for code review for this PR. |
Summary by CodeRabbit
New Features
Tests
Chores
[TRTLLM-11851][feat] MX and GMS integration prototype for Dynamo weight sharing
Description
Summary
Prototype implementation of the two-axis integration model for MX (ModelExpress P2P transfer) and GMS (GPU Memory Service) in TRT-LLM's PyTorch backend. Enables fast cold-start via GPU-to-GPU weight streaming (MX) and zero-copy weight sharing with crash-resilient failover (GMS).
MX and GMS map onto two independent axes in TRT-LLM's existing loading pipeline:
checkpoint_format(weight source axis):"MX"— P2P via upstreamMxLiveWeightLoader.load_weights(), with automatic HF disk fallback (inherited fromHfCheckpointLoader)LoadFormat(memory management axis):GMS— out-of-process GPU memory pool with RW/RO dual paths via upstreamgpu_memory_serviceclass-based APIThese compose independently, giving four modes with no combinatorial config explosion:
checkpoint_formatLoadFormat"HF"(default)AUTO(default)"MX"AUTO"HF"GMS"MX"GMSAligned against the actually-merged upstream APIs:
gpu-memory-service0.9.0 (fromai-dynamo/dynamo) andmodelexpress0.3.0 (fromai-dynamo/modelexpress). TRT-LLM owns the integration policy (Layer 3) and calls the upstream libraries' stable per-call primitives (Layer 2); we never duplicate the wire-protocol or runtime mechanics (Layer 1 — gRPC, NIXL RDMA, CUDA VMM).What's included
MXCheckpointLoader (
_torch/models/checkpoints/mx/checkpoint_loader.py, ~380 lines)HfCheckpointLoader— HF disk fallback is inherited, no separate fallback loader neededMxLiveWeightLoader(mx_server=url).load_weights(checkpoint_dir, mapping=, model=)(handles agent setup, source matching, dtype-cast handling, PVC fallback for size-mismatched tensors)publish_model_params(torch_model)via an env-var dance overMODEL_EXPRESS_URLandMODEL_NAMEp2p_succeededproperty — consumed byModelLoaderto set_weights_preshardedonLinearmodulespost_load_weights()MX source publish timing so targets receive raw loaded state and run their own transformsmodel_nameconstructor parameter (plumbed fromllm_args.model) so publish identity comes from the user-supplied Hub ID with HF-snapshot path unmangling, not the upstream"unknown"sentinelMX_SOURCE_QUERY_TIMEOUT=30setdefault when an MX URL is configured, capping upstreamMxLiveWeightLoader._query_sourcepolling at 30 s instead of the 1-hour default. The correct upstream fix is to adoptRdmaStrategy's immediate-fallback pattern (singlelist_sourcescall, no polling) — tracked as MX-4GMSBackend (
_torch/memory/gpu_memory_backend.py, ~440 lines)GPUMemoryBackendprotocol + concreteGMSBackendimplementation wrapping the upstreamGMSClientMemoryManagerand thegms_use_mem_poolcontext managerconnect()time viaget_or_create_gms_client_memory_manager(socket, device, mode=RequestedLockType.RW_OR_RO)and inspected viagranted_lock_typegms_backend.mem_pool_scope(device)context manager; after load,move_untracked_params()migrates stray buffers into the pool (inside the scope, matching upstream_load_rworder), thenfinalize_write()delegates to upstreamfinalize_gms_write()(register + sync + commit + RO reconnect + remap)post_load_weights()runs beforematerialize_module(), guard prevents double-execution;materialize_module_from_gms(mgr, model, device_index=N)does the zero-copy import"weights"matching the GMS library conventionget_socket_path(device, tag)(UUID-based, stable acrossCUDA_VISIBLE_DEVICES)connect()appliespatch_empty_cache()to prevent segfaults on VMM-backed allocationsModelLoader changes (
_torch/pyexecutor/model_loader.py, +255/-19)LoadFormat.GMSbranch with full RW/RO dual pathsmodel.llm_checkpoint_dirredirect (matches AUTO branch); loads speculative-decoding draft weights inside the samemem_pool_scopeso a singlefinalize_write(model)covers both main and draft via the model-tree walkLoadFormat.AUTOand GMS-RW modes, beforepost_load_weights()_weights_preshardedflag set per-path (MX P2P in AUTO, GMS materialize in RO)Linear module (
_torch/modules/linear.py)_weights_presharded = Falsedeclared in__init__— TP slicing skipped when True_effective_tp_coords(module)helper centralized TP-coordinate logic: when_weights_preshardedis True, returns(tp_size=1, tp_rank=0)for allload_weight_shard(...)calls — both the main weight/bias tensors AND their quantization metadata (weight_scale,input_scale,pre_quant_scale, etc.) across all quant-scheme loaders (FP8, NVFP4, MXFP4, W4A16-AWQ, W4A8-AWQ, etc.)Configuration (
llmapi/llm_args.py)mx_server_url,mx_preshard_strategy,gms_socket_path,gms_mode,gms_tag— allstatus="prototype"validate_mx_config(),validate_gms_config()— honor upstreamMODEL_EXPRESS_URLenv var as fallback formx_server_url; reject empty/whitespacegms_tagvaluesmx_preshard_strategy: per_module | global(defaultper_module);globalraisesNotImplementedErroruntilLoadFormat.PRESHARDEDlands upstreamOptional dependencies (
setup.py)The MX (
modelexpress) and GMS (gpu-memory-service) Python packages are intentionally NOT bundled as[mx]/[gms]/[dynamo]extras while the integration is at prototype status. Users install manually:Once installed, the MX/GMS code paths auto-detect them via
import; without them, the loaders fall back to plain HF disk loading.Why no extras: (1)
gpu-memory-serviceis not published to PyPI (GMS-6), so NVIDIA's Blossom-CI vulnerability scanner can't resolve it; (2)modelexpressis brand-new on PyPI and not yet in the internal OSS allowlist (MX-7). Both blockers are tracked in §15 of the design doc. Restoringpip install tensorrt_llm[dynamo]is a single-hunk revert once both are cleared.Running with MX or GMS enabled
The MX/GMS knobs are set via a YAML config passed to
trtllm-serve --config, or as keyword arguments when using the PythonLLMAPI directly.checkpoint_format"MX"to enable MX P2P loading. DefaultNone→"HF".mx_server_urlMODEL_EXPRESS_URLenv var.load_formatLoadFormat"GMS"to enable GMS-backed loading. Default"AUTO".gms_socket_pathNone→ UUID-keyed auto-resolve.gms_modeauto(default),rw,ro.gms_tag"weights".MX only:
GMS only:
Python API:
Design decisions
setup_gms()entry point patchesModelLoader.loadfrom outside. We useGMSBackendas the explicit, reviewable boundary instead._weights_preshardedinstead ofLoadFormat.PRESHARDED. Composes more cleanly withLoadFormat.GMSand lets us mark only MX-delivered modules in mixed-success scenarios._effective_tp_coords(module)ensures_weights_preshardedskips TP slicing for weights AND their quantization metadata (weight_scale,input_scale,pre_quant_scale) — preventing shape mismatches on quantized models.Relationship to existing work
_weights_preshardedconcept and pre-post_load_weights()publish timing, using two-axis model instead ofLoadFormat.PRESHARDED.MODEL_EXPRESS_URLenv-var fallback, theMX_SOURCE_QUERY_TIMEOUTdefensive default (MX-4), and themodel_nameplumbing.docs/design/mx-gms-integration/on thedocs-and-plansbranch. §15 has the validation plan, upstream alignment requests, and test matrix.Upstream alignment requests
Workarounds in the prototype that could be removed if the corresponding upstream change lands. Full details in §15.
LoadFormat.PRESHARDEDvs per-module flagmx_preshard_strategy='global'raises until upstream lands_build_trtllm_identityto public APIpublish_as_sourceMX-3Per-rank addressingRdmaStrategyimmediate-fallback inMxLiveWeightLoaderMX_SOURCE_QUERY_TIMEOUT=30(30 s poll vs correct try-once)modelexpressinto NVIDIA OSS allowlistsetup.pyGMSBackendowns the integration_move_untracked_paramsto publicGMSBackendgpu-memory-serviceto PyPIsetup.py; install from sourceMX-1 and MX-5 are the design items most likely to come up in review. MX-7 and GMS-6 block restoring
pip install tensorrt_llm[dynamo]ergonomics.Test Coverage
148 CPU-only unit tests in
tests/unittest/, runs in <1 s, no MX/GMS deps required:tests/unittest/llmapi/test_mx_gms_args.py— config validation, defaults, cross-field warnings,MODEL_EXPRESS_URLenv-var fallback,gms_tagempty-string rejection,LoadFormatint-conversion regressiontests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py— construction + registry, parametrized fallback paths, MX-success integration contract,MX_SOURCE_QUERY_TIMEOUTdefensive default,model_nameplumbing,_normalize_model_identity(HF-snapshot unmangling),_resolve_mx_model_namepriority order,publish_as_sourceMODEL_NAMEend-to-end with env restorationtests/unittest/_torch/memory/test_gms_backend.py— construction, pre-connect state,connect()graceful failure, RW-only method gating,_ptr_in_gmshalf-open intervals + zero-sentinel guards,_storage_nbytes, protocol conformance. All tests run on CPU CI (CUDA dependency monkeypatched via autouse fixture).Tests do NOT exercise the upstream libraries — fallback paths use
sys.modules-injection to forceImportError, success paths use fake module trees.PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.