Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/blossom-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
"amirkl94",
"amitz-nv",
"amukkara",
"anikaj-eng",
"anish-shanbhag",
"arekay",
"arysef",
Expand Down Expand Up @@ -144,8 +145,8 @@ jobs:
"Jackch-NV",
"JadoTu",
"jaedeok-nvidia",
"jdemouth-nvidia",
"janbernloehr",
"jdemouth-nvidia",
"JennyLiu-nv",
"jershi425",
"jgangani",
Expand All @@ -172,8 +173,8 @@ jobs:
"katec846",
"Kefeng-Duan",
"KingsleyLiu-NV",
"KrishnanPrash",
"kris1025",
"KrishnanPrash",
"kunlunl",
"kxdc",
"kyleliang-nv",
Expand Down Expand Up @@ -303,6 +304,7 @@ jobs:
"tijyojwad",
"timlee0212",
"timothygao8710",
"tingyangk",
"Tom-Zheng",
"tomeras91",
"tongyuantongyu",
Expand Down Expand Up @@ -371,8 +373,8 @@ jobs:
"zerollzeng",
"zhanga5",
"zhangcl",
"zhaoyangwang-nvidia",
"ZhanruiSunCh",
"zhaoyangwang-nvidia",
"zhengd-nv",
"zhenhuaw-me",
"zheyuf",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ build
colored
cuda-python>=13
diffusers>=0.27.0
ftfy
lark
mpi4py
numpy>=2.0.0,<2.4 # numba 0.63.1 requires numpy<2.4
Expand Down
17 changes: 17 additions & 0 deletions tensorrt_llm/_torch/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(

super().__init__(*args, **kwargs)
self._async_initialized = False
self._paused = False

async def setup_async(self):
"""Setup the LLM asynchronously."""
Expand Down Expand Up @@ -94,6 +95,22 @@ async def collective_rpc(
method, args, kwargs, unique_reply_rank=unique_reply_rank, target_ranks=target_ranks
)

def generate_async(self, *args, **kwargs):
if self._paused:
raise RuntimeError(
"AsyncLLM is paused. Call resume_generation() before submitting new requests."
)
return super().generate_async(*args, **kwargs)

async def pause_generation(self) -> None:
"""Abort all in-flight requests and block new ones until resume_generation() is called."""
self._paused = True
self._executor.abort_all_requests()

async def resume_generation(self) -> None:
"""Allow new generation requests after a pause_generation() call."""
self._paused = False

def __await__(self):
return self.setup_async().__await__()

Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,10 +670,11 @@ def forward_impl(
attention_mask_data=attention_mask_data,
)
wrapper = metadata.get_ragged_prefill_wrapper(plan_params)
# cuDNN's ragged prefill kernel assumes contiguous NHD tensors.
wrapper.run(
q,
k,
v,
q.contiguous(),
k.contiguous(),
v.contiguous(),
out=output.view(-1, self.num_heads, self.head_dim),
)
return
Expand Down
42 changes: 38 additions & 4 deletions tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import multiprocessing
import os
Expand All @@ -8,13 +22,14 @@
import safetensors
import torch
import tqdm
from mpi4py import MPI as _MPI

from tensorrt_llm._torch.models.checkpoints.base_weight_loader import (
BaseWeightLoader, ConsumableWeightsDict)
from tensorrt_llm._torch.models.modeling_utils import (
register_checkpoint_weight_loader, run_concurrently)
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
local_mpi_size)
from tensorrt_llm._utils import (ENABLE_MULTI_DEVICE, local_mpi_barrier,
local_mpi_comm, local_mpi_rank, local_mpi_size)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

Expand All @@ -26,6 +41,24 @@ class HfWeightLoader(BaseWeightLoader):
Loads weights from SafeTensors/bin/pth files.
"""

@staticmethod
def _get_local_available_host_memory() -> int:
"""Determine the minimum available memory observed on the local node
and distribute it to all local ranks

Because psutil.virtual_memory().available is just a snapshot in time,
it is possible for the local ranks to get different numbers due to
timing differences. This can lead to disagreement among the local ranks
as to whether prefetch should be enabled, which causes a deadlock,
because the ranks that think prefetch is enabled will wait at a local
mpi barrier indefinitely for the ranks that do not.
"""
available_host_memory = psutil.virtual_memory().available
if ENABLE_MULTI_DEVICE:
return local_mpi_comm().allreduce(available_host_memory,
op=_MPI.MIN)
return available_host_memory

def load_weights(self, checkpoint_dir: str,
mapping: Mapping) -> dict[str, Any]:
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
Expand All @@ -44,8 +77,9 @@ def load_weights(self, checkpoint_dir: str,
# If the layer number is overridden, it indicates that only a subset of layers are loaded.
# Prefetching all layers is unnecessary.
num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0"))
enable_prefetch = prefetch_size < psutil.virtual_memory(
).available * 0.9 and num_layers == 0
enable_prefetch = (prefetch_size
< self._get_local_available_host_memory() * 0.9
and num_layers == 0)
if enable_prefetch:
logger.info(
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/_torch/models/modeling_radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,19 @@ def __init__(

self.metadata_cls = attention_utils.get_attention_backend(
model_config.attn_backend).Metadata
self.attn_metadata = self.metadata_cls(
metadata_kwargs = dict(
max_num_requests=8192, # TODO: Make this dynamic
max_num_tokens=model_config.max_num_tokens,
kv_cache_manager=None,
)
if model_config.attn_backend == "FLASHINFER":
# FlashInfer's original default kv_layout is "NHD". TRT-LLM changed
# the default to "HND" for paged KV cache paths (see PR #6917).
# For ModelingRadio ragged prefill (kv_cache_manager=None), we
# explicitly use "NHD" because ragged k/v tensors computed directly
# from input are always in NHD format ([tokens, heads, dim]).
metadata_kwargs["kv_layout"] = "NHD"
self.attn_metadata = self.metadata_cls(**metadata_kwargs)

def prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int],
attn_metadata: AttentionMetadata):
Expand Down
45 changes: 30 additions & 15 deletions tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,51 @@ class TrtllmAttentionMetadata:
max_batch_size: Initial batch size hint. Will grow automatically if exceeded.
max_seq_len: Initial sequence length hint. Will grow automatically if exceeded.
device: Target device for tensors.
attention_metadata_state: Mutable model-scoped state shared by all
attention layers in one model instance.
"""

def __init__(
self,
max_batch_size: int = 16,
max_seq_len: int = 4096,
device: Optional[torch.device] = None,
attention_metadata_state: Optional[dict] = None,
):
# These are initial hints, not hard limits - capacity grows as needed
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.device = device or torch.device("cuda")
if attention_metadata_state is None:
raise ValueError(
"TRTLLM attention requires `attention_metadata_state` to be provided "
"by visual-gen config for model-scoped metadata sharing."
)
self._metadata_state = attention_metadata_state

# Lazily created BaseTrtllmAttentionMetadata
self._metadata: Optional[BaseTrtllmAttentionMetadata] = None

# Track allocated capacity
self._allocated_batch_size = 0
self._allocated_max_seq_len = 0
self._metadata: Optional[BaseTrtllmAttentionMetadata] = self._metadata_state["metadata"]

# Track prepared state
self._cached_seq_lens: Optional[torch.Tensor] = None
self._prepared = False

def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool:
"""Check if we need to create new metadata (capacity change)."""
metadata = self._metadata_state["metadata"]
allocated_batch_size, allocated_max_seq_len = self._metadata_state["capacity"]
return (
self._metadata is None
or batch_size > self._allocated_batch_size
or max_seq_len > self._allocated_max_seq_len
metadata is None
or batch_size > allocated_batch_size
or max_seq_len > allocated_max_seq_len
)

def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
"""Check if we need to call prepare() (seq_lens changed)."""
"""Check if we need to call prepare() (seq_lens changed).

Assumes uniform sequence length per batch; if per-sample lengths vary,
we may need to check seq_lens tensor instead.
"""
if not self._prepared:
return True
if self._cached_seq_lens is None:
Expand All @@ -89,9 +100,9 @@ def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:

def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
"""Create new metadata with given capacity."""
# Allocate with some headroom to avoid frequent reallocation
alloc_batch = max(batch_size, self._allocated_batch_size)
alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len)
prev_batch, prev_seq = self._metadata_state["capacity"]
alloc_batch = max(batch_size, prev_batch)
alloc_seq_len = max(max_seq_len, prev_seq)

self._metadata = BaseTrtllmAttentionMetadata(
max_num_requests=alloc_batch,
Expand All @@ -102,8 +113,8 @@ def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
runtime_features=AttentionRuntimeFeatures(),
)

self._allocated_batch_size = alloc_batch
self._allocated_max_seq_len = alloc_seq_len
self._metadata_state["metadata"] = self._metadata
self._metadata_state["capacity"] = (alloc_batch, alloc_seq_len)
self._prepared = False # Reset prepare state on new metadata

def prepare(
Expand All @@ -116,7 +127,7 @@ def prepare(

Lazy behavior:
- Creates metadata only when capacity needs increase
- Calls prepare() only when seq_lens actually change
- Calls prepare() only when (batch_size, max_seq_len) actually change
"""
if isinstance(seq_lens, int):
seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32)
Expand All @@ -127,6 +138,8 @@ def prepare(

if self._needs_new_metadata(batch_size, max_seq_len):
self._create_metadata(batch_size, max_seq_len)
else:
self._metadata = self._metadata_state["metadata"]

if self._needs_prepare(batch_size, seq_lens_tensor):
self._metadata.seq_lens = seq_lens_tensor
Expand Down Expand Up @@ -165,6 +178,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
max_batch_size: int = 16,
max_seq_len: int = 4096,
attention_metadata_state: Optional[dict] = None,
):
num_kv_heads = num_kv_heads or num_heads

Expand All @@ -183,6 +197,7 @@ def __init__(
self.metadata = TrtllmAttentionMetadata(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
attention_metadata_state=attention_metadata_state,
)

# Needed to work with torch compile cause of attention metadata
Expand Down
14 changes: 14 additions & 0 deletions tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from tensorrt_llm.models.modeling_utils import QuantConfig

from ..config import AttentionConfig
from .interface import AttentionBackend


Expand Down Expand Up @@ -77,6 +78,8 @@ def create_attention(
dtype: Optional[torch.dtype] = None,
max_batch_size: int = 16,
max_seq_len: int = 4096,
attention_config: Optional[AttentionConfig] = None,
attention_metadata_state: Optional[dict] = None,
**kwargs,
) -> AttentionBackend:
"""
Expand All @@ -97,13 +100,24 @@ def create_attention(
will automatically reallocate if larger batches are encountered.
max_seq_len: Initial sequence length for metadata pre-allocation. The backend
will automatically reallocate if longer sequences are encountered.
attention_config: Optional AttentionConfig
attention_metadata_state: Optional model-scoped metadata state from
visual-gen config. Required for TRTLLM backend.
**kwargs: Additional backend-specific arguments

Returns:
AttentionBackend instance
"""
attn_cls = get_visual_gen_attention_backend(backend)

if backend.upper() == "TRTLLM":
if attention_metadata_state is None:
raise ValueError(
"TRTLLM backend requires `attention_metadata_state` from "
"DiffusionModelConfig; creation path must not allocate metadata implicitly."
)
kwargs["attention_metadata_state"] = attention_metadata_state

return attn_cls(
layer_idx=layer_idx,
num_heads=num_heads,
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/visual_gen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,11 @@ def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]:
return components


def create_attention_metadata_state() -> Dict[str, Any]:
"""Create model-scoped attention metadata state for TRTLLM visual-gen backend."""
return {"metadata": None, "capacity": (0, 0)}


# =============================================================================
# DiffusionModelConfig - Internal configuration (merged/parsed)
# =============================================================================
Expand Down Expand Up @@ -579,6 +584,7 @@ class DiffusionModelConfig(BaseModel):
cuda_graph: CudaGraphConfig = PydanticField(default_factory=CudaGraphConfig)
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
attention_metadata_state: Optional[Dict[str, Any]] = None
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
cache: Optional[CacheConfig] = None

Expand Down Expand Up @@ -935,6 +941,10 @@ def from_pretrained(

NVFP4LinearMethod.use_tunable_quantize = True

attention_metadata_state = (
create_attention_metadata_state() if attention_cfg.backend == "TRTLLM" else None
)

return cls(
pretrained_config=pretrained_config,
quant_config=quant_config,
Expand All @@ -947,6 +957,7 @@ def from_pretrained(
cuda_graph=cuda_graph_cfg,
pipeline=pipeline_cfg,
attention=attention_cfg,
attention_metadata_state=attention_metadata_state,
parallel=parallel_cfg,
cache=cache_cfg,
skip_create_weights_in_init=True,
Expand Down
Loading
Loading