Skip to content

Commit 1b7b013

Browse files
committed
[TRTLLM-11851][feat] MX adapter improvements: env-var fallback, query timeout, model_name plumbing
Three discrete improvements to the MX side of PR NVIDIA#13045 driven by review feedback from MX team's downstream PR (chienchunhung/TensorRT-LLM #1) — three orchestration ergonomics fixes landed as one focused commit so reviewers see them as a clean slice on top of the prototype. (1) MODEL_EXPRESS_URL env-var fallback — at validator level TorchLlmArgs.validate_mx_config now honors the upstream ``MODEL_EXPRESS_URL`` env var when ``checkpoint_format='MX'`` and ``mx_server_url`` is unset. Resolution happens at validator time so the value ends up on ``llm_args.mx_server_url`` (visible to logging, /startup_metrics, downstream code) instead of being silently re-read from env by the loader. Lets orchestrators (Dynamo) configure MX via the environment without plumbing every CLI knob, while keeping resolution in one place. Explicit ``mx_server_url=`` always wins. The env-var fallback only fires when MX is the active checkpoint format (so HF-only configs aren't surprised by an unrelated env var). Empty string in env is treated as unset. (2) MX_SOURCE_QUERY_TIMEOUT defensive default MXCheckpointLoader.__init__ calls ``os.environ.setdefault("MX_SOURCE_QUERY_TIMEOUT", "30")`` whenever an MX server URL is configured. Caps cold-cluster first-replica startup at 30 s instead of upstream's 1-hour default (the polling in MxLiveWeightLoader._query_source). setdefault semantics preserve any explicit user value. HF-only loads (no MX URL) don't touch the env at all. The proper upstream-side fix is a non-blocking source-query API (tracked as MX-4 in §15 of the design doc); this defensive default caps the worst case until that lands. (3) model_name plumbing with HF-snapshot-aware resolver Plumbs ``llm_args.model → MXCheckpointLoader(model_name=...)`` so upstream's ``publish_model_params()`` publishes under the user-supplied Hub ID (e.g. "Qwen/Qwen2.5-72B-Instruct") instead of the "unknown" sentinel. - MXCheckpointLoader takes a new optional ``model_name`` constructor arg (Union[str, Path]). Coerced to str at construction time. - publish_as_source() now sets BOTH MODEL_EXPRESS_URL and MODEL_NAME env vars (resolving identity via the priority order below) and restores both env vars in finally. publish_model_params() reads them via env, as documented. - Identity resolution order: explicit constructor arg → MODEL_NAME env → checkpoint_dir basename (with HF-snapshot path unmangling) → "unknown". - HF cache layout (".../models--<org>--<name>/snapshots/<sha>/") is unmangled back to "<org>/<name>" instead of returning the commit hash. - _construct_checkpoint_loader plumbs ``mx_model_name`` through; py_executor_creator.py extracts it from llm_args.model. Both env-var dances (MODEL_EXPRESS_URL + MODEL_NAME) collapse into one direct call when MX-2 (public build_identity) lands upstream. Tests for these three additions are in the next commit. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
1 parent 8ecfa78 commit 1b7b013

4 files changed

Lines changed: 153 additions & 17 deletions

File tree

tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
``HfCheckpointLoader`` base class.
2929
"""
3030

31-
from typing import Any, Optional
31+
import os
32+
from pathlib import Path
33+
from typing import Any, Optional, Union
3234

3335
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
3436
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
@@ -38,6 +40,16 @@
3840
from tensorrt_llm.logger import logger
3941
from tensorrt_llm.mapping import Mapping
4042

43+
# Defensive default for the upstream ``MX_SOURCE_QUERY_TIMEOUT`` env var.
44+
# The upstream ``MxLiveWeightLoader`` polls the MX server every 5 s for up
45+
# to ``MX_SOURCE_QUERY_TIMEOUT`` seconds (default 3600 = 1 hour) waiting
46+
# for a source. On a cold cluster (no donor up yet), this means the very
47+
# first replica blocks for an hour before falling back to disk. We cap
48+
# the default at 30 s so first-replica startup degrades gracefully; users
49+
# can still override via the env var or a future per-loader knob.
50+
# Tracked as MX-4 in §15 (non-blocking source-query API upstream).
51+
_MX_SOURCE_QUERY_TIMEOUT_DEFAULT_S = "30"
52+
4153

4254
@register_checkpoint_loader("MX")
4355
class MXCheckpointLoader(HfCheckpointLoader):
@@ -68,6 +80,7 @@ def __init__(
6880
weight_mapper: Optional[BaseWeightMapper] = None,
6981
config_loader: Optional[BaseConfigLoader] = None,
7082
mx_server_url: Optional[str] = None,
83+
model_name: Optional[Union[str, Path]] = None,
7184
):
7285
super().__init__(
7386
weight_loader=weight_loader,
@@ -78,8 +91,24 @@ def __init__(
7891
# caller reading self._checkpoint_format directly also sees "MX".
7992
self._checkpoint_format = "MX"
8093
self._mx_server_url = mx_server_url
94+
# ``model_name`` is the human-readable identity to publish/look up
95+
# under on the MX server. Typically the user-supplied
96+
# ``llm_args.model`` (a Hub ID like ``"Qwen/Qwen2.5-72B-Instruct"``
97+
# or a local path). ``publish_as_source()`` resolves it via
98+
# :func:`_resolve_mx_model_name` (with HF-snapshot path fallback).
99+
self._model_name = str(model_name) if model_name is not None else None
81100
self._p2p_succeeded = False
82101

102+
# Defensive default for upstream's source-query timeout. Only
103+
# applied when an MX server URL is configured (so HF-only loads
104+
# are unaffected). Uses ``setdefault`` so an explicit user value
105+
# always wins.
106+
if mx_server_url is not None:
107+
os.environ.setdefault(
108+
"MX_SOURCE_QUERY_TIMEOUT",
109+
_MX_SOURCE_QUERY_TIMEOUT_DEFAULT_S,
110+
)
111+
83112
@property
84113
def checkpoint_format(self) -> str:
85114
"""Override parent's checkpoint_format to return 'MX'."""
@@ -89,6 +118,17 @@ def checkpoint_format(self) -> str:
89118
def mx_server_url(self) -> Optional[str]:
90119
return self._mx_server_url
91120

121+
@property
122+
def model_name(self) -> Optional[str]:
123+
"""Explicit model identity passed to the constructor (if any).
124+
125+
Note this is the *as-configured* value (e.g. ``llm_args.model``),
126+
not the final resolved identity that ends up in the published
127+
``MODEL_NAME``. The full resolution (with env var and basename
128+
fallbacks) happens inside :meth:`publish_as_source`.
129+
"""
130+
return self._model_name
131+
92132
@property
93133
def p2p_succeeded(self) -> bool:
94134
"""Whether the last load_weights() call used P2P transfer.
@@ -221,11 +261,12 @@ def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str
221261
mapping: Distributed mapping. Currently unused — kept for
222262
signature symmetry with the prior prototype API and for
223263
forward-compat with future upstream signatures.
224-
checkpoint_dir: Checkpoint directory. Currently unused —
225-
upstream uses the ``MODEL_NAME`` env var for identity.
264+
checkpoint_dir: Checkpoint directory. Used as a last-resort
265+
fallback for resolving the ``MODEL_NAME`` identity when
266+
neither ``model_name`` was passed to the constructor nor
267+
``MODEL_NAME`` is set in the environment.
226268
"""
227-
# mapping/checkpoint_dir are deliberately unused; see docstring.
228-
del mapping, checkpoint_dir
269+
del mapping # currently unused; see docstring.
229270

230271
if self._mx_server_url is None:
231272
return
@@ -238,17 +279,29 @@ def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str
238279
logger.debug("modelexpress library not installed; skipping MX publish.")
239280
return
240281

241-
# Upstream publish_model_params reads MODEL_EXPRESS_URL from env;
242-
# set it from our config so the per-server URL is respected.
243-
import os
282+
# Upstream publish_model_params reads MODEL_EXPRESS_URL and
283+
# MODEL_NAME from the environment. Set both from our resolved
284+
# configuration so per-instance values (URL passed via
285+
# llm_args.mx_server_url, identity from llm_args.model) are
286+
# respected, then restore prior state. Tracked as MX-2 in §15
287+
# (the env-var dance goes away when upstream exports a public
288+
# ``build_identity()`` we can call directly).
289+
resolved_name = _resolve_mx_model_name(self._model_name, checkpoint_dir)
290+
291+
env_overrides = {
292+
"MODEL_EXPRESS_URL": self._mx_server_url,
293+
"MODEL_NAME": resolved_name,
294+
}
295+
prior = {key: os.environ.get(key) for key in env_overrides}
296+
for key, value in env_overrides.items():
297+
os.environ[key] = value
244298

245-
prior_url = os.environ.get("MODEL_EXPRESS_URL")
246-
os.environ["MODEL_EXPRESS_URL"] = self._mx_server_url
247299
try:
248300
publish_model_params(model)
249301
logger.info(
250-
"Published weights to MX server at %s",
302+
"Published weights to MX server at %s as model=%r",
251303
self._mx_server_url,
304+
resolved_name,
252305
)
253306
except Exception as e:
254307
logger.warning(
@@ -257,7 +310,65 @@ def publish_as_source(self, model, mapping: Mapping = None, checkpoint_dir: str
257310
e,
258311
)
259312
finally:
260-
if prior_url is None:
261-
os.environ.pop("MODEL_EXPRESS_URL", None)
262-
else:
263-
os.environ["MODEL_EXPRESS_URL"] = prior_url
313+
for key, prior_value in prior.items():
314+
if prior_value is None:
315+
os.environ.pop(key, None)
316+
else:
317+
os.environ[key] = prior_value
318+
319+
320+
# ---------------------------------------------------------------------------
321+
# Module-level helpers
322+
# ---------------------------------------------------------------------------
323+
324+
325+
def _resolve_mx_model_name(model_name_arg: Optional[str], checkpoint_dir: Optional[str]) -> str:
326+
"""Resolve a stable model identity for publishing to the MX server.
327+
328+
Resolution order (first non-empty wins):
329+
330+
1. ``model_name_arg`` — the explicit value passed at construction
331+
time (typically ``llm_args.model``: a Hub ID like
332+
``"Qwen/Qwen2.5-72B-Instruct"`` or a local path).
333+
2. ``MODEL_NAME`` env var — upstream's existing convention.
334+
3. ``checkpoint_dir`` basename, with HF-snapshot path fallback so
335+
``.../models--<org>--<name>/snapshots/<sha>/`` resolves to
336+
``"<org>/<name>"`` instead of the commit hash.
337+
4. Literal ``"unknown"`` — matches upstream's own sentinel.
338+
"""
339+
candidate = model_name_arg or os.environ.get("MODEL_NAME") or checkpoint_dir
340+
if not candidate:
341+
return "unknown"
342+
return _normalize_model_identity(str(candidate))
343+
344+
345+
def _normalize_model_identity(s: str) -> str:
346+
"""Convert a model identifier to a stable, human-readable name.
347+
348+
Hub IDs (``"org/name"``) and arbitrary user-provided strings are
349+
returned unchanged. Filesystem paths are reduced to a basename, with
350+
HuggingFace cache snapshot layouts (``snapshots/<commit-sha>/``)
351+
walked up to recover the original ``"org/name"`` identity.
352+
"""
353+
if not s:
354+
return "unknown"
355+
356+
# Heuristic: a Hub ID is bare ``"name"`` or ``"org/name"``. Anything
357+
# that starts with a path separator/expansion or contains more than
358+
# one "/" is treated as a path. Single-"/" strings remain ambiguous;
359+
# we side with the Hub ID interpretation unless the path also exists
360+
# on disk (in which case we assume the user gave us a real path).
361+
looks_like_path = s.startswith(("/", "./", "../", "~")) or s.count("/") > 1 or os.path.exists(s)
362+
if not looks_like_path:
363+
return s
364+
365+
p = Path(s).expanduser()
366+
name = p.name
367+
if name and "snapshots" in p.parts:
368+
# HF cache layout: ``.../models--<org>--<name>/snapshots/<sha>/``.
369+
# Walk up to find the ``models--<org>--<name>`` directory and
370+
# un-mangle it back to ``"<org>/<name>"``.
371+
for ancestor in p.parents:
372+
if ancestor.name.startswith("models--"):
373+
return ancestor.name[len("models--") :].replace("--", "/")
374+
return name or "unknown"

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _construct_checkpoint_loader(
171171
checkpoint_format: Optional[str],
172172
*,
173173
mx_server_url: Optional[str] = None,
174+
mx_model_name: Optional[str] = None,
174175
) -> Optional[BaseCheckpointLoader]:
175176
if backend == "_autodeploy":
176177
return None
@@ -187,8 +188,11 @@ def _construct_checkpoint_loader(
187188

188189
# Pass extra kwargs for format-specific loaders (e.g. MX).
189190
extra_kwargs: dict = {}
190-
if checkpoint_format == "MX" and mx_server_url is not None:
191-
extra_kwargs["mx_server_url"] = mx_server_url
191+
if checkpoint_format == "MX":
192+
if mx_server_url is not None:
193+
extra_kwargs["mx_server_url"] = mx_server_url
194+
if mx_model_name is not None:
195+
extra_kwargs["model_name"] = mx_model_name
192196

193197
checkpoint_loader = BaseCheckpointLoader.get(
194198
checkpoint_format=checkpoint_format,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,16 @@ def create_py_executor(
252252
skip_est = os.environ.get("TRTLLM_SKIP_KV_CACHE_ESTIMATION", '0') == '1'
253253
torch.cuda.set_per_process_memory_fraction(1.0)
254254
# Apply model-specific defaults early, before destructuring llm_args fields
255+
# Pass llm_args.model through to MXCheckpointLoader so it can publish
256+
# to the MX server under the user-supplied identity (Hub ID or local
257+
# path basename) instead of defaulting to "unknown".
255258
checkpoint_loader = _construct_checkpoint_loader(
256259
llm_args.backend,
257260
llm_args.checkpoint_loader,
258261
llm_args.checkpoint_format,
259262
mx_server_url=llm_args.mx_server_url,
263+
mx_model_name=str(llm_args.model)
264+
if llm_args.model is not None else None,
260265
)
261266
llm_args = ModelLoader.load_config_and_apply_defaults(
262267
checkpoint_dir, llm_args, checkpoint_loader)

tensorrt_llm/llmapi/llm_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3940,6 +3940,22 @@ def validate_checkpoint_format(self):
39403940

39413941
@model_validator(mode="after")
39423942
def validate_mx_config(self) -> 'TorchLlmArgs':
3943+
# When MX is the active checkpoint format and the user did not
3944+
# explicitly set ``mx_server_url``, honor the ``MODEL_EXPRESS_URL``
3945+
# env var that the upstream ``modelexpress`` library reads
3946+
# (see ``modelexpress.client._get_server_url``). This lets
3947+
# orchestrators (e.g. Dynamo) configure MX via the environment
3948+
# without plumbing every CLI knob through, while keeping the
3949+
# resolved value visible on ``llm_args.mx_server_url`` for
3950+
# logging, ``/startup_metrics``, and downstream code paths.
3951+
if (self.checkpoint_format == "MX" and self.mx_server_url is None):
3952+
env_url = os.environ.get("MODEL_EXPRESS_URL")
3953+
if env_url:
3954+
logger.info(
3955+
"mx_server_url not set; using MODEL_EXPRESS_URL=%s "
3956+
"from environment.", env_url)
3957+
self.mx_server_url = env_url
3958+
39433959
if self.mx_server_url is not None and self.checkpoint_format != "MX":
39443960
logger.warning(
39453961
"mx_server_url is set but checkpoint_format is '%s', not "

0 commit comments

Comments
 (0)