Skip to content

Commit 6e5a339

Browse files
authored
[TRTLLM-12127][fix] VisualGen metadata updates (NVIDIA#12862)
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
1 parent 421422f commit 6e5a339

File tree

10 files changed

+95
-19
lines changed

10 files changed

+95
-19
lines changed

tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,40 +45,51 @@ class TrtllmAttentionMetadata:
4545
max_batch_size: Initial batch size hint. Will grow automatically if exceeded.
4646
max_seq_len: Initial sequence length hint. Will grow automatically if exceeded.
4747
device: Target device for tensors.
48+
attention_metadata_state: Mutable model-scoped state shared by all
49+
attention layers in one model instance.
4850
"""
4951

5052
def __init__(
5153
self,
5254
max_batch_size: int = 16,
5355
max_seq_len: int = 4096,
5456
device: Optional[torch.device] = None,
57+
attention_metadata_state: Optional[dict] = None,
5558
):
5659
# These are initial hints, not hard limits - capacity grows as needed
5760
self.max_batch_size = max_batch_size
5861
self.max_seq_len = max_seq_len
5962
self.device = device or torch.device("cuda")
63+
if attention_metadata_state is None:
64+
raise ValueError(
65+
"TRTLLM attention requires `attention_metadata_state` to be provided "
66+
"by visual-gen config for model-scoped metadata sharing."
67+
)
68+
self._metadata_state = attention_metadata_state
6069

6170
# Lazily created BaseTrtllmAttentionMetadata
62-
self._metadata: Optional[BaseTrtllmAttentionMetadata] = None
63-
64-
# Track allocated capacity
65-
self._allocated_batch_size = 0
66-
self._allocated_max_seq_len = 0
71+
self._metadata: Optional[BaseTrtllmAttentionMetadata] = self._metadata_state["metadata"]
6772

6873
# Track prepared state
6974
self._cached_seq_lens: Optional[torch.Tensor] = None
7075
self._prepared = False
7176

7277
def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool:
7378
"""Check if we need to create new metadata (capacity change)."""
79+
metadata = self._metadata_state["metadata"]
80+
allocated_batch_size, allocated_max_seq_len = self._metadata_state["capacity"]
7481
return (
75-
self._metadata is None
76-
or batch_size > self._allocated_batch_size
77-
or max_seq_len > self._allocated_max_seq_len
82+
metadata is None
83+
or batch_size > allocated_batch_size
84+
or max_seq_len > allocated_max_seq_len
7885
)
7986

8087
def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
81-
"""Check if we need to call prepare() (seq_lens changed)."""
88+
"""Check if we need to call prepare() (seq_lens changed).
89+
90+
Assumes uniform sequence length per batch; if per-sample lengths vary,
91+
we may need to check seq_lens tensor instead.
92+
"""
8293
if not self._prepared:
8394
return True
8495
if self._cached_seq_lens is None:
@@ -89,9 +100,9 @@ def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
89100

90101
def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
91102
"""Create new metadata with given capacity."""
92-
# Allocate with some headroom to avoid frequent reallocation
93-
alloc_batch = max(batch_size, self._allocated_batch_size)
94-
alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len)
103+
prev_batch, prev_seq = self._metadata_state["capacity"]
104+
alloc_batch = max(batch_size, prev_batch)
105+
alloc_seq_len = max(max_seq_len, prev_seq)
95106

96107
self._metadata = BaseTrtllmAttentionMetadata(
97108
max_num_requests=alloc_batch,
@@ -102,8 +113,8 @@ def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
102113
runtime_features=AttentionRuntimeFeatures(),
103114
)
104115

105-
self._allocated_batch_size = alloc_batch
106-
self._allocated_max_seq_len = alloc_seq_len
116+
self._metadata_state["metadata"] = self._metadata
117+
self._metadata_state["capacity"] = (alloc_batch, alloc_seq_len)
107118
self._prepared = False # Reset prepare state on new metadata
108119

109120
def prepare(
@@ -116,7 +127,7 @@ def prepare(
116127
117128
Lazy behavior:
118129
- Creates metadata only when capacity needs increase
119-
- Calls prepare() only when seq_lens actually change
130+
- Calls prepare() only when (batch_size, max_seq_len) actually change
120131
"""
121132
if isinstance(seq_lens, int):
122133
seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32)
@@ -127,6 +138,8 @@ def prepare(
127138

128139
if self._needs_new_metadata(batch_size, max_seq_len):
129140
self._create_metadata(batch_size, max_seq_len)
141+
else:
142+
self._metadata = self._metadata_state["metadata"]
130143

131144
if self._needs_prepare(batch_size, seq_lens_tensor):
132145
self._metadata.seq_lens = seq_lens_tensor
@@ -165,6 +178,7 @@ def __init__(
165178
dtype: Optional[torch.dtype] = None,
166179
max_batch_size: int = 16,
167180
max_seq_len: int = 4096,
181+
attention_metadata_state: Optional[dict] = None,
168182
):
169183
num_kv_heads = num_kv_heads or num_heads
170184

@@ -183,6 +197,7 @@ def __init__(
183197
self.metadata = TrtllmAttentionMetadata(
184198
max_batch_size=max_batch_size,
185199
max_seq_len=max_seq_len,
200+
attention_metadata_state=attention_metadata_state,
186201
)
187202

188203
# Needed to work with torch compile cause of attention metadata

tensorrt_llm/_torch/visual_gen/attention_backend/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from tensorrt_llm.models.modeling_utils import QuantConfig
2828

29+
from ..config import AttentionConfig
2930
from .interface import AttentionBackend
3031

3132

@@ -77,6 +78,8 @@ def create_attention(
7778
dtype: Optional[torch.dtype] = None,
7879
max_batch_size: int = 16,
7980
max_seq_len: int = 4096,
81+
attention_config: Optional[AttentionConfig] = None,
82+
attention_metadata_state: Optional[dict] = None,
8083
**kwargs,
8184
) -> AttentionBackend:
8285
"""
@@ -97,13 +100,24 @@ def create_attention(
97100
will automatically reallocate if larger batches are encountered.
98101
max_seq_len: Initial sequence length for metadata pre-allocation. The backend
99102
will automatically reallocate if longer sequences are encountered.
103+
attention_config: Optional AttentionConfig
104+
attention_metadata_state: Optional model-scoped metadata state from
105+
visual-gen config. Required for TRTLLM backend.
100106
**kwargs: Additional backend-specific arguments
101107
102108
Returns:
103109
AttentionBackend instance
104110
"""
105111
attn_cls = get_visual_gen_attention_backend(backend)
106112

113+
if backend.upper() == "TRTLLM":
114+
if attention_metadata_state is None:
115+
raise ValueError(
116+
"TRTLLM backend requires `attention_metadata_state` from "
117+
"DiffusionModelConfig; creation path must not allocate metadata implicitly."
118+
)
119+
kwargs["attention_metadata_state"] = attention_metadata_state
120+
107121
return attn_cls(
108122
layer_idx=layer_idx,
109123
num_heads=num_heads,

tensorrt_llm/_torch/visual_gen/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,11 @@ def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]:
536536
return components
537537

538538

539+
def create_attention_metadata_state() -> Dict[str, Any]:
540+
"""Create model-scoped attention metadata state for TRTLLM visual-gen backend."""
541+
return {"metadata": None, "capacity": (0, 0)}
542+
543+
539544
# =============================================================================
540545
# DiffusionModelConfig - Internal configuration (merged/parsed)
541546
# =============================================================================
@@ -579,6 +584,7 @@ class DiffusionModelConfig(BaseModel):
579584
cuda_graph: CudaGraphConfig = PydanticField(default_factory=CudaGraphConfig)
580585
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
581586
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
587+
attention_metadata_state: Optional[Dict[str, Any]] = None
582588
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
583589
cache: Optional[CacheConfig] = None
584590

@@ -935,6 +941,10 @@ def from_pretrained(
935941

936942
NVFP4LinearMethod.use_tunable_quantize = True
937943

944+
attention_metadata_state = (
945+
create_attention_metadata_state() if attention_cfg.backend == "TRTLLM" else None
946+
)
947+
938948
return cls(
939949
pretrained_config=pretrained_config,
940950
quant_config=quant_config,
@@ -947,6 +957,7 @@ def from_pretrained(
947957
cuda_graph=cuda_graph_cfg,
948958
pipeline=pipeline_cfg,
949959
attention=attention_cfg,
960+
attention_metadata_state=attention_metadata_state,
950961
parallel=parallel_cfg,
951962
cache=cache_cfg,
952963
skip_create_weights_in_init=True,

tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def __init__(
131131
num_kv_heads=self.num_key_value_heads,
132132
quant_config=self.quant_config,
133133
dtype=self.dtype,
134+
attention_config=config.attention,
135+
attention_metadata_state=config.attention_metadata_state,
134136
)
135137
self._has_dual_attn = True
136138

tensorrt_llm/_torch/visual_gen/modules/attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(
9595

9696
self._init_qkv_proj()
9797

98+
attention_metadata_state = getattr(config, "attention_metadata_state", None)
99+
98100
if self.qk_norm:
99101
# "full": norm over all heads combined (e.g. WAN, dim=q_dim)
100102
# "per_head": norm over each head independently (e.g. FLUX, dim=head_dim)
@@ -141,6 +143,8 @@ def __init__(
141143
num_kv_heads=backend_num_kv_heads,
142144
quant_config=self.quant_config,
143145
dtype=self.dtype,
146+
attention_config=config.attention,
147+
attention_metadata_state=attention_metadata_state,
144148
)
145149

146150
# Wrap with parallelism strategies (orthogonal to backend choice)

tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
AttentionConfig,
2828
DiffusionModelConfig,
2929
TorchCompileConfig,
30+
create_attention_metadata_state,
3031
)
3132
from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping
3233
from tensorrt_llm._utils import get_free_port
@@ -152,6 +153,9 @@ def _make_model_config(pretrained_dict, ulysses_size=1, backend="VANILLA"):
152153
attention=AttentionConfig(backend=backend),
153154
visual_gen_mapping=vgm,
154155
cache=None,
156+
attention_metadata_state=(
157+
create_attention_metadata_state() if backend.upper() == "TRTLLM" else None
158+
),
155159
skip_create_weights_in_init=False,
156160
)
157161
config.mapping = vgm.to_llm_mapping()

tests/unittest/_torch/visual_gen/test_attention_integration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
# Flash Attention 4 availability
2020
# ============================================================================
2121
from tensorrt_llm._torch.visual_gen.attention_backend.flash_attn4 import _flash_attn_fwd as _fa4_fwd
22-
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
22+
from tensorrt_llm._torch.visual_gen.config import (
23+
AttentionConfig,
24+
DiffusionModelConfig,
25+
create_attention_metadata_state,
26+
)
2327

2428
# Import new integrated versions
2529
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb
@@ -128,6 +132,9 @@ def create_model_config(
128132
attention=AttentionConfig(backend=attn_backend),
129133
skip_create_weights_in_init=False,
130134
)
135+
config.attention_metadata_state = (
136+
create_attention_metadata_state() if attn_backend == "TRTLLM" else None
137+
)
131138
return config
132139

133140

tests/unittest/_torch/visual_gen/test_attention_perf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@
4343
from tensorrt_llm._torch.visual_gen.attention_backend.flash_attn4 import (
4444
_flash_attn_fwd_import_error as _fa4_import_error,
4545
)
46-
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
46+
from tensorrt_llm._torch.visual_gen.config import (
47+
AttentionConfig,
48+
DiffusionModelConfig,
49+
create_attention_metadata_state,
50+
)
4751
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
4852

4953
_flash_attn4_available = _fa4_fwd is not None
@@ -155,6 +159,9 @@ def create_model_config(
155159
attention=AttentionConfig(backend=attn_backend),
156160
skip_create_weights_in_init=False,
157161
)
162+
config.attention_metadata_state = (
163+
create_attention_metadata_state() if attn_backend == "TRTLLM" else None
164+
)
158165
return config
159166

160167

tests/unittest/_torch/visual_gen/test_flux_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
import torch
2121
import torch.nn.functional as F
2222

23-
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
23+
from tensorrt_llm._torch.visual_gen.config import (
24+
AttentionConfig,
25+
DiffusionModelConfig,
26+
create_attention_metadata_state,
27+
)
2428
from tensorrt_llm.mapping import Mapping
2529
from tensorrt_llm.models.modeling_utils import QuantConfig
2630

@@ -103,6 +107,7 @@ def test_trtllm_backend_sanity(self):
103107

104108
torch.manual_seed(42)
105109
config = self._create_config("TRTLLM")
110+
config.attention_metadata_state = create_attention_metadata_state()
106111

107112
attn = (
108113
FluxJointAttention(
@@ -175,6 +180,7 @@ def test_backend_equivalence(self):
175180
p.normal_(0, 0.02)
176181

177182
config = self._create_config("TRTLLM")
183+
config.attention_metadata_state = create_attention_metadata_state()
178184
trtllm_attn = (
179185
FluxJointAttention(
180186
hidden_size=dim,

tests/unittest/_torch/visual_gen/test_ltx2_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
import torch
1717
import torch.nn.functional as F
1818

19-
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
19+
from tensorrt_llm._torch.visual_gen.config import (
20+
AttentionConfig,
21+
DiffusionModelConfig,
22+
create_attention_metadata_state,
23+
)
2024
from tensorrt_llm.mapping import Mapping
2125
from tensorrt_llm.models.modeling_utils import QuantConfig
2226

@@ -102,6 +106,7 @@ def test_trtllm_self_attention_sanity(self):
102106

103107
torch.manual_seed(42)
104108
config = _create_config("TRTLLM")
109+
config.attention_metadata_state = create_attention_metadata_state()
105110

106111
attn = (
107112
LTX2Attention(
@@ -287,6 +292,7 @@ def test_backend_equivalence(self):
287292

288293
# Create TRTLLM attention and copy weights
289294
config_trtllm = _create_config("TRTLLM")
295+
config_trtllm.attention_metadata_state = create_attention_metadata_state()
290296
trtllm_attn = (
291297
LTX2Attention(
292298
query_dim=query_dim,

0 commit comments

Comments
 (0)