Skip to content

[TRTLLM-11851][feat] MX and GMS integration MVP for Dynamo weight sharing#13045

Draft
chienchunhung wants to merge 3 commits intoNVIDIA:mainfrom
chienchunhung:dynamo-integration-prototype
Draft

[TRTLLM-11851][feat] MX and GMS integration MVP for Dynamo weight sharing#13045
chienchunhung wants to merge 3 commits intoNVIDIA:mainfrom
chienchunhung:dynamo-integration-prototype

Conversation

@chienchunhung
Copy link
Copy Markdown
Collaborator

@chienchunhung chienchunhung commented Apr 14, 2026

Summary by CodeRabbit

  • New Features

    • GMS (GPU Memory Service) backend for disaggregated GPU weight memory
    • MX P2P weight-transfer support with automatic fallback and presharded-weight handling for Linear modules
    • New LoadFormat.GMS and user-facing config options: mx_server_url, mx_preshard_strategy, gms_socket_path, gms_mode, gms_tag
  • Tests

    • Added comprehensive unit tests for GMS backend, MX checkpoint loader, and MX/GMS argument validation
  • Chores

    • Packaging metadata updated (SPDX year) and clarifying install guidance added to setup notes

[TRTLLM-11851][feat] MX and GMS integration prototype for Dynamo weight sharing

Description

Summary

Prototype implementation of the two-axis integration model for MX (ModelExpress P2P transfer) and GMS (GPU Memory Service) in TRT-LLM's PyTorch backend. Enables fast cold-start via GPU-to-GPU weight streaming (MX) and zero-copy weight sharing with crash-resilient failover (GMS).

MX and GMS map onto two independent axes in TRT-LLM's existing loading pipeline:

  • checkpoint_format (weight source axis): "MX" — P2P via upstream MxLiveWeightLoader.load_weights(), with automatic HF disk fallback (inherited from HfCheckpointLoader)
  • LoadFormat (memory management axis): GMS — out-of-process GPU memory pool with RW/RO dual paths via upstream gpu_memory_service class-based API

These compose independently, giving four modes with no combinatorial config explosion:

Mode checkpoint_format LoadFormat Use case
Pure TRT-LLM "HF" (default) AUTO (default) Current behavior, unchanged
MX only "MX" AUTO Cross-node P2P weight transfer
GMS only "HF" GMS Within-node weight sharing + crash resilience
MX + GMS "MX" GMS Cross-node P2P + within-node sharing

Aligned against the actually-merged upstream APIs: gpu-memory-service 0.9.0 (from ai-dynamo/dynamo) and modelexpress 0.3.0 (from ai-dynamo/modelexpress). TRT-LLM owns the integration policy (Layer 3) and calls the upstream libraries' stable per-call primitives (Layer 2); we never duplicate the wire-protocol or runtime mechanics (Layer 1 — gRPC, NIXL RDMA, CUDA VMM).

What's included

MXCheckpointLoader (_torch/models/checkpoints/mx/checkpoint_loader.py, ~380 lines)

  • Subclasses HfCheckpointLoader — HF disk fallback is inherited, no separate fallback loader needed
  • Delegates the actual NIXL/RDMA transfer to upstream MxLiveWeightLoader(mx_server=url).load_weights(checkpoint_dir, mapping=, model=) (handles agent setup, source matching, dtype-cast handling, PVC fallback for size-mismatched tensors)
  • Delegates the publish side to upstream publish_model_params(torch_model) via an env-var dance over MODEL_EXPRESS_URL and MODEL_NAME
  • Exposes p2p_succeeded property — consumed by ModelLoader to set _weights_presharded on Linear modules
  • Pre-post_load_weights() MX source publish timing so targets receive raw loaded state and run their own transforms
  • model_name constructor parameter (plumbed from llm_args.model) so publish identity comes from the user-supplied Hub ID with HF-snapshot path unmangling, not the upstream "unknown" sentinel
  • Conservative mixed-success: if MX returns size-mismatched fallback weights, falls through to a full disk load to avoid mixing presharded and non-presharded weights
  • Defensive MX_SOURCE_QUERY_TIMEOUT=30 setdefault when an MX URL is configured, capping upstream MxLiveWeightLoader._query_source polling at 30 s instead of the 1-hour default. The correct upstream fix is to adopt RdmaStrategy's immediate-fallback pattern (single list_sources call, no polling) — tracked as MX-4

GMSBackend (_torch/memory/gpu_memory_backend.py, ~440 lines)

  • GPUMemoryBackend protocol + concrete GMSBackend implementation wrapping the upstream GMSClientMemoryManager and the gms_use_mem_pool context manager
  • Mode resolved at connect() time via get_or_create_gms_client_memory_manager(socket, device, mode=RequestedLockType.RW_OR_RO) and inspected via granted_lock_type
  • RW path: caller uses gms_backend.mem_pool_scope(device) context manager; after load, move_untracked_params() migrates stray buffers into the pool (inside the scope, matching upstream _load_rw order), then finalize_write() delegates to upstream finalize_gms_write() (register + sync + commit + RO reconnect + remap)
  • RO path: post_load_weights() runs before materialize_module(), guard prevents double-execution; materialize_module_from_gms(mgr, model, device_index=N) does the zero-copy import
  • Default tag is "weights" matching the GMS library convention
  • Default socket path resolves via get_socket_path(device, tag) (UUID-based, stable across CUDA_VISIBLE_DEVICES)
  • VMM safety: connect() applies patch_empty_cache() to prevent segfaults on VMM-backed allocations

ModelLoader changes (_torch/pyexecutor/model_loader.py, +255/-19)

  • New LoadFormat.GMS branch with full RW/RO dual paths
  • GMS RW: respects model.llm_checkpoint_dir redirect (matches AUTO branch); loads speculative-decoding draft weights inside the same mem_pool_scope so a single finalize_write(model) covers both main and draft via the model-tree walk
  • MX source publish fires for LoadFormat.AUTO and GMS-RW modes, before post_load_weights()
  • _weights_presharded flag set per-path (MX P2P in AUTO, GMS materialize in RO)
  • Draft model modules excluded from presharded flag

Linear module (_torch/modules/linear.py)

  • _weights_presharded = False declared in __init__ — TP slicing skipped when True
  • _effective_tp_coords(module) helper centralized TP-coordinate logic: when _weights_presharded is True, returns (tp_size=1, tp_rank=0) for all load_weight_shard(...) calls — both the main weight/bias tensors AND their quantization metadata (weight_scale, input_scale, pre_quant_scale, etc.) across all quant-scheme loaders (FP8, NVFP4, MXFP4, W4A16-AWQ, W4A8-AWQ, etc.)

Configuration (llmapi/llm_args.py)

  • mx_server_url, mx_preshard_strategy, gms_socket_path, gms_mode, gms_tag — all status="prototype"
  • Validators: validate_mx_config(), validate_gms_config() — honor upstream MODEL_EXPRESS_URL env var as fallback for mx_server_url; reject empty/whitespace gms_tag values
  • mx_preshard_strategy: per_module | global (default per_module); global raises NotImplementedError until LoadFormat.PRESHARDED lands upstream
  • API stability YAML updated

Optional dependencies (setup.py)

  • The MX (modelexpress) and GMS (gpu-memory-service) Python packages are intentionally NOT bundled as [mx] / [gms] / [dynamo] extras while the integration is at prototype status. Users install manually:

    # MX — Apache-2.0, on PyPI
    pip install "modelexpress>=0.3.0,<0.4.0"
    
    # GMS — Apache-2.0, source-only (not yet on PyPI)
    git clone https://github.com/ai-dynamo/dynamo
    pip install ./dynamo/lib/gpu_memory_service

    Once installed, the MX/GMS code paths auto-detect them via import; without them, the loaders fall back to plain HF disk loading.

  • Why no extras: (1) gpu-memory-service is not published to PyPI (GMS-6), so NVIDIA's Blossom-CI vulnerability scanner can't resolve it; (2) modelexpress is brand-new on PyPI and not yet in the internal OSS allowlist (MX-7). Both blockers are tracked in §15 of the design doc. Restoring pip install tensorrt_llm[dynamo] is a single-hunk revert once both are cleared.

Running with MX or GMS enabled

The MX/GMS knobs are set via a YAML config passed to trtllm-serve --config, or as keyword arguments when using the Python LLM API directly.

Field Type Purpose
checkpoint_format str "MX" to enable MX P2P loading. Default None"HF".
mx_server_url str URL of the ModelExpress server. Falls back to MODEL_EXPRESS_URL env var.
load_format str / LoadFormat "GMS" to enable GMS-backed loading. Default "AUTO".
gms_socket_path str Unix socket of the GMS daemon. Default None → UUID-keyed auto-resolve.
gms_mode str auto (default), rw, ro.
gms_tag str GMS pool tag. Default "weights".

MX only:

# config_mx.yaml
checkpoint_format: "MX"
mx_server_url: "http://mx-server:8001"
trtllm-serve <model> --config config_mx.yaml

GMS only:

# config_gms.yaml
load_format: "GMS"
# First worker: gets RW, loads from disk, commits to GMS pool
trtllm-serve <model> --config config_gms.yaml
# Subsequent workers: get RO, zero-copy import
trtllm-serve <model> --config config_gms.yaml

Python API:

from tensorrt_llm import LLM
llm = LLM(model="<model>", checkpoint_format="MX", mx_server_url="http://mx:8001")
# or
from tensorrt_llm.llmapi import LoadFormat
llm = LLM(model="<model>", load_format=LoadFormat.GMS)

Design decisions

  • No monkey-patching. The upstream setup_gms() entry point patches ModelLoader.load from outside. We use GMSBackend as the explicit, reviewable boundary instead.
  • MX P2P is NOT used in GMS RW mode. Model params are meta tensors at that point — no CUDA buffers for P2P to write into. Disk loading under the GMS pool is the correct behavior.
  • Per-module _weights_presharded instead of LoadFormat.PRESHARDED. Composes more cleanly with LoadFormat.GMS and lets us mark only MX-delivered modules in mixed-success scenarios.
  • Quantization-metadata-aware TP bypass. _effective_tp_coords(module) ensures _weights_presharded skips TP slicing for weights AND their quantization metadata (weight_scale, input_scale, pre_quant_scale) — preventing shape mismatches on quantized models.

Relationship to existing work

  • PR #12898 (MX team prototype, closed): we adopt _weights_presharded concept and pre-post_load_weights() publish timing, using two-axis model instead of LoadFormat.PRESHARDED.
  • dynamo PR #7575 (TRT-LLM sleep/wake with GMS): we use the same Layer 2 primitives but invoke them directly, not via monkey-patch.
  • MX team's downstream PR chienchunhung/TensorRT-LLM #1: drove the MODEL_EXPRESS_URL env-var fallback, the MX_SOURCE_QUERY_TIMEOUT defensive default (MX-4), and the model_name plumbing.
  • Design doc: docs/design/mx-gms-integration/ on the docs-and-plans branch. §15 has the validation plan, upstream alignment requests, and test matrix.

Upstream alignment requests

Workarounds in the prototype that could be removed if the corresponding upstream change lands. Full details in §15.

ID To Title Workaround in this PR Blocks merge?
MX-1 MX LoadFormat.PRESHARDED vs per-module flag mx_preshard_strategy='global' raises until upstream lands No
MX-2 MX Promote _build_trtllm_identity to public API env-var dance in publish_as_source No
MX-3 MX Per-rank addressing Resolved — upstream uses global MPI rank on both sides
MX-4 MX Adopt RdmaStrategy immediate-fallback in MxLiveWeightLoader MX_SOURCE_QUERY_TIMEOUT=30 (30 s poll vs correct try-once) No
MX-5 MX NIXL ownership for MX+GMS composition MX P2P bypassed in GMS-RW path Yes for MX+GMS
MX-7 MX/TRT-LLM Onboard modelexpress into NVIDIA OSS allowlist extras removed from setup.py No
GMS-1 GMS Non-monkey-patch integration path GMSBackend owns the integration No
GMS-2 GMS Promote _move_untracked_params to public Re-implemented in GMSBackend No
GMS-6 GMS Publish gpu-memory-service to PyPI extras removed from setup.py; install from source No

MX-1 and MX-5 are the design items most likely to come up in review. MX-7 and GMS-6 block restoring pip install tensorrt_llm[dynamo] ergonomics.

Test Coverage

148 CPU-only unit tests in tests/unittest/, runs in <1 s, no MX/GMS deps required:

  • tests/unittest/llmapi/test_mx_gms_args.py — config validation, defaults, cross-field warnings, MODEL_EXPRESS_URL env-var fallback, gms_tag empty-string rejection, LoadFormat int-conversion regression
  • tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py — construction + registry, parametrized fallback paths, MX-success integration contract, MX_SOURCE_QUERY_TIMEOUT defensive default, model_name plumbing, _normalize_model_identity (HF-snapshot unmangling), _resolve_mx_model_name priority order, publish_as_source MODEL_NAME end-to-end with env restoration
  • tests/unittest/_torch/memory/test_gms_backend.py — construction, pre-connect state, connect() graceful failure, RW-only method gating, _ptr_in_gms half-open intervals + zero-sentinel guards, _storage_nbytes, protocol conformance. All tests run on CPU CI (CUDA dependency monkeypatched via autouse fixture).

Tests do NOT exercise the upstream libraries — fallback paths use sys.modules-injection to force ImportError, success paths use fake module trees.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 15, 2026
…th implementation

The protocol listing in Section 5 (challenge 8) had stale method signatures
(commit/upgrade_lock) that diverged from the canonical implementation in
Section 4 and the prototype PR NVIDIA#13045. Updated to match: finalize_write
takes tag parameter, cleanup replaces commit+upgrade_lock.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 18, 2026
…13045

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 18, 2026
Update Section 15 (Prototype Validation Plan) with the actual progress
of integrating PR NVIDIA#13045 with the §10/§11 benchmark infrastructure:

- Replace the speculative Branch Strategy with concrete branch info:
  dynamo/proto-rebased (prototype on current upstream/main, 0 rebase
  conflicts) and dynamo/proto-bench-integration-v2 (= proto-rebased
  + 7 bench commits cherry-picked). Both pushed to fork; the two
  original branches (dynamo-integration-prototype and
  dynamo/startup-profiling) are untouched.

- Document all 5 conflict resolutions (py_executor_creator.py +
  model_loader.py) on the foundational bench commit. Strategy: keep
  prototype semantics, add §10 timers without changing control flow.

- Add an Execution Status section capturing:
    * Phase A done (branch integration + smoke verification)
    * Smoke confirms M1 (AUTO/HF) baseline path is unaffected and
      bench profiler captures the full hierarchy on the integrated
      branch (Qwen 7B TP=1 warm NFS: 67.6s server / 36.3s worker).
    * Phase B blocker: prototype's GMSBackend uses a high-level API
      (gms_client.connect, get_mem_pool, materialize_module_from_gms
      as top-level functions) that does NOT exist in the actually-
      merged ai-dynamo/dynamo PR NVIDIA#7575. Merged GMS exposes a class-
      based API (GMSClientMemoryManager + gms_use_mem_pool context
      manager) plus an official setup_gms() monkey-patch integration
      at gpu_memory_service.integrations.trtllm. The prototype's
      "Align GMS backend with merged GMS library API (PR NVIDIA#7575)"
      commit (2026-04-16) was written against an earlier, un-merged
      iteration of the API. Three resolution paths laid out for
      decision.
    * MX install not feasible on current node (no Rust, no Docker);
      ai-dynamo/modelexpress would need separate node prep.

- Recommend resolving the GMS API mismatch first (cheaper, exercises
  ~half the prototype's new code) before MX node prep.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 20, 2026
… (DONE)

Update Section 15 to reflect the resolution of the GMS API mismatch
blocker that was open as of the previous Section 15 commit. PR NVIDIA#13045
has been:

1. Rebased onto current upstream/main (7b84136 — 95 commits ahead of
   the original prototype base 4a848cc). Zero rebase conflicts.

2. Refactored to align both adapters with what was actually merged
   upstream:
   - GMS (gpu_memory_service v0.9.0): use the class-based API
     (GMSClientMemoryManager + gms_use_mem_pool context manager +
     materialize_module_from_gms + finalize_gms_write) instead of the
     non-existent gms_client.connect() / get_mem_pool() / commit() /
     disconnect() convenience functions. Default tag changes from
     'model_weights' to 'weights' to match the GMS convention.
   - MX (modelexpress v0.3.0): delegate the actual NIXL transfer to
     MxLiveWeightLoader and publishing to publish_model_params instead
     of calling non-existent mx_client.connect() / list_sources() /
     receive() / register_source() helpers. Old SourceIdentity dict
     schema replaced by structured fields.

3. Added an mx_preshard_strategy config knob (per_module | global)
   defaulting to per_module. 'global' raises NotImplementedError until
   LoadFormat.PRESHARDED lands upstream.

4. Declared modelexpress and gpu-memory-service as setuptools optional
   extras with pinned upper bounds.

The new live working branch is dynamo-integration-prototype-rebased
(tip: 62ac40f). It will be force-pushed to dynamo-integration-prototype
once the alignment work is reviewed.

Pre-alignment branches (dynamo/proto-rebased, dynamo/proto-bench-integration-v2)
remain as historical snapshots and are slated for cleanup after Phase B
validation completes.

Doc changes:
- Status header updated from "Phase B blocked on GMS API alignment" to
  "Phase B ready once MX server + GMS daemon are reachable on a test node"
- Branch Strategy table adds the new dynamo-integration-prototype-rebased
  row and clarifies that the pre-alignment branches are historical
- ASCII tree updated to show the new branch on top of current upstream/main
- "Phase B blocked" subsection replaced with "API Alignment (DONE)"
  documenting what was wrong, what the alignment commit does, and the
  three-layer separation principle (TRT-LLM owns Layer 3; we call upstream
  Layer 2 primitives; we never duplicate Layer 1 wire protocol mechanics)
- Recommended Next Step now reflects that the API blocker is resolved
  and the next gates are environment-side (force-push the PR, GMS-only
  end-to-end test, MX setup on a separate node)

No changes to the Test Matrix, Verification Tests, Critical Diagnostic,
Service Setup Reference, or Benchmark Script Changes sections.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung chienchunhung force-pushed the dynamo-integration-prototype branch from 84dfb2a to 62ac40f Compare April 20, 2026 21:22
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 20, 2026
Update Sections 3, 4, 5, 6 to reflect the actually-merged GMS and MX
Python APIs that PR NVIDIA#13045 (commit 62ac40f) now uses. Section 15
already documents the alignment work itself; these edits bring the
rest of the design doc into sync so cross-references match the code.

Section 3 (Architecture):
- New replica startup diagram: torch.cuda.use_mem_pool(gms_pool) ->
  gms_backend.mem_pool_scope(device); add move_untracked_params() to
  the RW commit chain.
- Component roles diagram: GMS RW node now mentions
  mem_pool_scope(device) and move_untracked_params(); MX publish node
  shows it delegates to publish_model_params().
- Root-cause callout for the MX-in-GMS-RW limitation now points at
  MX-5 in Section 15 as the upstream alignment item.

Section 4 (Implementation & API Design):
- MXCheckpointLoader code block rewritten against current MX API:
  delegates load to MxLiveWeightLoader, delegates publish to
  publish_model_params (with the MODEL_EXPRESS_URL env-var dance for
  per-instance URLs), drops the non-existent mx_client.connect /
  list_sources / receive / register_source convenience calls and the
  proto.SourceIdentity(extra_params=dict) schema. Mixed-success
  conservative behavior documented.
- Identity schema section explains the new structured
  p2p_pb2.SourceIdentity fields and forward-references MX-2 / MX-3.
- GMS code block in ModelLoader.load() updated: gms_pool /
  use_mem_pool(gms_pool) -> mem_pool_scope(device); add
  move_untracked_params() before finalize_write(); add
  empty_cache() drain inside the scope (mirrors upstream RW reference);
  default tag changes to "weights" via GMSBackend.DEFAULT_TAG.
- GPUMemoryBackend protocol listing rewritten: get_mem_pool() ->
  MemPool replaced by mem_pool_scope(device) -> ContextManager;
  finalize_write() now returns int (bytes committed) and takes no
  tag arg; release() removed (cleanup() handles teardown);
  has_committed_weights() takes no arg.
- New "Layer-2 mapping" table shows each protocol method's upstream
  call so reviewers can see where API drift hits.
- "What TRT-LLM Implements vs GMS Library" table refreshed with
  new method names + the optional-extras install hint.
- Configuration section: added mx_preshard_strategy field with
  per_module/global semantics; gms_tag default changes to "weights"
  with comment pointing to GMS_TAGS upstream; gms_socket_path default
  comment points at upstream get_socket_path().
- Validators section updated to mention mx_preshard_strategy.
- Explicitly notes we do NOT use upstream setup_gms() monkey-patch
  and links to GMS-1 in Section 15.

Section 5 (Challenges):
- Challenge 3 (TP rank matching): code example rewritten to use the
  current MxClient.list_sources + get_metadata flow; adds a callout
  for MX-3 (per-rank addressing in identity/metadata schema).
- Challenge 6 (CUDA VMM Integration): use_mem_pool(gms_pool) ->
  mem_pool_scope(device); update method names to match the new
  GPUMemoryBackend protocol; clarify cleanup() responsibilities.
- Challenge 8 (GMS API Stability): full rewrite. Reflects that GMS
  has stabilized as of PR NVIDIA#7575 and our adapter calls the merged
  Layer 2 primitives. New protocol listing matches code. Layer-2
  mapping table added. Explicitly explains why we do NOT use the
  upstream setup_gms() monkey-patch entry point. References GMS-2
  (move_untracked_params upstream promotion).
- Challenge 9 (Transfer backend): client.receive() and
  client.register_source() examples replaced with current
  MxLiveWeightLoader.load_weights() and publish_model_params().

Section 6 (Executor & Failover):
- Section intro: "gms_client.upgrade_lock()" replaced with the
  actual primitive sequence (unmap_all_vas + abort + reconnect-RW
  + remap_all_vas).
- _shadow_loop pseudocode: drop the gms_client.heartbeat() call
  (the unix socket is the lock; OS keepalives suffice). Note that
  GMS-3 (peek RPC) would give a cheaper shadow-health signal here.
- _activate_from_shadow pseudocode: gms_client.upgrade_lock() ->
  gms_backend.upgrade_to_rw() with a comment that GMSBackend should
  grow this method as a one-call wrapper around the primitive
  sequence.
- _init_shadow_with_gms pseudocode: gms_client= -> gms_backend=,
  documents that this delegates to upstream
  materialize_module_from_gms(mgr, model, device_index=N).
- Sleep/wake mapping table: tag="model_weights" -> tag="weights",
  and each "GMS Library Call" column entry updated to show the
  actual current API call. Adds a note explaining the "weights" /
  "kv_cache" naming convention from GMS_TAGS.

Sections 1, 2, 7-11 had no stale GMS/MX API references and are
unchanged.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 21, 2026
…ounds

Add a dedicated "Upstream Alignment Requests" section to §15 with the
full MX-1..MX-6 / GMS-1..GMS-5 request inventory. The summary
table has the priority + workaround status + merge-blocking flag for
each item. Cross-references in §3, §4, §5, §6 (~8 of them) have been
updated to point to the new section's anchor.

Updates two of the existing entries to reflect new workarounds added
to PR NVIDIA#13045 in this same series of follow-up commits:

  MX-2 (promote _build_trtllm_identity to public API): expand the
    workaround note to mention that `publish_as_source()` now sets
    BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars temporarily (we
    used to only set MODEL_EXPRESS_URL). The two env-var dances
    would collapse into one direct call if MX-2 lands. The
    MODEL_NAME resolution itself is now plumbed cleanly through
    `llm_args.model → MXCheckpointLoader(model_name=...)` with HF
    snapshot path unmangling for the basename-fallback case.

  MX-4 (non-blocking source-query API): add the
    `MX_SOURCE_QUERY_TIMEOUT=30` defensive setdefault as the
    current workaround. `MXCheckpointLoader.__init__` calls
    `os.environ.setdefault(...)` whenever an MX URL is configured
    so first-replica cold-cluster startup degrades from 1 hour
    (upstream default) to 30s. setdefault semantics preserve any
    explicit user value.

The new section also captures items that were previously only in chat
context (MX-1, MX-3, MX-5, MX-6, GMS-1..5) so the cross-references
from §3-§6 actually resolve to a stable in-doc location.

No changes to the test matrix, verification tests, conflict-resolution
table, or the rest of §15.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 21, 2026
… timeout, model_name plumbing

Three discrete improvements to the MX side of PR NVIDIA#13045 driven by
review feedback from MX team's downstream PR
(chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes
landed as one focused commit so reviewers see them as a clean slice
on top of the prototype.

(1) MODEL_EXPRESS_URL env-var fallback — at validator level

  TorchLlmArgs.validate_mx_config now honors the upstream
  ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and
  ``mx_server_url`` is unset. Resolution happens at validator time so
  the value ends up on ``llm_args.mx_server_url`` (visible to
  logging, /startup_metrics, downstream code) instead of being
  silently re-read from env by the loader.

  Lets orchestrators (Dynamo) configure MX via the environment
  without plumbing every CLI knob, while keeping resolution in one
  place. Explicit ``mx_server_url=`` always wins. The env-var
  fallback only fires when MX is the active checkpoint format
  (so HF-only configs aren't surprised by an unrelated env var).
  Empty string in env is treated as unset.

(2) MX_SOURCE_QUERY_TIMEOUT defensive default

  MXCheckpointLoader.__init__ calls
  ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever
  an MX server URL is configured. Caps cold-cluster first-replica
  startup at 30 s instead of upstream's 1-hour default (the polling
  in MxLiveWeightLoader._query_source). setdefault semantics preserve
  any explicit user value. HF-only loads (no MX URL) don't touch
  the env at all.

  The proper upstream-side fix is a non-blocking source-query API
  (tracked as MX-4 in §15 of the design doc); this defensive default
  caps the worst case until that lands.

(3) model_name plumbing with HF-snapshot-aware resolver

  Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so
  upstream's ``publish_model_params()`` publishes under the
  user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of
  the "unknown" sentinel.

  - MXCheckpointLoader takes a new optional ``model_name``
    constructor arg (Union[str, Path]). Coerced to str at
    construction time.
  - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and
    MODEL_NAME env vars (resolving identity via the priority order
    below) and restores both env vars in finally.
    publish_model_params() reads them via env, as documented.
  - Identity resolution order: explicit constructor arg →
    MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path
    unmangling) → "unknown".
  - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/")
    is unmangled back to "<org>/<name>" instead of returning the
    commit hash.
  - _construct_checkpoint_loader plumbs ``mx_model_name`` through;
    py_executor_creator.py extracts it from llm_args.model.

  Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into
  one direct call when MX-2 (public build_identity) lands upstream.

Tests for these three additions are in the next commit.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung chienchunhung force-pushed the dynamo-integration-prototype branch 2 times, most recently from b78dafc to edb4f0f Compare April 21, 2026 22:00
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 21, 2026
… timeout, model_name plumbing

Three discrete improvements to the MX side of PR NVIDIA#13045 driven by
review feedback from MX team's downstream PR
(chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes
landed as one focused commit so reviewers see them as a clean slice
on top of the prototype.

(1) MODEL_EXPRESS_URL env-var fallback — at validator level

  TorchLlmArgs.validate_mx_config now honors the upstream
  ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and
  ``mx_server_url`` is unset. Resolution happens at validator time so
  the value ends up on ``llm_args.mx_server_url`` (visible to
  logging, /startup_metrics, downstream code) instead of being
  silently re-read from env by the loader.

  Lets orchestrators (Dynamo) configure MX via the environment
  without plumbing every CLI knob, while keeping resolution in one
  place. Explicit ``mx_server_url=`` always wins. The env-var
  fallback only fires when MX is the active checkpoint format
  (so HF-only configs aren't surprised by an unrelated env var).
  Empty string in env is treated as unset.

(2) MX_SOURCE_QUERY_TIMEOUT defensive default

  MXCheckpointLoader.__init__ calls
  ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever
  an MX server URL is configured. Caps cold-cluster first-replica
  startup at 30 s instead of upstream's 1-hour default (the polling
  in MxLiveWeightLoader._query_source). setdefault semantics preserve
  any explicit user value. HF-only loads (no MX URL) don't touch
  the env at all.

  The proper upstream-side fix is a non-blocking source-query API
  (tracked as MX-4 in §15 of the design doc); this defensive default
  caps the worst case until that lands.

(3) model_name plumbing with HF-snapshot-aware resolver

  Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so
  upstream's ``publish_model_params()`` publishes under the
  user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of
  the "unknown" sentinel.

  - MXCheckpointLoader takes a new optional ``model_name``
    constructor arg (Union[str, Path]). Coerced to str at
    construction time.
  - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and
    MODEL_NAME env vars (resolving identity via the priority order
    below) and restores both env vars in finally.
    publish_model_params() reads them via env, as documented.
  - Identity resolution order: explicit constructor arg →
    MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path
    unmangling) → "unknown".
  - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/")
    is unmangled back to "<org>/<name>" instead of returning the
    commit hash.
  - _construct_checkpoint_loader plumbs ``mx_model_name`` through;
    py_executor_creator.py extracts it from llm_args.model.

  Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into
  one direct call when MX-2 (public build_identity) lands upstream.

Tests for these three additions are in the next commit.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 21, 2026
… timeout, model_name plumbing

Three discrete improvements to the MX side of PR NVIDIA#13045 driven by
review feedback from MX team's downstream PR
(chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes
landed as one focused commit so reviewers see them as a clean slice
on top of the prototype.

(1) MODEL_EXPRESS_URL env-var fallback — at validator level

  TorchLlmArgs.validate_mx_config now honors the upstream
  ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and
  ``mx_server_url`` is unset. Resolution happens at validator time so
  the value ends up on ``llm_args.mx_server_url`` (visible to
  logging, /startup_metrics, downstream code) instead of being
  silently re-read from env by the loader.

  Lets orchestrators (Dynamo) configure MX via the environment
  without plumbing every CLI knob, while keeping resolution in one
  place. Explicit ``mx_server_url=`` always wins. The env-var
  fallback only fires when MX is the active checkpoint format
  (so HF-only configs aren't surprised by an unrelated env var).
  Empty string in env is treated as unset.

(2) MX_SOURCE_QUERY_TIMEOUT defensive default

  MXCheckpointLoader.__init__ calls
  ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever
  an MX server URL is configured. Caps cold-cluster first-replica
  startup at 30 s instead of upstream's 1-hour default (the polling
  in MxLiveWeightLoader._query_source). setdefault semantics preserve
  any explicit user value. HF-only loads (no MX URL) don't touch
  the env at all.

  The proper upstream-side fix is a non-blocking source-query API
  (tracked as MX-4 in §15 of the design doc); this defensive default
  caps the worst case until that lands.

(3) model_name plumbing with HF-snapshot-aware resolver

  Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so
  upstream's ``publish_model_params()`` publishes under the
  user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of
  the "unknown" sentinel.

  - MXCheckpointLoader takes a new optional ``model_name``
    constructor arg (Union[str, Path]). Coerced to str at
    construction time.
  - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and
    MODEL_NAME env vars (resolving identity via the priority order
    below) and restores both env vars in finally.
    publish_model_params() reads them via env, as documented.
  - Identity resolution order: explicit constructor arg →
    MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path
    unmangling) → "unknown".
  - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/")
    is unmangled back to "<org>/<name>" instead of returning the
    commit hash.
  - _construct_checkpoint_loader plumbs ``mx_model_name`` through;
    py_executor_creator.py extracts it from llm_args.model.

  Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into
  one direct call when MX-2 (public build_identity) lands upstream.

Tests for these three additions are in the next commit.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung chienchunhung force-pushed the dynamo-integration-prototype branch from edb4f0f to 0a2a950 Compare April 21, 2026 22:57
@chienchunhung
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44810 [ run ] triggered by Bot. Commit: 0a2a950 Link to invocation

@chienchunhung chienchunhung marked this pull request as ready for review April 22, 2026 20:59
@chienchunhung chienchunhung requested review from a team as code owners April 22, 2026 20:59
with gms_use_mem_pool(self._tag, target_device):
yield

def move_untracked_params(self, model: nn.Module) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure this won't cause an OOM? typically when you migrate something into GMS you will have to request a new allocation for the same piece of memory, then copy it in. simply remapping will not store it in GMS.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. To clarify: GMSBackend.move_untracked_params is not a remap — it walks the model's parameters, calls gms_client.create_mapping(size=nbytes, tag=self._tag) to allocate fresh GMS-pool memory, then replacement.copy_(tensor) and reassigns tensor.data = replacement. Mirrors upstream gpu_memory_service.integrations.trtllm.model_loader._move_untracked_params 1:1. So the bytes do end up in GMS.

The transient peak you're worried about is real though: for the duration of each per-tensor replacement.copy_(tensor), both the old (out-of-pool) buffer and the new (in-pool) buffer exist before the old one is GC'd. In practice this only fires for stray params that landed outside mem_pool_scope (post-load transforms, etc.), which is a small fraction of total weights — but for very large models on tight budgets it could spike.

Also, in this round I moved move_untracked_params inside mem_pool_scope (commit c6c6e57) to match the upstream _load_rw reference path exactly. That doesn't change the per-tensor peak, but it does ensure any intermediate buffers tensor.copy_() may allocate via PyTorch's caching allocator end up in the GMS pool and get drained by the subsequent torch.cuda.empty_cache().

Worth tracking as an upstream ask: a streaming/in-place migrator that rebinds an existing physical alloc to a GMS-mapped VA without a copy. What do you think?


from gpu_memory_service.integrations.common.utils import finalize_gms_write

bytes_committed = int(finalize_gms_write(self._client, model))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finalize_gms_write is not quite the right abstraction if we are going to want to support speculative decoding models in this path as well, if I remember correctly. we would need to defer commit until draft model weights are loaded into GMS.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right — the original RW branch only loaded the main model and committed immediately, which would have left model.draft_model either uncommitted or out-of-pool. Just pushed c6c6e57 which loads draft weights inside the same mem_pool_scope (mirroring the AUTO branch's if self.spec_config is not None and ...need_load_draft_weights(): flow) so they land in the GMS pool. finalize_gms_write(model) then picks them up via its existing model-tree walk (finalize_gms_write() walks named_parameters, so anything reachable from model — including model.draft_model — gets registered in the same commit).

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

Adds GMS GPU-memory backend and MX (ModelExpress) P2P checkpoint loader, integrates MX/GMS options into model-loading paths, adds Linear presharding flag, and extends CLI/config args and tests to support MX/GMS workflows.

Changes

Cohort / File(s) Summary
Setup & Args
setup.py, tensorrt_llm/llmapi/llm_args.py
Updated SPDX header and added comment about extras; introduced LoadFormat.GMS and MX/GMS CLI/config fields (mx_server_url, mx_preshard_strategy, gms_socket_path, gms_mode, gms_tag) with validation.
GPU Memory Backend
tensorrt_llm/_torch/memory/__init__.py, tensorrt_llm/_torch/memory/gpu_memory_backend.py
New GPUMemoryBackend protocol and GMSBackend implementation: lazy daemon connect, RW mem-pool scope, move/commit helpers, RO materialization, cleanup, and pointer/storage helpers.
MX Checkpoint Loader
tensorrt_llm/_torch/models/checkpoints/mx/...
Added MXCheckpointLoader (registered "MX"): performs MX P2P transfer via modelexpress, disk fallback on failures, publish_as_source() hook, model-name resolution, and related helpers.
HF Checkpoint Registrations
tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py, .../weight_loader.py, .../weight_mapper.py, .../__init__.py
Registered existing HF loaders/mappers/config loader to also handle "MX" and exported MXCheckpointLoader from checkpoints package.
Model Loading Integration
tensorrt_llm/_torch/pyexecutor/model_loader.py, .../model_engine.py, .../py_executor_creator.py, tensorrt_llm/executor/base_worker.py
Forwarded MX/GMS args into checkpoint construction; ModelLoader supports GMS RW/RO flows (connect, mem_pool_scope, finalize_write, materialize), skips certain prealloc/materialize steps for GMS, marks modules presharded after MX success, and adds cleanup().
Linear Module
tensorrt_llm/_torch/modules/linear.py
Introduced Linear._weights_presharded and updated weight-loading helpers to use effective TP coords so already-sharded weights are not re-sharded.
Tests & API Stability
tests/unittest/_torch/memory/test_gms_backend.py, tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py, tests/unittest/llmapi/test_mx_gms_args.py, tests/unittest/api_stability/references/llm.yaml
Added unit tests for GMSBackend, MXCheckpointLoader, and TorchLlmArgs MX/GMS validators; updated API stability reference with new prototype parameters.

Sequence Diagrams

sequenceDiagram
    participant Client
    participant MXLoader as MXCheckpointLoader
    participant Modelexpress
    participant HFDisk as HF Checkpoint Disk

    Client->>MXLoader: load_weights(checkpoint_dir, mapping, model=...)
    MXLoader->>MXLoader: validate MX url & model arg
    alt MX available
        MXLoader->>Modelexpress: request P2P transfer
        Modelexpress->>MXLoader: response (fallback_weights?)
        alt no fallback
            MXLoader->>Client: return {} (p2p_succeeded=True)
        else partial/fallback
            MXLoader->>HFDisk: load from disk
            HFDisk-->>MXLoader: weights dict
            MXLoader->>Client: return disk-loaded weights (p2p_succeeded=False)
        end
    else MX unavailable/exception
        MXLoader->>HFDisk: load from disk
        HFDisk-->>MXLoader: weights dict
        MXLoader->>Client: return disk-loaded weights (p2p_succeeded=False)
    end
Loading
sequenceDiagram
    participant Client
    participant Loader as ModelLoader
    participant GMS as GMSBackend
    participant Daemon as GMS Daemon
    participant Model

    Note over Client,Model: GMS RW path (loading)
    Client->>Loader: load(format=GMS, mode=RW)
    Loader->>GMS: connect()
    GMS->>Daemon: request RW lock & mappings
    Daemon-->>GMS: connected + mappings
    Loader->>GMS: enter mem_pool_scope()
    Loader->>Model: load weights (allocations routed to GMS pool)
    Loader->>GMS: move_untracked_params(model)
    GMS-->>Model: migrate params into GMS mappings
    Loader->>GMS: finalize_write(model)
    GMS->>Daemon: commit weights -> transition to RO

    Note over Client,Model: GMS RO path (inference)
    Client->>Loader: load(format=GMS, mode=RO)
    Loader->>GMS: connect()
    GMS->>Daemon: request RO lock
    Daemon-->>GMS: connected + mappings
    Loader->>GMS: materialize_module(model)
    GMS-->>Model: zero-copy materialize weights, mark presharded
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: prototype implementation of MX (ModelExpress) and GMS (GPU Memory Service) integration for weight sharing in TRT-LLM's Dynamo backend, which matches the primary focus of the extensive changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering all major sections of the template including a detailed summary, test coverage details, and a completed checklist.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
setup.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the copyright year for this modified file.

This file was meaningfully modified in 2026, but the header still ends at 2025.

Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines: "Add NVIDIA copyright header on ALL new files and update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@setup.py` at line 1, Update the SPDX header year range in the setup.py file
from "2022-2025" to include 2026 (e.g., "2022-2026") so the copyright line
reflects the file was modified in 2026; locate and edit the existing header
string at the top of setup.py accordingly.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/memory/gpu_memory_backend.py (1)

395-415: Log the exception in client.close() instead of silently passing.

Static analysis flags the try-except-pass pattern (S110). Even for best-effort cleanup, logging at debug level helps diagnose shutdown issues without cluttering normal output.

🔧 Proposed fix to log the close() exception
             try:
                 client.close()
-            except Exception:
+            except Exception as e:
                 # Best-effort: even if close() fails, evict so a future
                 # connect() in the same process can re-establish state.
-                pass
+                logger.debug("GMS client.close() failed (best-effort): %s", e)
             evict_gms_client_memory_manager(client)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 395 - 415, In
cleanup(), when calling client.close() inside the inner try/except, do not
silently pass on exceptions; catch the Exception and log it (at debug level) so
shutdown issues are visible. Locate the cleanup method and the inner try around
client.close() (variable client and function evict_gms_client_memory_manager)
and replace the empty except with a logger.debug call that includes the
exception details (e.g., logger.debug("GMS client.close() error: %s", e,
exc_info=True)) before continuing to evict and clear self._client.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py`:
- Around line 7-8: Add the required SPDX/NVIDIA copyright/header block at the
top of tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py so the file
includes the repository-standard NVIDIA copyright header with the correct year
of latest meaningful modification and SPDX identifier; place it above the
existing decorators (e.g., above `@register_config_loader`("MX") /
`@register_config_loader`("HF")) and ensure the header matches other source files
in format and content.

In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py`:
- Around line 142-231: The logger.warning in load_weights currently passes
len(fallback_weights) twice to a message that implies two different values;
update the log text in the fallback handling inside load_weights (in
tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader) to avoid the
misleading "delivered %d out of N" phrasing—either compute and pass an actual
delivered count if available or reword the message to only report the fallback
count (len(fallback_weights)) and the MX server URL or total tensors if you can
obtain them; ensure the call no longer supplies duplicate format args and
reference MxLiveWeightLoader/fallback_weights/_fallback_to_disk while making
this change.
- Around line 247-318: The publish_as_source method uses implicit Optional
defaults; update its signature to use explicit union types: change mapping:
Mapping = None to mapping: Mapping | None and checkpoint_dir: str = None to
checkpoint_dir: str | None (keep model's type as-is or add/qualify with
torch.nn.Module if you want an explicit model type); ensure Python 3.10+ union
syntax is supported in this project or use typing.Optional if not, and add any
missing imports (e.g., from torch import nn if you change model to nn.Module) or
use a forward/string annotation to avoid import cycles.

In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 187-195: Compute the effective TP coordinates once (e.g.,
effective_tp_size = 1 if getattr(module, '_weights_presharded', False) else
module.tp_size and effective_tp_rank = 0 if getattr(module,
'_weights_presharded', False) else module.tp_rank) and reuse those variables for
every load_weight_shard(...) call in this module (including the calls that load
quantization metadata such as weight_scale, input_scale, pre_quant_scale and any
other quantized tensors), replacing direct uses of module.tp_size/module.tp_rank
so that both weights and their quantization metadata are sliced consistently
based on module._weights_presharded and module.tp_mode.
- Line 2533: The file is missing the repo-standard SPDX/NVIDIA copyright header;
add the required NVIDIA copyright/SPDX header block at the very top of the
source file (using the year of the latest meaningful modification) so all source
files include the standard notice; update the header in the module that defines
the Linear class (the file containing the assignment self._weights_presharded =
False) to match the repository's canonical NVIDIA header format.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 502-575: GMS branches skip loading draft-model weights so
model.draft_model remains uninitialized; after the GMS RW/RO handling (but
before storing self._gms_backend) replicate the existing speculative/draft load
flow: check self.spec_config.spec_dec_mode.need_load_draft_weights() and if true
load weights from self.spec_config.speculative_model using the same
checkpoint_loader pattern (use load_weights with mapping=self.mapping, call
checkpoint_loader.get_initialized_weight_mapper to set up a mapper, then invoke
self._call_load_weights to apply weights), assign the resulting module to
model.draft_model (or set via the same initializer used in other branches) and
ensure any GMS-specific memory scope is respected (for RW use
gms_backend.mem_pool_scope as needed) so draft weights are loaded outside/inside
GMS consistently.
- Around line 526-529: The GMS branch is using the hard-coded checkpoint_dir
variable which ignores a model's redirected checkpoint source; inside the
gms_backend.mem_pool_scope block (around load_weights_kwargs and
checkpoint_loader.load_weights) change the call to pass the model's redirected
path when present by resolving checkpoint_source = getattr(self, "model", None)
and using getattr(self.model, "llm_checkpoint_dir", checkpoint_dir) (or
equivalent) so that checkpoint_loader.load_weights(...) uses
model.llm_checkpoint_dir when available, falling back to checkpoint_dir
otherwise.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3437-3438: This file (tensorrt_llm/llmapi/llm_args.py) was
modified but is missing the required NVIDIA copyright header; add the
repository-mandated NVIDIA file header (including the project name and the year
of latest meaningful modification) to the top of the file so all Python sources
comply, ensuring it appears before any code or comments (i.e., above the section
that defines GMS = 3 and the "Load weights..." comment) and update the year if
this is a modification rather than a new file.
- Around line 3716-3725: Add an explicit Pydantic field validator for gms_tag to
reject empty or whitespace-only strings so configs like gms_tag="" fail fast
instead of collapsing to the default; implement a `@field_validator`("gms_tag")
`@classmethod` named e.g. validate_gms_tag that checks v.strip() and raises
ValueError("gms_tag must be non-empty") when empty, otherwise returns v; place
it in the same class that defines the gms_tag Field so the validation runs
during model parsing.

In `@tests/unittest/_torch/memory/test_gms_backend.py`:
- Around line 39-41: Remove the class-level pytest.mark.skipif and instead make
tests run on CPU by patching torch.cuda.current_device or by instantiating
GMSBackend without calling __init__; specifically, in tests that exercise
GMSBackend (e.g., where GMSBackend.__init__, connect(), and RW/RO gating are
involved) replace the skip with a setup that monkeypatches
torch.cuda.current_device to return a valid device (e.g.,
monkeypatch.setattr(torch.cuda, "current_device", lambda: 0)) so __init__ can
run under mocks, or create the object via GMSBackend.__new__(GMSBackend) and
manually set any internal attributes needed for the mocked connect()/protocol
code paths to execute; apply this change to the blocks referenced (around lines
with the prior skip and the other ranges called out) so the mocked happy/failure
connect() and gating paths execute on CPU CI.

In `@tests/unittest/llmapi/test_mx_gms_args.py`:
- Around line 279-289: Add a regression test to ensure integer load_format
values map to the new enum member by calling TorchLlmArgs.convert_load_format
with the raw integer 3 and asserting it yields LoadFormat.GMS; update the
TestLoadFormatEnum suite (or add a new test method) to include an assertion that
TorchLlmArgs.convert_load_format(3) == LoadFormat.GMS while keeping the existing
checks for AUTO/DUMMY/VISION_ONLY unchanged.

---

Outside diff comments:
In `@setup.py`:
- Line 1: Update the SPDX header year range in the setup.py file from
"2022-2025" to include 2026 (e.g., "2022-2026") so the copyright line reflects
the file was modified in 2026; locate and edit the existing header string at the
top of setup.py accordingly.

---

Nitpick comments:
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py`:
- Around line 395-415: In cleanup(), when calling client.close() inside the
inner try/except, do not silently pass on exceptions; catch the Exception and
log it (at debug level) so shutdown issues are visible. Locate the cleanup
method and the inner try around client.close() (variable client and function
evict_gms_client_memory_manager) and replace the empty except with a
logger.debug call that includes the exception details (e.g., logger.debug("GMS
client.close() error: %s", e, exc_info=True)) before continuing to evict and
clear self._client.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 88b1a034-9c30-4c0d-a6cd-250f1d7100ae

📥 Commits

Reviewing files that changed from the base of the PR and between 6e5a339 and 0a2a950.

📒 Files selected for processing (19)
  • setup.py
  • tensorrt_llm/_torch/memory/__init__.py
  • tensorrt_llm/_torch/memory/gpu_memory_backend.py
  • tensorrt_llm/_torch/models/checkpoints/__init__.py
  • tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py
  • tensorrt_llm/_torch/models/checkpoints/mx/__init__.py
  • tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/memory/test_gms_backend.py
  • tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py
  • tests/unittest/api_stability/references/llm.yaml
  • tests/unittest/llmapi/test_mx_gms_args.py

Comment on lines +7 to 8
@register_config_loader("MX")
@register_config_loader("HF")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add the required SPDX/NVIDIA header to this modified source file.

The registry change is fine, but this TensorRT-LLM source file still lacks the repo-standard copyright header block.

As per coding guidelines "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py` around lines 7 -
8, Add the required SPDX/NVIDIA copyright/header block at the top of
tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py so the file includes
the repository-standard NVIDIA copyright header with the correct year of latest
meaningful modification and SPDX identifier; place it above the existing
decorators (e.g., above `@register_config_loader`("MX") /
`@register_config_loader`("HF")) and ensure the header matches other source files
in format and content.

Comment on lines +142 to +231
def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
"""Load weights, preferring MX P2P transfer when available.

Delegates the actual transfer to the upstream
``modelexpress.trtllm_live_transfer.MxLiveWeightLoader``,
which handles NIXL setup, source discovery, name matching,
dtype casting, and PVC fallback for size-mismatched tensors.

Args:
checkpoint_dir: Path to the HF checkpoint directory.
mapping: Distributed mapping configuration.
**kwargs: Additional keyword arguments. When ``model`` is
passed it is used as the target for direct P2P writes.

Returns:
A weights dict. Empty when MX P2P fully succeeded (weights
already in model params); populated when falling back to
disk loading for some or all weights.
"""
model = kwargs.pop("model", None)
self._p2p_succeeded = False

if self._mx_server_url is None or model is None:
return self._fallback_to_disk(
checkpoint_dir,
mapping,
reason=(
"no MX server URL configured"
if self._mx_server_url is None
else "no model reference passed (cannot do P2P writes)"
),
**kwargs,
)

try:
from modelexpress.trtllm_live_transfer import ( # type: ignore[import-not-found]
MxLiveWeightLoader,
)
except ImportError:
logger.warning(
"modelexpress library not installed; cannot use MX P2P "
"weight transfer. Install from "
"https://github.com/ai-dynamo/modelexpress (Python client at "
"modelexpress_client/python). Falling back to disk loading."
)
return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs)

try:
mx_loader = MxLiveWeightLoader(mx_server=self._mx_server_url)
fallback_weights = mx_loader.load_weights(
checkpoint_dir,
mapping=mapping,
model=model,
)
except Exception as e:
logger.warning(
"MX P2P transfer failed (%s). Falling back to disk loading.",
e,
)
return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs)

if fallback_weights:
# Mixed-success case: MX delivered most weights into model
# params via P2P, but returned a dict of size-mismatched
# tensors that still need to be loaded via the standard
# pipeline. Marking Linear modules as fully presharded would
# then incorrectly skip TP slicing for those fallback
# weights too. The conservative correct behavior is to fall
# through to the standard disk path entirely, which loads
# *all* weights normally and ignores the MX-written ones.
#
# (When LoadFormat.PRESHARDED lands upstream, the proper
# fix is per-tensor presharded marking based on whether
# MX delivered that tensor.)
logger.warning(
"MX P2P delivered %d out of N tensors but returned %d "
"fallback weights (size mismatch). Falling back to full "
"disk load to avoid mixing presharded and non-presharded "
"weights.",
len(fallback_weights), # this is a count of fallback only
len(fallback_weights),
)
return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs)

self._p2p_succeeded = True
logger.info(
"MX P2P weight transfer succeeded from %s",
self._mx_server_url,
)
return {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Misleading log message format.

The log message on lines 216-222 uses len(fallback_weights) for both format arguments, but the message text implies two different values: "delivered %d out of N" (successful transfers) and "returned %d fallback weights" (failures). Since only the fallback count is available, the message should be clarified.

📝 Proposed fix for clearer logging
             logger.warning(
-                "MX P2P delivered %d out of N tensors but returned %d "
-                "fallback weights (size mismatch). Falling back to full "
+                "MX P2P returned %d fallback weights (size mismatch). "
+                "Falling back to full "
                 "disk load to avoid mixing presharded and non-presharded "
                 "weights.",
-                len(fallback_weights),  # this is a count of fallback only
                 len(fallback_weights),
             )
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 196-196: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
142 - 231, The logger.warning in load_weights currently passes
len(fallback_weights) twice to a message that implies two different values;
update the log text in the fallback handling inside load_weights (in
tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader) to avoid the
misleading "delivered %d out of N" phrasing—either compute and pass an actual
delivered count if available or reword the message to only report the fallback
count (len(fallback_weights)) and the MX server URL or total tensors if you can
obtain them; ensure the call no longer supplies duplicate format args and
reference MxLiveWeightLoader/fallback_weights/_fallback_to_disk while making
this change.

Comment on lines +247 to +318
def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str = None) -> None:
"""Publish this instance's weights so other ranks can pull via P2P.

Called by the integration in ``model_loader.py`` *before*
``post_load_weights()`` so targets receive raw loaded state and
can apply their own post-load transforms.

Delegates to the upstream
``modelexpress.trtllm_live_transfer.publish_model_params``
helper, which handles the per-rank NIXL setup, tensor
registration, and gRPC publish.

Args:
model: The model whose weights to publish.
mapping: Distributed mapping. Currently unused — kept for
signature symmetry with the prior prototype API and for
forward-compat with future upstream signatures.
checkpoint_dir: Checkpoint directory. Used as a last-resort
fallback for resolving the ``MODEL_NAME`` identity when
neither ``model_name`` was passed to the constructor nor
``MODEL_NAME`` is set in the environment.
"""
del mapping # currently unused; see docstring.

if self._mx_server_url is None:
return

try:
from modelexpress.trtllm_live_transfer import ( # type: ignore[import-not-found]
publish_model_params,
)
except ImportError:
logger.debug("modelexpress library not installed; skipping MX publish.")
return

# Upstream publish_model_params reads MODEL_EXPRESS_URL and
# MODEL_NAME from the environment. Set both from our resolved
# configuration so per-instance values (URL passed via
# llm_args.mx_server_url, identity from llm_args.model) are
# respected, then restore prior state. Tracked as MX-2 in §15
# (the env-var dance goes away when upstream exports a public
# ``build_identity()`` we can call directly).
resolved_name = _resolve_mx_model_name(self._model_name, checkpoint_dir)

env_overrides = {
"MODEL_EXPRESS_URL": self._mx_server_url,
"MODEL_NAME": resolved_name,
}
prior = {key: os.environ.get(key) for key in env_overrides}
for key, value in env_overrides.items():
os.environ[key] = value

try:
publish_model_params(model)
logger.info(
"Published weights to MX server at %s as model=%r",
self._mx_server_url,
resolved_name,
)
except Exception as e:
logger.warning(
"Failed to publish weights to MX server at %s: %s",
self._mx_server_url,
e,
)
finally:
for key, prior_value in prior.items():
if prior_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = prior_value

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use explicit | None for optional parameter types.

Per PEP 484 and coding guidelines, implicit Optional (via = None default) is prohibited. The parameters mapping and checkpoint_dir need explicit union type annotations.

🔧 Proposed fix for type annotations
-    def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str = None) -> None:
+    def publish_as_source(
+        self, model: "nn.Module", mapping: Mapping | None = None, checkpoint_dir: str | None = None
+    ) -> None:

Note: You'll need to add from torch import nn to imports if not already present, or use a string annotation for model.

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 247-247: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


[warning] 306-306: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
247 - 318, The publish_as_source method uses implicit Optional defaults; update
its signature to use explicit union types: change mapping: Mapping = None to
mapping: Mapping | None and checkpoint_dir: str = None to checkpoint_dir: str |
None (keep model's type as-is or add/qualify with torch.nn.Module if you want an
explicit model type); ensure Python 3.10+ union syntax is supported in this
project or use typing.Optional if not, and add any missing imports (e.g., from
torch import nn if you change model to nn.Module) or use a forward/string
annotation to avoid import cycles.

Comment on lines +187 to 195
# Skip TP slicing when weights are already sharded (e.g. MX P2P transfer).
tp_size = 1 if getattr(module, '_weights_presharded',
False) else module.tp_size
tp_rank = 0 if getattr(module, '_weights_presharded',
False) else module.tp_rank

weight = load_weight_shard(weights[0]['weight'], tp_size, tp_rank,
module.tp_mode,
device) if "weight" in weights[0] else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't bypass TP slicing for weights without also bypassing it for quantization metadata.

_weights_presharded now stops extra slicing for the main weight/bias tensors, but the quantized loaders later in this file still shard weight_scale, input_scale, pre_quant_scale, and similar tensors with module.tp_size/module.tp_rank. For presharded MX/GMS loads, that leaves full local weights paired with re-sliced metadata, which will either shape-mismatch or silently produce wrong results on quantized models. Please centralize the “effective TP coords” logic and reuse it for every load_weight_shard(...) call in this file, not just these helpers.

Also applies to: 234-248, 293-304

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 187 - 195, Compute the
effective TP coordinates once (e.g., effective_tp_size = 1 if getattr(module,
'_weights_presharded', False) else module.tp_size and effective_tp_rank = 0 if
getattr(module, '_weights_presharded', False) else module.tp_rank) and reuse
those variables for every load_weight_shard(...) call in this module (including
the calls that load quantization metadata such as weight_scale, input_scale,
pre_quant_scale and any other quantized tensors), replacing direct uses of
module.tp_size/module.tp_rank so that both weights and their quantization
metadata are sliced consistently based on module._weights_presharded and
module.tp_mode.

self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
self.disable_deep_gemm = disable_deep_gemm
self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping
self._weights_presharded = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add the required SPDX/NVIDIA header to this modified source file.

This file now has meaningful changes but still does not carry the repo-standard copyright header block.

As per coding guidelines "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/linear.py` at line 2533, The file is missing the
repo-standard SPDX/NVIDIA copyright header; add the required NVIDIA
copyright/SPDX header block at the very top of the source file (using the year
of the latest meaningful modification) so all source files include the standard
notice; update the header in the module that defines the Linear class (the file
containing the assignment self._weights_presharded = False) to match the
repository's canonical NVIDIA header format.

Comment on lines +526 to +529
with gms_backend.mem_pool_scope(device):
load_weights_kwargs = {"mapping": self.mapping}
weights = checkpoint_loader.load_weights(
checkpoint_dir, **load_weights_kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

GMS RW drops the model.llm_checkpoint_dir redirect.

The AUTO path already loads from model.llm_checkpoint_dir when the model rewrites its checkpoint source, but the GMS RW branch hard-codes checkpoint_dir. That creates a GMS-only failure mode for the same model class.

💡 Suggested fix
-                        weights = checkpoint_loader.load_weights(
-                            checkpoint_dir, **load_weights_kwargs)
+                        weight_source = (
+                            model.llm_checkpoint_dir
+                            if hasattr(model, "llm_checkpoint_dir")
+                            else checkpoint_dir
+                        )
+                        weights = checkpoint_loader.load_weights(
+                            weight_source, **load_weights_kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with gms_backend.mem_pool_scope(device):
load_weights_kwargs = {"mapping": self.mapping}
weights = checkpoint_loader.load_weights(
checkpoint_dir, **load_weights_kwargs)
with gms_backend.mem_pool_scope(device):
load_weights_kwargs = {"mapping": self.mapping}
weight_source = (
model.llm_checkpoint_dir
if hasattr(model, "llm_checkpoint_dir")
else checkpoint_dir
)
weights = checkpoint_loader.load_weights(
weight_source, **load_weights_kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 526 - 529, The
GMS branch is using the hard-coded checkpoint_dir variable which ignores a
model's redirected checkpoint source; inside the gms_backend.mem_pool_scope
block (around load_weights_kwargs and checkpoint_loader.load_weights) change the
call to pass the model's redirected path when present by resolving
checkpoint_source = getattr(self, "model", None) and using getattr(self.model,
"llm_checkpoint_dir", checkpoint_dir) (or equivalent) so that
checkpoint_loader.load_weights(...) uses model.llm_checkpoint_dir when
available, falling back to checkpoint_dir otherwise.

Comment on lines +3437 to +3438
# Load weights from GPU Memory Service (read-only GPU memory pool).
GMS = 3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add the required NVIDIA file header.

This file is modified in this PR but still lacks the repository-mandated NVIDIA copyright header / latest-modification year.

As per coding guidelines, **/*.{cpp,cc,h,hpp,c,py} must “Add NVIDIA copyright header on ALL new files and update year on modified files” and “All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 3437 - 3438, This file
(tensorrt_llm/llmapi/llm_args.py) was modified but is missing the required
NVIDIA copyright header; add the repository-mandated NVIDIA file header
(including the project name and the year of latest meaningful modification) to
the top of the file so all Python sources comply, ensuring it appears before any
code or comments (i.e., above the section that defines GMS = 3 and the "Load
weights..." comment) and update the year if this is a modification rather than a
new file.

Comment on lines +3716 to +3725
gms_tag: str = Field(
default="weights",
description="Tag identifying the weight set in the GMS memory pool. "
"Used to distinguish multiple models or versions in the same "
"GMS instance. Defaults to 'weights' to match the GMS library "
"convention (gpu_memory_service uses 'weights' for the model "
"weight tag and 'kv_cache' for the KV cache tag). "
"Only used when load_format='GMS'.",
status="prototype",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject empty gms_tag values.

gms_tag="" currently parses, and the GMS load path later collapses falsey tags to the shared default "weights". That turns a bad config into a silent pool collision instead of failing fast.

💡 Suggested fix
-    gms_tag: str = Field(
+    gms_tag: str = Field(
         default="weights",
         description="Tag identifying the weight set in the GMS memory pool. "
         "Used to distinguish multiple models or versions in the same "
         "GMS instance. Defaults to 'weights' to match the GMS library "
         "convention (gpu_memory_service uses 'weights' for the model "
         "weight tag and 'kv_cache' for the KV cache tag). "
         "Only used when load_format='GMS'.",
         status="prototype",
     )
`@field_validator`("gms_tag")
`@classmethod`
def validate_gms_tag(cls, v: str) -> str:
    if not v.strip():
        raise ValueError("gms_tag must be non-empty")
    return v
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 3716 - 3725, Add an explicit
Pydantic field validator for gms_tag to reject empty or whitespace-only strings
so configs like gms_tag="" fail fast instead of collapsing to the default;
implement a `@field_validator`("gms_tag") `@classmethod` named e.g. validate_gms_tag
that checks v.strip() and raises ValueError("gms_tag must be non-empty") when
empty, otherwise returns v; place it in the same class that defines the gms_tag
Field so the validation runs during model parsing.

Comment on lines +39 to +41
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="GMSBackend.__init__ calls torch.cuda.current_device()"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These skips leave the new GMS control-path tests effectively uncovered in CPU CI.

Almost everything exercised here is mocked; the only real CUDA dependency is GMSBackend.__init__ touching torch.cuda.current_device(). With the class-level skipif, the pre-connect, connect(), RW/RO gating, and protocol checks all disappear from the default CPU-only unit suite, and this file still never asserts a successful RW or RO connect() path. Please patch torch.cuda.current_device() (or construct via __new__) so the mocked happy/failure paths run without a GPU. QA list updates are unnecessary here because this is still unittest-only coverage.

As per coding guidelines "Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix."

Also applies to: 88-90, 129-148, 156-158, 268-276

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/memory/test_gms_backend.py` around lines 39 - 41,
Remove the class-level pytest.mark.skipif and instead make tests run on CPU by
patching torch.cuda.current_device or by instantiating GMSBackend without
calling __init__; specifically, in tests that exercise GMSBackend (e.g., where
GMSBackend.__init__, connect(), and RW/RO gating are involved) replace the skip
with a setup that monkeypatches torch.cuda.current_device to return a valid
device (e.g., monkeypatch.setattr(torch.cuda, "current_device", lambda: 0)) so
__init__ can run under mocks, or create the object via
GMSBackend.__new__(GMSBackend) and manually set any internal attributes needed
for the mocked connect()/protocol code paths to execute; apply this change to
the blocks referenced (around lines with the prior skip and the other ranges
called out) so the mocked happy/failure connect() and gating paths execute on
CPU CI.

Comment on lines +279 to +289
class TestLoadFormatEnum:
def test_gms_enum_present(self):
# The prototype adds LoadFormat.GMS = 3.
assert hasattr(LoadFormat, "GMS")
assert LoadFormat.GMS.name == "GMS"

def test_pre_existing_enums_unchanged(self):
# Sanity — make sure the prototype didn't reshuffle enum values.
assert LoadFormat.AUTO.value == 0
assert LoadFormat.DUMMY.value == 1
assert LoadFormat.VISION_ONLY.value == 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a regression test for integer load_format inputs.

TorchLlmArgs.convert_load_format() now accepts raw ints, but this suite only checks enum members/constants. A case like load_format=3 -> LoadFormat.GMS should be pinned down here.

🧪 Suggested test
 class TestLoadFormatEnum:
+    def test_int_value_is_converted(self):
+        args = _make_args(load_format=3)
+        assert args.load_format == LoadFormat.GMS
+
     def test_gms_enum_present(self):
         # The prototype adds LoadFormat.GMS = 3.
         assert hasattr(LoadFormat, "GMS")
         assert LoadFormat.GMS.name == "GMS"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/llmapi/test_mx_gms_args.py` around lines 279 - 289, Add a
regression test to ensure integer load_format values map to the new enum member
by calling TorchLlmArgs.convert_load_format with the raw integer 3 and asserting
it yields LoadFormat.GMS; update the TestLoadFormatEnum suite (or add a new test
method) to include an assertion that TorchLlmArgs.convert_load_format(3) ==
LoadFormat.GMS while keeping the existing checks for AUTO/DUMMY/VISION_ONLY
unchanged.

@chienchunhung chienchunhung changed the title [TRTLLM-11851][feat] MX and GMS integration prototype for Dynamo weight sharing [TRTLLM-11851][feat] MX and GMS integration MVP for Dynamo weight sharing Apr 22, 2026
chienchunhung added a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 22, 2026
… timeout, model_name plumbing

Three discrete improvements to the MX side of PR NVIDIA#13045 driven by
review feedback from MX team's downstream PR
(chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes
landed as one focused commit so reviewers see them as a clean slice
on top of the prototype.

(1) MODEL_EXPRESS_URL env-var fallback — at validator level

  TorchLlmArgs.validate_mx_config now honors the upstream
  ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and
  ``mx_server_url`` is unset. Resolution happens at validator time so
  the value ends up on ``llm_args.mx_server_url`` (visible to
  logging, /startup_metrics, downstream code) instead of being
  silently re-read from env by the loader.

  Lets orchestrators (Dynamo) configure MX via the environment
  without plumbing every CLI knob, while keeping resolution in one
  place. Explicit ``mx_server_url=`` always wins. The env-var
  fallback only fires when MX is the active checkpoint format
  (so HF-only configs aren't surprised by an unrelated env var).
  Empty string in env is treated as unset.

(2) MX_SOURCE_QUERY_TIMEOUT defensive default

  MXCheckpointLoader.__init__ calls
  ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever
  an MX server URL is configured. Caps cold-cluster first-replica
  startup at 30 s instead of upstream's 1-hour default (the polling
  in MxLiveWeightLoader._query_source). setdefault semantics preserve
  any explicit user value. HF-only loads (no MX URL) don't touch
  the env at all.

  The proper upstream-side fix is a non-blocking source-query API
  (tracked as MX-4 in §15 of the design doc); this defensive default
  caps the worst case until that lands.

(3) model_name plumbing with HF-snapshot-aware resolver

  Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so
  upstream's ``publish_model_params()`` publishes under the
  user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of
  the "unknown" sentinel.

  - MXCheckpointLoader takes a new optional ``model_name``
    constructor arg (Union[str, Path]). Coerced to str at
    construction time.
  - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and
    MODEL_NAME env vars (resolving identity via the priority order
    below) and restores both env vars in finally.
    publish_model_params() reads them via env, as documented.
  - Identity resolution order: explicit constructor arg →
    MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path
    unmangling) → "unknown".
  - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/")
    is unmangled back to "<org>/<name>" instead of returning the
    commit hash.
  - _construct_checkpoint_loader plumbs ``mx_model_name`` through;
    py_executor_creator.py extracts it from llm_args.model.

  Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into
  one direct call when MX-2 (public build_identity) lands upstream.

Tests for these three additions are in the next commit.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung chienchunhung force-pushed the dynamo-integration-prototype branch from 0a2a950 to 6bdcfd8 Compare April 22, 2026 22:48
…ht sharing

Prototype implementation of the two-axis integration model for MX
(ModelExpress P2P transfer) and GMS (GPU Memory Service) in TRT-LLM's
PyTorch backend. Enables fast cold-start via GPU-to-GPU weight
streaming (MX) and zero-copy weight sharing with crash-resilient
failover (GMS).

The key design insight is that MX and GMS map onto two **independent
axes** in TRT-LLM's existing loading pipeline:

  - ``checkpoint_format`` (weight source axis): ``"MX"`` — P2P via
    upstream ``modelexpress.trtllm_live_transfer.MxLiveWeightLoader``,
    with automatic HF disk fallback (inherited from HfCheckpointLoader).
  - ``LoadFormat`` (memory management axis): ``GMS`` — out-of-process
    GPU memory pool with RW/RO dual paths via the upstream
    ``gpu_memory_service`` class-based API.

These compose independently, giving four modes with no combinatorial
config explosion:

| Mode         | checkpoint_format | LoadFormat | Use case                                 |
|--------------|-------------------|------------|------------------------------------------|
| Pure TRT-LLM | "HF" (default)    | AUTO       | Current behavior, unchanged              |
| MX only      | "MX"              | AUTO       | Cross-node P2P weight transfer           |
| GMS only     | "HF"              | GMS        | Within-node weight sharing + failover    |
| MX + GMS     | "MX"              | GMS        | Cross-node P2P + within-node sharing     |

Design intent: TRT-LLM owns the integration policy (Layer 3) and
calls the upstream libraries' stable per-call primitives (Layer 2).
We never duplicate the wire-protocol or runtime mechanics (Layer 1 —
gRPC, NIXL RDMA, CUDA VMM). We deliberately do NOT use the upstream
``gpu_memory_service.integrations.trtllm.setup_gms()`` entry point
because it monkey-patches ``ModelLoader.load`` from outside, which is
opaque at code-review time and conflicts with the two-axis design.

What's included:

  MXCheckpointLoader (_torch/models/checkpoints/mx/checkpoint_loader.py)
    - Subclasses HfCheckpointLoader — disk fallback inherited.
    - Delegates the actual NIXL/RDMA transfer to upstream
      MxLiveWeightLoader.load_weights(checkpoint_dir, mapping=, model=)
      (handles agent setup, source matching, dtype-cast handling,
      PVC fallback for size-mismatched tensors).
    - Delegates the publish side to upstream publish_model_params(model).
    - Exposes ``p2p_succeeded`` property — consumed by ModelLoader to
      set ``_weights_presharded`` on Linear modules.
    - Pre-post_load_weights() MX source publish timing so targets
      receive raw loaded state and run their own transforms.
    - Conservative mixed-success behavior: if MX returns
      size-mismatched fallback weights, falls through to a full disk
      load to avoid mixing presharded and non-presharded weights.

  GMSBackend (_torch/memory/gpu_memory_backend.py)
    - GPUMemoryBackend protocol + concrete GMSBackend implementation
      wrapping upstream GMSClientMemoryManager and the gms_use_mem_pool
      context manager.
    - Mode resolved at connect() via
      get_or_create_gms_client_memory_manager(socket, device,
      mode=RequestedLockType.RW_OR_RO) and inspected via granted_lock_type.
    - RW path: caller uses gms_backend.mem_pool_scope(device) context
      manager (delegates to gms_use_mem_pool("weights", device)); after
      load, move_untracked_params() migrates stray buffers into the
      pool, then finalize_write() delegates to upstream
      finalize_gms_write() (handles register + sync + commit + RO
      reconnect + remap in one call).
    - RO path: post_load_weights() runs *before* materialize_module()
      (sets up module aliases first); guard prevents double-execution;
      materialize_module_from_gms(mgr, model, device_index=N) does
      the zero-copy import.
    - Default tag is ``"weights"`` matching the GMS library convention
      (``weights`` for model weights, ``kv_cache`` for KV cache).
    - Default socket path resolves via get_socket_path(device, tag)
      (UUID-based, stable across CUDA_VISIBLE_DEVICES).
    - VMM safety: connect() applies patch_empty_cache() to prevent
      segfaults on VMM-backed allocations.

  ModelLoader changes (_torch/pyexecutor/model_loader.py)
    - New LoadFormat.GMS branch with full RW/RO dual paths.
    - MX source publish fires for LoadFormat.AUTO and GMS-RW modes,
      before post_load_weights().
    - _weights_presharded flag set per-path (MX P2P in AUTO, GMS
      materialize in RO).
    - Draft model modules excluded from presharded flag.
    - Wires the mx_preshard_strategy knob (per_module / global).

  Configuration (llmapi/llm_args.py)
    - mx_server_url, mx_preshard_strategy, gms_socket_path, gms_mode,
      gms_tag — all status="prototype".
    - mx_preshard_strategy: per_module | global (default per_module);
      global raises NotImplementedError until LoadFormat.PRESHARDED
      lands upstream (tracked as MX-1 in the design doc §15).
    - Validators: validate_mx_config(), validate_gms_config() reject
      bad values with friendly messages.
    - API stability YAML updated.

  Linear module (_torch/modules/linear.py)
    - _weights_presharded = False declared in __init__ — TP slicing
      skipped when True (concept from the closed PR NVIDIA#12898).

  Optional dependencies (setup.py)
    - The MX (modelexpress) and GMS (gpu-memory-service) Python
      packages are intentionally NOT declared as ``[mx]`` / ``[gms]`` /
      ``[dynamo]`` extras_require entries while the integration is at
      prototype status. Until both packages are PyPI-published *and*
      onboarded into NVIDIA's OSS allowlist, users install them
      manually:

        pip install "modelexpress>=0.3.0,<0.4.0"
        git clone https://github.com/ai-dynamo/dynamo \
          && pip install ./dynamo/lib/gpu_memory_service

      Restoring one-line ``pip install tensorrt_llm[dynamo]``
      ergonomics is a single-hunk revert once upstream publish + OSS
      allowlist steps are complete (tracked in §15 as MX-7 and GMS-6).
      Rationale and the suppressed extras block are inline-documented
      in setup.py at the extras_require site.

Design decisions:

  - No monkey-patching. The upstream setup_gms() entry point patches
    ModelLoader.load from outside. We deliberately do not use it:
    TRT-LLM owns the integration policy, and GMSBackend is the
    explicit, reviewable boundary.
  - MX P2P is NOT used in GMS RW mode. Model params are meta tensors
    at that point — no CUDA buffers for P2P to write into, and
    MX/NIXL would allocate buffers outside the GMS pool. Disk
    loading under the GMS pool is the correct behavior.
  - Property-based ``checkpoint_format`` override on
    MXCheckpointLoader — the parent sets _checkpoint_format = "HF",
    the subclass overrides both the property and backing attribute.

CLI usage:

    # MX only (cross-node P2P)
    trtllm-serve <model> --checkpoint-format mx --mx-server-url http://mx:8001

    # GMS only (within-node sharing + crash resilience)
    trtllm-serve <model> --load-format gms --gms-socket-path /tmp/gms-0.sock

    # MX + GMS (full two-axis composition)
    trtllm-serve <model> --checkpoint-format mx --mx-server-url http://mx:8001 \
      --load-format gms --gms-socket-path /tmp/gms-0.sock

Relationship to existing work:
  - Adopts and adapts the _weights_presharded concept and pre-
    post_load_weights() publish timing from MX team's PR NVIDIA#12898
    (closed in favor of this two-axis model).
  - Uses the same upstream Layer-2 primitives as dynamo PR NVIDIA#7575 (the
    official TRT-LLM sleep/wake integration with GMS), but invokes
    them directly rather than via setup_gms() monkey-patch.
  - Full design doc at docs/design/mx-gms-integration/ on the
    docs-and-plans branch.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
… timeout, model_name plumbing

Three discrete improvements to the MX side of PR NVIDIA#13045 driven by
review feedback from MX team's downstream PR
(chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes
landed as one focused commit so reviewers see them as a clean slice
on top of the prototype.

(1) MODEL_EXPRESS_URL env-var fallback — at validator level

  TorchLlmArgs.validate_mx_config now honors the upstream
  ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and
  ``mx_server_url`` is unset. Resolution happens at validator time so
  the value ends up on ``llm_args.mx_server_url`` (visible to
  logging, /startup_metrics, downstream code) instead of being
  silently re-read from env by the loader.

  Lets orchestrators (Dynamo) configure MX via the environment
  without plumbing every CLI knob, while keeping resolution in one
  place. Explicit ``mx_server_url=`` always wins. The env-var
  fallback only fires when MX is the active checkpoint format
  (so HF-only configs aren't surprised by an unrelated env var).
  Empty string in env is treated as unset.

(2) MX_SOURCE_QUERY_TIMEOUT defensive default

  MXCheckpointLoader.__init__ calls
  ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever
  an MX server URL is configured. Caps cold-cluster first-replica
  startup at 30 s instead of upstream's 1-hour default (the polling
  in MxLiveWeightLoader._query_source). setdefault semantics preserve
  any explicit user value. HF-only loads (no MX URL) don't touch
  the env at all.

  The proper upstream-side fix is a non-blocking source-query API
  (tracked as MX-4 in §15 of the design doc); this defensive default
  caps the worst case until that lands.

(3) model_name plumbing with HF-snapshot-aware resolver

  Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so
  upstream's ``publish_model_params()`` publishes under the
  user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of
  the "unknown" sentinel.

  - MXCheckpointLoader takes a new optional ``model_name``
    constructor arg (Union[str, Path]). Coerced to str at
    construction time.
  - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and
    MODEL_NAME env vars (resolving identity via the priority order
    below) and restores both env vars in finally.
    publish_model_params() reads them via env, as documented.
  - Identity resolution order: explicit constructor arg →
    MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path
    unmangling) → "unknown".
  - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/")
    is unmangled back to "<org>/<name>" instead of returning the
    commit hash.
  - _construct_checkpoint_loader plumbs ``mx_model_name`` through;
    py_executor_creator.py extracts it from llm_args.model.

  Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into
  one direct call when MX-2 (public build_identity) lands upstream.

Tests for these three additions are in the next commit.

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
…dapters

Adds 136 CPU-only unit tests covering the MX/GMS prototype's TRT-LLM-side
surface. Tests do NOT exercise the upstream ``modelexpress`` or
``gpu_memory_service`` libraries — fallback paths are exercised by
injecting None into ``sys.modules`` (forcing ImportError on ``import``),
and success paths use fake module trees. Real CUDA is gated with
``@pytest.mark.skipif(not torch.cuda.is_available())``.

Test files added:

tests/unittest/llmapi/test_mx_gms_args.py (41 tests)

  - Default values for all 5 prototype config fields. Notable:
    ``gms_tag`` defaults to ``"weights"`` (not ``"model_weights"``)
    matching the upstream GMS_TAGS convention.
  - ``mx_preshard_strategy`` validation: accepts per_module/global,
    rejects everything else, parametrized cross-field warnings on
    (checkpoint_format, strategy).
  - ``mx_server_url`` cross-field warning: parametrized over the
    {explicit-with-MX, explicit-without-MX, unset} matrix.
  - ``MODEL_EXPRESS_URL`` env-var fallback (validator-level): env
    populates / explicit wins / no-env stays None / env ignored when
    not MX / empty env is no-op / no spurious cross-field warning on
    env-fallback path.
  - ``gms_mode`` validator: only enforced when ``load_format=GMS``;
    accepts auto/rw/ro, rejects others.
  - ``gms_socket_path`` cross-field warning: parametrized over the
    {GMS, AUTO} matrix.
  - Two-axis composition sanity: pure TRT-LLM, MX-only, GMS-only,
    MX+GMS all construct without error.
  - LoadFormat enum: GMS=3 added, pre-existing values unshifted.

  Uses ``unittest.mock.patch`` on
  ``tensorrt_llm.llmapi.llm_args.logger`` (TRT-LLM uses a custom
  Singleton logger that does NOT route through stdlib logging, so
  pytest's ``caplog`` does not intercept it).

tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py (53 tests)

  - Construction: subclasses HfCheckpointLoader, exposes
    mx_server_url/p2p_succeeded/checkpoint_format/model_name
    properties, registered under ``checkpoint_format="MX"`` in the
    BaseCheckpointLoader registry (verified via the same call shape
    that ``_construct_checkpoint_loader`` uses).

  - Disk fallback paths (no upstream library involved) — parametrized
    on the fallback trigger (no_url / no_model / lib_unavailable /
    upstream_raises). Each case asserts the same observable contract:
    p2p_succeeded stays False, super().load_weights is invoked once
    and its return value is propagated unchanged.

  - MX path (mocked upstream):
    * Empty fallback dict → p2p_succeeded=True, returns {}, plus
      assertions that MxLiveWeightLoader was constructed with
      mx_server=URL and load_weights was called with the right args
      (pins the integration contract).
    * Non-empty fallback dict (mixed-success) → falls through to
      full disk load to avoid mixing presharded and non-presharded
      weights (tracked as MX-1 in §15 of the design doc).

  - publish_as_source: no-op when no URL or upstream missing,
    delegates to publish_model_params(model), env-var dance for both
    MODEL_EXPRESS_URL and MODEL_NAME (sets during call, restores
    prior value), swallows upstream exceptions.

  - MX_SOURCE_QUERY_TIMEOUT defensive default: unset gets "30",
    explicit user value preserved (setdefault semantics), no MX URL
    leaves env untouched.

  - model_name plumbing: constructor accepts str/Path/None, coerced
    to str at construction.

  - _normalize_model_identity (8 parametrized cases): pass-through
    for Hub IDs and bare names; basename for absolute, relative,
    and home-expansion paths; HF-snapshot unmangling for both simple
    and nested-org HF cache layouts.

  - _resolve_mx_model_name (6 cases): priority order
    (explicit > env > basename > "unknown"); explicit-arg
    normalization.

  - publish_as_source MODEL_NAME end-to-end (6 cases): explicit
    constructor value wins, basename fallback, HF-snapshot
    unmangling, constructor priority over env, env-only path,
    "unknown" sentinel.

tests/unittest/_torch/memory/test_gms_backend.py (42 tests)

  - Construction: validates mode in {rw, ro, auto}, default tag is
    "weights" (matches DEFAULT_TAG class attribute), default socket
    path is resolved lazily inside connect().

  - Pre-connect state: is_rw is None, has_committed_weights returns
    False (not raise), all RW-only methods raise RuntimeError with
    "not connected" — parametrized on (method_name, invoke), cleanup
    is safe.

  - connect() failure: returns False (not raises) when
    gpu_memory_service is not importable; returns False when the
    upstream factory raises (e.g. socket missing).

  - RW-only method gating: parametrized — mem_pool_scope and
    finalize_write raise "only valid in RW mode" when granted lock
    is RO.

  - _ptr_in_gms helper: half-open interval semantics (parametrized
    over inside / start / just-below-end / at-end / just-below-start
    / far-outside cases), no-mappings-attr handling, zero-va/size
    sentinel guards (parametrized).

  - _storage_nbytes helper: dtype-dependent sizing, view-independent
    sizing on real CPU tensors.

  - Protocol conformance: GMSBackend instances satisfy the
    runtime-checkable GPUMemoryBackend protocol; the protocol
    exposes the exact method set ``model_loader.py`` invokes (so a
    different backend like a future CudaIpcBackend can be plugged
    in without touching call sites).

Total runtime: ~0.6 s for all 136 tests on the development node.

Coverage gaps deliberately left for follow-ups:
  - Real GMS daemon end-to-end (requires gpu-memory-service install
    + GPU). See Section 15 of the design doc for the full E2E test
    plan (B2 GMS shadow memory, B6 failover floor).
  - Real MX server end-to-end (requires MX server build via Rust +
    Docker). See Section 15 (B1 bit-exactness, B3 P2P throughput,
    B4 cold-start headline).
  - LoadFormat.GMS branch in model_loader.py end-to-end (requires
    real model + real GMS daemon).

Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Made-with: Cursor
@chienchunhung chienchunhung force-pushed the dynamo-integration-prototype branch from 6bdcfd8 to 06491b6 Compare April 22, 2026 23:54
@chienchunhung
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@chienchunhung
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45034 [ run ] triggered by Bot. Commit: 06491b6 Link to invocation

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)

332-402: ⚠️ Potential issue | 🟠 Major

Keep buffer materialization separate from the GMS weight skip.

This now skips _apply_to_buffers_only(...) and the meta init path for all tensors in GMS mode, not just parameters. Any registered buffer that is not restored from the checkpoint will stay on meta, and the later post_load_weights() / MoE finalization path can still trip over it. The GMS special-case should bypass parameter preallocation, but buffers still need to be materialized.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 332 - 402, The
current GMS branch skips calling _apply_to_buffers_only and the buffer meta-init
path, leaving registered buffers on meta; ensure buffers are always
materialized: move the call to _apply_to_buffers_only(model,
allocate_buffer_on_cuda) out of the conditional that checks
self.model_weights_memory_tag and load_format != LoadFormat.GMS so it runs even
when load_format == LoadFormat.GMS, and likewise allow the init_meta_tensor path
to materialize buffers (i.e., don't gate buffer initialization by load_format !=
LoadFormat.GMS) while keeping the special-case skip for parameter preallocation
(allocate_weights_on_cuda and its virtual_memory_scope) when load_format ==
LoadFormat.GMS; key symbols to change: _apply_to_buffers_only,
allocate_buffer_on_cuda, init_meta_tensor, allocate_weights_on_cuda,
self.model_weights_memory_tag, and LoadFormat.GMS.
♻️ Duplicate comments (1)
tests/unittest/_torch/memory/test_gms_backend.py (1)

87-92: ⚠️ Potential issue | 🟠 Major

Add a successful connect() path to this suite.

The CPU-CI enablement is fixed, but the tests still only cover pre-connect behavior and failure paths. The main control path is still unpinned: connect() successfully granting RW/RO, setting _client / _is_rw, and resolving socket_path=None via the upstream helper. A fake happy-path client here would catch regressions in the path model_loader actually depends on. QA list updates are unnecessary because this is still unittest-only coverage.

As per coding guidelines, “Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix.”

Also applies to: 139-183

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/memory/test_gms_backend.py` around lines 87 - 92, Add a
“happy-path” connect() flow to the test test_socket_path_none_resolved_lazily by
creating a fake client and stubbing the upstream socket-path resolver so that
GMSBackend.connect() succeeds, sets backend._client and backend._is_rw
appropriately, and resolves socket_path when initialized with socket_path=None;
specifically, in the test arrange a MagicMock/fake client that returns expected
RW/RO capability, patch the get_socket_path (or equivalent helper used by
GMSBackend.connect) to return a concrete path, call backend.connect(), and
assert backend._client is the fake client, backend._is_rw is True/False as
expected, and backend._socket_path is the resolved path so the real control path
is covered.
🧹 Nitpick comments (1)
tensorrt_llm/llmapi/llm_args.py (1)

3684-3697: Use Literal[...] for mx_preshard_strategy.

This is a user-facing Pydantic field with a closed value set, but typing it as str hides those constraints from schema/docs and pushes basic value validation into the model validator. Literal["per_module", "global"] would make the contract explicit while keeping the cross-field warning in validate_mx_config.

As per coding guidelines, user-facing Python Pydantic fields should “Use Literal["value1", "value2"] instead of str in Python Pydantic fields when only certain values are allowed.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 3684 - 3697, Change the
mx_preshard_strategy field's type from str to a Literal restricted to the
allowed values (Literal["per_module", "global"]) so the Pydantic schema and docs
expose the closed set; add the necessary import for Literal (from typing or
typing_extensions as used in the repo), keep the default="per_module" and
existing description/status, and leave any cross-field validation logic in
validate_mx_config unchanged (it should continue to enforce/emit warnings or
raise NotImplementedError when needed).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py`:
- Around line 296-321: The loop currently uses seen: set[int] and skips
subsequent tensors sharing the same storage_ptr, leaving aliases pointing
outside GMS; change seen to a dict mapping original storage_ptr -> replacement
tensor (or its base_va) so after you create replacement via
gms_client.create_mapping and _tensor_from_pointer (and do replacement.copy_),
you store seen[storage_ptr] = replacement and for any later tensor with the same
storage_ptr (found in _iter_module_tensors) you assign tensor.data =
seen[storage_ptr] instead of continuing; keep the existing _ptr_in_gms check
(and treat tensors already in GMS as already-mapped) and reuse _storage_nbytes,
_tensor_from_pointer, replacement.copy_, and tensor.data = ... symbols to locate
where to implement this change.
- Around line 213-216: In connect(), if patch_empty_cache() raises an exception,
log the error (keep or upgrade the current logger.debug message) and immediately
return False so the connection fails instead of returning True; locate the
try/except around patch_empty_cache() in the connect() method and add a return
False in the except block (reference: connect() and patch_empty_cache()).

In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py`:
- Around line 295-321: The env var override + publish sequence in
checkpoint_loader (the block that sets MODEL_EXPRESS_URL/MODEL_NAME, calls
publish_model_params(model), and restores os.environ) must be serialized with a
module-level lock to prevent concurrent publish_as_source()/publish_model_params
calls from interleaving; add a module-level threading.Lock (e.g.,
_env_override_lock) and wrap the entire override/publish/restore sequence in a
with _env_override_lock: so the try/except/finally and the restore logic run
while holding the lock; ensure the lock is declared at module scope so other
functions in this module can reuse it.

In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 178-191: The _weights_presharded flag set by MX P2P/GMS is never
cleared, so _effective_tp_coords() can incorrectly return (1, 0) on subsequent
non-presharded loads; update Linear.load_weights (or add/ensure a call to
pre_reload_weights) to clear module._weights_presharded (or reconstruct the
Linear instance) before loading regular checkpoint weights so TP slicing is
re-enabled; specifically, ensure any reload path that calls Linear.load_weights
resets _weights_presharded (or invokes pre_reload_weights) so
_effective_tp_coords, the main weight/bias and quantization metadata are sliced
consistently.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 512-610: Assign self._gms_backend immediately after
gms_backend.connect() returns True and then wrap the subsequent GMS load logic
(the RW/RO branches, finalize_write/materialize_module calls, and any
draft-weight handling) in a try/except that on exception clears
self._gms_backend and calls the backend’s cleanup/disconnect method (e.g.,
gms_backend.disconnect() or the appropriate close/cleanup API) before
re-raising; this ensures the live GMS handle and CUDA hooks are released if
checkpoint_loader.load_weights, finalize_write, materialize_module, or other
calls fail.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3833-3835: Reject boolean inputs before treating numbers as
LoadFormat: in the load_format validator (the block that currently does "if
isinstance(v, int): return LoadFormat(v)"), add an explicit check for
isinstance(v, bool) and raise a ValidationError/TypeError so True/False do not
get interpreted as integers; keep the existing int path to map valid integers to
LoadFormat. Add a regression test named test_load_format_rejects_bool in
test_mx_gms_args.py that passes load_format=True and asserts a validation
failure. Also change the field declaration for mx_preshard_strategy from
"mx_preshard_strategy: str = Field(...)" to "mx_preshard_strategy:
Literal['per_module','global'] = Field(...)" so Pydantic enforces the allowed
values.

---

Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 332-402: The current GMS branch skips calling
_apply_to_buffers_only and the buffer meta-init path, leaving registered buffers
on meta; ensure buffers are always materialized: move the call to
_apply_to_buffers_only(model, allocate_buffer_on_cuda) out of the conditional
that checks self.model_weights_memory_tag and load_format != LoadFormat.GMS so
it runs even when load_format == LoadFormat.GMS, and likewise allow the
init_meta_tensor path to materialize buffers (i.e., don't gate buffer
initialization by load_format != LoadFormat.GMS) while keeping the special-case
skip for parameter preallocation (allocate_weights_on_cuda and its
virtual_memory_scope) when load_format == LoadFormat.GMS; key symbols to change:
_apply_to_buffers_only, allocate_buffer_on_cuda, init_meta_tensor,
allocate_weights_on_cuda, self.model_weights_memory_tag, and LoadFormat.GMS.

---

Duplicate comments:
In `@tests/unittest/_torch/memory/test_gms_backend.py`:
- Around line 87-92: Add a “happy-path” connect() flow to the test
test_socket_path_none_resolved_lazily by creating a fake client and stubbing the
upstream socket-path resolver so that GMSBackend.connect() succeeds, sets
backend._client and backend._is_rw appropriately, and resolves socket_path when
initialized with socket_path=None; specifically, in the test arrange a
MagicMock/fake client that returns expected RW/RO capability, patch the
get_socket_path (or equivalent helper used by GMSBackend.connect) to return a
concrete path, call backend.connect(), and assert backend._client is the fake
client, backend._is_rw is True/False as expected, and backend._socket_path is
the resolved path so the real control path is covered.

---

Nitpick comments:
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3684-3697: Change the mx_preshard_strategy field's type from str
to a Literal restricted to the allowed values (Literal["per_module", "global"])
so the Pydantic schema and docs expose the closed set; add the necessary import
for Literal (from typing or typing_extensions as used in the repo), keep the
default="per_module" and existing description/status, and leave any cross-field
validation logic in validate_mx_config unchanged (it should continue to
enforce/emit warnings or raise NotImplementedError when needed).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: ca8dadee-00ef-4aaf-b997-796f3c04c71d

📥 Commits

Reviewing files that changed from the base of the PR and between 0a2a950 and 06491b6.

📒 Files selected for processing (19)
  • setup.py
  • tensorrt_llm/_torch/memory/__init__.py
  • tensorrt_llm/_torch/memory/gpu_memory_backend.py
  • tensorrt_llm/_torch/models/checkpoints/__init__.py
  • tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py
  • tensorrt_llm/_torch/models/checkpoints/mx/__init__.py
  • tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/memory/test_gms_backend.py
  • tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py
  • tests/unittest/api_stability/references/llm.yaml
  • tests/unittest/llmapi/test_mx_gms_args.py
✅ Files skipped from review due to trivial changes (6)
  • tensorrt_llm/_torch/models/checkpoints/mx/init.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py
  • setup.py
  • tensorrt_llm/_torch/models/checkpoints/init.py
  • tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
  • tests/unittest/api_stability/references/llm.yaml
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py

Comment on lines +213 to +216
try:
patch_empty_cache()
except Exception as e:
logger.debug("GMS patch_empty_cache failed (non-fatal): %s", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's find and read the file to understand the context
find . -name "gpu_memory_backend.py" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 113


🏁 Script executed:

# Read the file to understand the context around lines 213-216
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '200,230p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1189


🏁 Script executed:

# Read more context including move_untracked_params() and the full connect() method
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '1,50p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2733


🏁 Script executed:

# Get the full connect() method to understand the return behavior
cat -n tensorrt_llm/_torch/memory/gpu_memory_backend.py | sed -n '180,225p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1956


🏁 Script executed:

# Find where patch_empty_cache() is defined
rg "def patch_empty_cache" --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if patch_empty_cache() is imported and where it comes from
rg "patch_empty_cache" tensorrt_llm/_torch/memory/gpu_memory_backend.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 590


🏁 Script executed:

# Search for any documentation or comments about what patch_empty_cache does
rg "patch_empty_cache" --type py -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 4942


🏁 Script executed:

# Check if there are any callers that depend on the return value of connect()
rg "\.connect\(\)" --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 345


Return False from connect() if patch_empty_cache() fails.

The patch prevents segfaults on GMS-backed VMM tensors when torch.cuda.empty_cache() is called (documented as critical in model_loader.py). If patching fails, the connection should fail (return False), not silently succeed. Currently the method returns True despite the patch failure, creating a latent crash condition for callers.

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 215-215: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 213 - 216, In
connect(), if patch_empty_cache() raises an exception, log the error (keep or
upgrade the current logger.debug message) and immediately return False so the
connection fails instead of returning True; locate the try/except around
patch_empty_cache() in the connect() method and add a return False in the except
block (reference: connect() and patch_empty_cache()).

Comment on lines +296 to +321
seen: set[int] = set()

with torch.no_grad():
for _name, tensor, tensor_type in _iter_module_tensors(model):
if tensor_type != "parameter" or tensor is None or not tensor.is_cuda:
continue

storage_ptr = tensor.untyped_storage().data_ptr()
if storage_ptr in seen:
continue
seen.add(storage_ptr)

if _ptr_in_gms(gms_client, int(tensor.data_ptr())):
continue

nbytes = _storage_nbytes(tensor)
base_va = gms_client.create_mapping(size=nbytes, tag=self._tag)
replacement = _tensor_from_pointer(
int(base_va),
list(tensor.shape),
list(tensor.stride()),
tensor.dtype,
device_index,
)
replacement.copy_(tensor)
tensor.data = replacement
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Rebind shared storages instead of skipping them.

seen deduplicates by the original storage pointer, but after the first tensor is migrated only that parameter is rebound to the GMS mapping. Any later alias that shared the same storage is skipped and keeps pointing outside the pool, which breaks tied weights and can leave finalize_write() with mixed tracked/untracked parameters.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/memory/gpu_memory_backend.py` around lines 296 - 321, The
loop currently uses seen: set[int] and skips subsequent tensors sharing the same
storage_ptr, leaving aliases pointing outside GMS; change seen to a dict mapping
original storage_ptr -> replacement tensor (or its base_va) so after you create
replacement via gms_client.create_mapping and _tensor_from_pointer (and do
replacement.copy_), you store seen[storage_ptr] = replacement and for any later
tensor with the same storage_ptr (found in _iter_module_tensors) you assign
tensor.data = seen[storage_ptr] instead of continuing; keep the existing
_ptr_in_gms check (and treat tensors already in GMS as already-mapped) and reuse
_storage_nbytes, _tensor_from_pointer, replacement.copy_, and tensor.data = ...
symbols to locate where to implement this change.

Comment on lines +295 to +321
env_overrides = {
"MODEL_EXPRESS_URL": self._mx_server_url,
"MODEL_NAME": resolved_name,
}
prior = {key: os.environ.get(key) for key in env_overrides}
for key, value in env_overrides.items():
os.environ[key] = value

try:
publish_model_params(model)
logger.info(
"Published weights to MX server at %s as model=%r",
self._mx_server_url,
resolved_name,
)
except Exception as e:
logger.warning(
"Failed to publish weights to MX server at %s: %s",
self._mx_server_url,
e,
)
finally:
for key, prior_value in prior.items():
if prior_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = prior_value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Serialize the env-var override around MX publishing.

MODEL_EXPRESS_URL and MODEL_NAME live in process-global os.environ, so two concurrent publish_as_source() calls can interleave here and publish a model under the wrong identity or server URL. Please guard the override/publish/restore sequence with a module-level lock.

🔒 Proposed fix
+import threading
 ...
+_MX_PUBLISH_ENV_LOCK = threading.Lock()
+
 ...
-        prior = {key: os.environ.get(key) for key in env_overrides}
-        for key, value in env_overrides.items():
-            os.environ[key] = value
-
-        try:
-            publish_model_params(model)
-            logger.info(
-                "Published weights to MX server at %s as model=%r",
-                self._mx_server_url,
-                resolved_name,
-            )
-        except Exception as e:
-            logger.warning(
-                "Failed to publish weights to MX server at %s: %s",
-                self._mx_server_url,
-                e,
-            )
-        finally:
-            for key, prior_value in prior.items():
-                if prior_value is None:
-                    os.environ.pop(key, None)
-                else:
-                    os.environ[key] = prior_value
+        with _MX_PUBLISH_ENV_LOCK:
+            prior = {key: os.environ.get(key) for key in env_overrides}
+            for key, value in env_overrides.items():
+                os.environ[key] = value
+
+            try:
+                publish_model_params(model)
+                logger.info(
+                    "Published weights to MX server at %s as model=%r",
+                    self._mx_server_url,
+                    resolved_name,
+                )
+            except Exception as e:
+                logger.warning(
+                    "Failed to publish weights to MX server at %s: %s",
+                    self._mx_server_url,
+                    e,
+                )
+            finally:
+                for key, prior_value in prior.items():
+                    if prior_value is None:
+                        os.environ.pop(key, None)
+                    else:
+                        os.environ[key] = prior_value
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 310-310: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines
295 - 321, The env var override + publish sequence in checkpoint_loader (the
block that sets MODEL_EXPRESS_URL/MODEL_NAME, calls publish_model_params(model),
and restores os.environ) must be serialized with a module-level lock to prevent
concurrent publish_as_source()/publish_model_params calls from interleaving; add
a module-level threading.Lock (e.g., _env_override_lock) and wrap the entire
override/publish/restore sequence in a with _env_override_lock: so the
try/except/finally and the restore logic run while holding the lock; ensure the
lock is declared at module scope so other functions in this module can reuse it.

Comment on lines +178 to +191
def _effective_tp_coords(module: Linear) -> tuple[int, int]:
"""Return (tp_size, tp_rank) accounting for ``_weights_presharded``.

When a module's weights are already TP-sharded for the local rank
(e.g. delivered via MX P2P transfer), TRT-LLM's standard TP slicing
must be skipped for *all* tensors associated with the module — the
main weight and bias *and* their quantization metadata
(weight_scale, input_scale, pre_quant_scale, etc.). Otherwise
full-size weights would be paired with re-sliced metadata, causing
shape mismatches or silent numerical errors on quantized models.
"""
if getattr(module, '_weights_presharded', False):
return 1, 0
return module.tp_size, module.tp_rank
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Assignments to _weights_presharded:"
rg -n -C2 '_weights_presharded\s*=' tensorrt_llm

echo
echo "Potential reload / reuse entry points:"
rg -n -C3 '\b(pre_reload_weights|load_weights_from_target_model|load_weights\()' tensorrt_llm/_torch

Repository: NVIDIA/TensorRT-LLM

Length of output: 50376


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the actual load_weights method in Linear to see if it clears the flag
echo "=== Linear.load_weights implementation ==="
sed -n '2600,2700p' tensorrt_llm/_torch/modules/linear.py | head -50

echo
echo "=== Check for pre_reload_weights in Linear ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/modules/linear.py

echo
echo "=== Check if modules are reconstructed or reused in model_loader ==="
rg -n -B5 -A5 'module._weights_presharded.*True' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -40

echo
echo "=== Check module init flow in model_loader ==="
rg -n 'Linear\(' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 3020


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# View the pre_reload_weights and load_weights methods in Linear
echo "=== Linear.pre_reload_weights (lines 2790-2800) ==="
sed -n '2790,2810p' tensorrt_llm/_torch/modules/linear.py

echo
echo "=== Linear.load_weights signature and start (lines 2795+) ==="
sed -n '2795,2850p' tensorrt_llm/_torch/modules/linear.py | head -60

echo
echo "=== Search for load_weights in Linear with context ==="
rg -n 'def load_weights' tensorrt_llm/_torch/modules/linear.py -A 10 | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 4646


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the actual Linear class's load_weights method
echo "=== Find Linear class load_weights ==="
rg -n 'class Linear' tensorrt_llm/_torch/modules/linear.py | head -5

echo
echo "=== Linear.load_weights in class (searching after line 2300) ==="
sed -n '2700,2800p' tensorrt_llm/_torch/modules/linear.py | grep -n 'def load_weights' -A 20 | head -40

echo
echo "=== Try broader search for load_weights in Linear class ==="
sed -n '2750,2795p' tensorrt_llm/_torch/modules/linear.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3123


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if pre_reload_weights is called in the model_loader before load_weights
echo "=== Check model_loader for pre_reload_weights calls ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/pyexecutor/model_loader.py -B 3 -A 3

echo
echo "=== Check modeling_utils for pre_reload_weights calls ==="
rg -n 'pre_reload_weights' tensorrt_llm/_torch/models/modeling_utils.py -B 3 -A 3

echo
echo "=== Verify that _weights_presharded is NOT reset anywhere ==="
rg -n '_weights_presharded\s*=\s*False' tensorrt_llm

Repository: NVIDIA/TensorRT-LLM

Length of output: 118


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# More comprehensive check
echo "=== All calls to pre_reload_weights ==="
rg -r 'pre_reload_weights' tensorrt_llm --type py

echo
echo "=== All assignments to _weights_presharded ==="
rg '_weights_presharded\s*=' tensorrt_llm --type py

echo
echo "=== Check if modules.load_weights is called after MX P2P sets the flag ==="
sed -n '450,480p' tensorrt_llm/_torch/pyexecutor/model_loader.py

echo
echo "=== Context around weight loading after MX P2P ==="
sed -n '420,470p' tensorrt_llm/_torch/pyexecutor/model_loader.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 50377


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Focus on the actual code flow - check if modules are reconstructed or reused
echo "=== Check if initialize_dummy_weights or model creation reconstructs Linear modules ==="
sed -n '490,500p' tensorrt_llm/_torch/pyexecutor/model_loader.py

echo
echo "=== Check the broader load flow - whether modules persist across loads ==="
sed -n '400,435p' tensorrt_llm/_torch/pyexecutor/model_loader.py | head -40

echo
echo "=== Confirm: no reset of _weights_presharded before load_weights ==="
grep -n 'def load_weights' tensorrt_llm/_torch/modules/linear.py -A 20 | grep -E '(_weights_presharded|pre_reload)' || echo "No reset found"

Repository: NVIDIA/TensorRT-LLM

Length of output: 2816


Clear the _weights_presharded flag before loading non-presharded weights, or ensure modules are reconstructed on reload.

The flag is initialized to False only in __init__ and set to True by MX P2P and GMS paths, but is never reset before subsequent loads. If a module undergoes presharded loading and is then reloaded from a regular checkpoint, TP slicing will incorrectly stay disabled, causing _effective_tp_coords() to return (1, 0) and full-size weights to be copied into per-rank buffers with catastrophic shape mismatches or silent numerical errors on quantized models.

Either reset the flag before non-presharded loads in Linear.load_weights(), verify that all reload paths reconstruct Linear instances, or call pre_reload_weights() before each load to clear any state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 178 - 191, The
_weights_presharded flag set by MX P2P/GMS is never cleared, so
_effective_tp_coords() can incorrectly return (1, 0) on subsequent
non-presharded loads; update Linear.load_weights (or add/ensure a call to
pre_reload_weights) to clear module._weights_presharded (or reconstruct the
Linear instance) before loading regular checkpoint weights so TP slicing is
re-enabled; specifically, ensure any reload path that calls Linear.load_weights
resets _weights_presharded (or invokes pre_reload_weights) so
_effective_tp_coords, the main weight/bias and quantization metadata are sliced
consistently.

Comment on lines +512 to +610
if not gms_backend.connect():
raise RuntimeError("Failed to connect to GMS at "
f"{self.llm_args.gms_socket_path}")

if gms_backend.is_rw:
# GMS RW path: load via checkpoint_loader under GMS
# memory pool so allocations go into shared memory.
# Note: MX P2P is NOT used in GMS RW mode because
# parameters are still meta tensors (no CUDA buffers
# for P2P to write into) and allocations must go through
# the GMS pool. The checkpoint_loader loads from disk,
# with the GMS pool intercepting all CUDA allocations.
device = torch.device('cuda')

with gms_backend.mem_pool_scope(device):
load_weights_kwargs = {"mapping": self.mapping}
# Honor model.llm_checkpoint_dir when the model
# rewrites its checkpoint source — same precedence
# as the AUTO branch above. Without this, GMS-only
# would silently load from the original
# checkpoint_dir for redirect-using model classes.
weight_source = (model.llm_checkpoint_dir if hasattr(
model, 'llm_checkpoint_dir') else checkpoint_dir)
weights = checkpoint_loader.load_weights(
weight_source, **load_weights_kwargs)

if weights:
self.weight_mapper = (
checkpoint_loader.get_initialized_weight_mapper(
model, config))
self._call_load_weights(model.load_weights, weights,
self.weight_mapper)

# Speculative decoding: load draft weights INSIDE
# the same mem_pool_scope so they land in the GMS
# pool and are picked up by finalize_write's model
# walk. (Deferring commit until after both main and
# draft weights are in the pool avoids the otherwise-
# required separate commit-per-model dance flagged
# in PR #13045 review.)
if (self.spec_config is not None and self.spec_config.
spec_dec_mode.need_load_draft_weights()):
draft_weights = checkpoint_loader.load_weights(
self.spec_config.speculative_model,
mapping=self.mapping)

draft_model_arch = (
model.draft_config.pretrained_config.
architectures[0])
draft_weight_mapper = AutoCheckpointMapper.get(
checkpoint_loader.checkpoint_format,
draft_model_arch)
draft_weight_mapper.init_model_and_config(
model.draft_model, model.draft_config)

self._call_load_weights(model.load_draft_weights,
draft_weights,
draft_weight_mapper)

# Move any parameters that landed outside the GMS
# pool (e.g., buffers created during weight-loading
# transforms) into the pool, then drain the caching
# allocator. Order and pool-scope membership mirror
# the upstream gpu_memory_service.integrations.trtllm
# ``_load_rw`` reference path:
# https://github.com/ai-dynamo/dynamo/blob/main/lib/gpu_memory_service/integrations/trtllm/model_loader.py
gms_backend.move_untracked_params(model)
torch.cuda.empty_cache()

# Commit weights for RO readers and transition to RO mode.
# finalize_write walks the full model tree (including
# model.draft_model when present) so a single commit
# covers both main and draft weights.
gms_backend.finalize_write(model)
logger.info(
"LoadFormat.GMS (RW): loaded and committed "
"weights via %s", checkpoint_loader.checkpoint_format)
else:
# GMS RO path: zero-copy import from existing GMS pool.
# post_load_weights() must run BEFORE materialize so
# that module aliases are set up correctly.
for module in model.modules():
if hasattr(module, 'post_load_weights') and not getattr(
module, '_weights_removed', False):
module.post_load_weights()

gms_backend.materialize_module(model)
# Note: MoE load balancer finalization (lines below) works
# because register functions access module attributes at
# call time, resolving to materialized CUDA tensors.
# The patch_empty_cache() applied during connect() is
# critical here — MoE's make_tensor_host_accessible()
# calls torch.cuda.empty_cache() which would segfault
# on VMM-backed GMS tensors without the patch.
logger.info("LoadFormat.GMS (RO): zero-copy materialized "
"weights")

# Store backend reference for cleanup during shutdown.
self._gms_backend = gms_backend
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Unwind the GMS connection on partial-load failures.

connect() succeeds at Line 512, but self._gms_backend is not assigned until Line 609. If load_weights(), the draft-weight path, finalize_write(), or materialize_module() raises first, the live GMS handle and its CUDA hooks leak, and cleanup() cannot recover because _gms_backend is still None. Store the backend immediately after a successful connect() and clean it up in a try/except before re-raising.

Possible fix
                 if not gms_backend.connect():
                     raise RuntimeError("Failed to connect to GMS at "
                                        f"{self.llm_args.gms_socket_path}")
+                self._gms_backend = gms_backend
 
-                if gms_backend.is_rw:
-                    # GMS RW path: load via checkpoint_loader under GMS
-                    # memory pool so allocations go into shared memory.
-                    # Note: MX P2P is NOT used in GMS RW mode because
-                    # parameters are still meta tensors (no CUDA buffers
-                    # for P2P to write into) and allocations must go through
-                    # the GMS pool.  The checkpoint_loader loads from disk,
-                    # with the GMS pool intercepting all CUDA allocations.
-                    device = torch.device('cuda')
+                try:
+                    if gms_backend.is_rw:
+                        # GMS RW path: load via checkpoint_loader under GMS
+                        # memory pool so allocations go into shared memory.
+                        device = torch.device('cuda')
 
-                    with gms_backend.mem_pool_scope(device):
-                        ...
+                        with gms_backend.mem_pool_scope(device):
+                            ...
 
-                    gms_backend.finalize_write(model)
-                    logger.info(
-                        "LoadFormat.GMS (RW): loaded and committed "
-                        "weights via %s", checkpoint_loader.checkpoint_format)
-                else:
-                    ...
-
-                # Store backend reference for cleanup during shutdown.
-                self._gms_backend = gms_backend
+                        gms_backend.finalize_write(model)
+                        logger.info(
+                            "LoadFormat.GMS (RW): loaded and committed "
+                            "weights via %s", checkpoint_loader.checkpoint_format)
+                    else:
+                        ...
+                except Exception:
+                    gms_backend.cleanup()
+                    self._gms_backend = None
+                    raise
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 512 - 610,
Assign self._gms_backend immediately after gms_backend.connect() returns True
and then wrap the subsequent GMS load logic (the RW/RO branches,
finalize_write/materialize_module calls, and any draft-weight handling) in a
try/except that on exception clears self._gms_backend and calls the backend’s
cleanup/disconnect method (e.g., gms_backend.disconnect() or the appropriate
close/cleanup API) before re-raising; this ensures the live GMS handle and CUDA
hooks are released if checkpoint_loader.load_weights, finalize_write,
materialize_module, or other calls fail.

Comment on lines +3833 to 3835
if isinstance(v, int):
return LoadFormat(v)
load_format = v.upper()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
from enum import Enum
class LoadFormat(Enum):
    AUTO = 0
    DUMMY = 1
    VISION_ONLY = 2
    GMS = 3

print("bool_is_int:", isinstance(True, int))
print("enum_from_true:", LoadFormat(True))
PY

Repository: NVIDIA/TensorRT-LLM

Length of output: 113


🏁 Script executed:

# First, let's examine the convert_load_format method in the file
sed -n '3820,3850p' tensorrt_llm/llmapi/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1326


🏁 Script executed:

# Check the mx_preshard_strategy field definition
rg -A 5 'mx_preshard_strategy' tensorrt_llm/llmapi/llm_args.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 1126


🏁 Script executed:

# Look for existing test file and tests related to load_format
find . -path '*/tests/unittest/llmapi/test_mx_gms_args.py' -exec cat {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 16427


🏁 Script executed:

# Also check if there are other load_format validator tests
rg -A 10 'load_format.*True\|test.*load_format' tests/ --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


Reject booleans before the new integer load_format path.

bool is an int in Python, so a config like load_format: true flows through LoadFormat(v) as LoadFormat.DUMMY (since True == 1) instead of failing fast. This silently switches the loader into random-weight initialization.

Proposed fix
    `@field_validator`('load_format', mode='before')
    `@classmethod`
    def convert_load_format(cls, v):
        if isinstance(v, LoadFormat):
            return v
+       if isinstance(v, bool):
+           raise ValueError(f"Invalid LoadFormat: {v}")
        if isinstance(v, int):
            return LoadFormat(v)
+       if not isinstance(v, str):
+           raise ValueError(f"Invalid LoadFormat: {v}")
        load_format = v.upper()
        if load_format not in LoadFormat.__members__:
            raise ValueError(f"Invalid LoadFormat: {v}")
        return LoadFormat[load_format]

Add a regression test in tests/unittest/llmapi/test_mx_gms_args.py for load_format=True.

Also, change mx_preshard_strategy: str = Field(...) to mx_preshard_strategy: Literal["per_module", "global"] = Field(...) to match the constrained value set, per Pydantic field typing guidelines.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 3833 - 3835, Reject boolean
inputs before treating numbers as LoadFormat: in the load_format validator (the
block that currently does "if isinstance(v, int): return LoadFormat(v)"), add an
explicit check for isinstance(v, bool) and raise a ValidationError/TypeError so
True/False do not get interpreted as integers; keep the existing int path to map
valid integers to LoadFormat. Add a regression test named
test_load_format_rejects_bool in test_mx_gms_args.py that passes
load_format=True and asserts a validation failure. Also change the field
declaration for mx_preshard_strategy from "mx_preshard_strategy: str =
Field(...)" to "mx_preshard_strategy: Literal['per_module','global'] =
Field(...)" so Pydantic enforces the allowed values.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45034 [ run ] completed with state SUCCESS. Commit: 06491b6
/LLM/main/L0_MergeRequest_PR pipeline #35341 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chienchunhung
Copy link
Copy Markdown
Collaborator Author

Opened another PR (#13531) to extract the MX-only part.

Changing this PR back to the draft mode to keep it as a reference; no need for code review for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants