Skip to content

Commit 2a0bcb1

Browse files
[TRTLLM-11794][feat] Optimize ViT Attention kernel on Nemotron (NVIDIA#12911)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 91c1c14 commit 2a0bcb1

File tree

7 files changed

+305
-21
lines changed

7 files changed

+305
-21
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 171 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import weakref
44
from dataclasses import dataclass, field
5-
from typing import Dict, Literal, Optional
5+
from typing import Any, Dict, Literal, Optional
66

77
import flashinfer
88
import torch
@@ -12,6 +12,7 @@
1212
from tensorrt_llm.functional import AttentionMaskType
1313
from tensorrt_llm.models.modeling_utils import QuantConfig
1414

15+
from ..metadata import KVCacheParams
1516
from ..utils import get_global_attrs, get_model_extra_attrs
1617
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
1718
CustomAttentionMask, PredefinedAttentionMask)
@@ -25,6 +26,8 @@
2526
arch_list = f"{capability[0]}.{capability[1]}"
2627
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list
2728

29+
from tensorrt_llm._utils import prefer_pinned
30+
2831

2932
@dataclass(kw_only=True, frozen=True)
3033
class PlanParams:
@@ -46,10 +49,13 @@ class PlanParams:
4649

4750
@dataclass(kw_only=True)
4851
class FlashInferWrappers:
49-
decode_wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper
50-
prefill_wrapper: Optional[flashinfer.BatchPrefillWithPagedKVCacheWrapper]
51-
5252
is_planned: bool
53+
decode_wrapper: Optional[
54+
flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None
55+
prefill_wrapper: Optional[
56+
flashinfer.BatchPrefillWithPagedKVCacheWrapper] = None
57+
ragged_prefill_wrapper: Optional[
58+
flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = None
5359

5460

5561
@dataclass(kw_only=True)
@@ -94,6 +100,15 @@ def get_decode_wrapper(
94100
result = self._plan_params_to_wrappers[plan_params].decode_wrapper
95101
return result
96102

103+
def get_ragged_prefill_wrapper(
104+
self, plan_params: PlanParams
105+
) -> flashinfer.BatchPrefillWithRaggedKVCacheWrapper:
106+
assert plan_params in self._plan_params_to_wrappers, "Plan params not found, make sure to call plan()"
107+
result = self._plan_params_to_wrappers[
108+
plan_params].ragged_prefill_wrapper
109+
assert result is not None, "Ragged prefill wrapper was not created in plan()"
110+
return result
111+
97112
@property
98113
def paged_kv_indices(self) -> torch.Tensor:
99114
return self._paged_kv_indices[:self.num_generation_blocks +
@@ -201,8 +216,81 @@ def page_size(self) -> int:
201216
"""
202217
Number of tokens per cache page
203218
"""
219+
assert self.kv_cache_manager is not None, (
220+
"page_size is undefined without a KV cache manager; use the "
221+
"ragged prefill path instead.")
204222
return self.kv_cache_manager.tokens_per_block
205223

224+
def _plan_ragged_cudnn_no_kv(
225+
self,
226+
plan_params: PlanParams,
227+
ragged_prefill_wrapper: Any,
228+
) -> None:
229+
is_causal = plan_params.attention_mask_type == AttentionMaskType.causal
230+
if plan_params.attention_mask_data is not None:
231+
window_left = -1
232+
else:
233+
window_left = plan_params.window_left
234+
235+
# Lengths are already on GPU via AttentionMetadata (seq_lens setter -> _seq_lens_cuda).
236+
assert self.seq_lens_cuda is not None
237+
assert self.seq_lens is not None
238+
239+
# NOTE: When kv_cache_manager is None (e.g. ViT), ragged prefill runs only for the context phase.
240+
# Restrict seq_lens to the first num_contexts entries accordingly.
241+
q_seqlens = self.seq_lens[:self.num_contexts]
242+
kv_seqlens = q_seqlens
243+
244+
max_query_tokens_per_sequence = int(
245+
self.seq_lens[:self.num_contexts].max().item())
246+
max_key_value_tokens_per_sequence = max_query_tokens_per_sequence
247+
248+
# cuDNN ragged prefill uses *element* offsets in qo/kv indptr, not token indptr.
249+
num_context_sequences = int(q_seqlens.shape[0])
250+
query_output_element_indptr = torch.zeros(
251+
num_context_sequences + 1,
252+
dtype=torch.int32,
253+
pin_memory=prefer_pinned(),
254+
)
255+
key_value_element_indptr = torch.zeros(
256+
num_context_sequences + 1,
257+
dtype=torch.int32,
258+
pin_memory=prefer_pinned(),
259+
)
260+
if num_context_sequences > 0:
261+
num_query_output_heads = plan_params.num_heads
262+
num_key_value_heads = plan_params.num_kv_heads
263+
attention_head_dim = plan_params.head_dim
264+
query_output_element_indptr[1:].copy_(
265+
torch.cumsum(q_seqlens, dim=0).mul_(num_query_output_heads *
266+
attention_head_dim))
267+
key_value_element_indptr[1:].copy_(
268+
torch.cumsum(kv_seqlens, dim=0).mul_(num_key_value_heads *
269+
attention_head_dim))
270+
271+
q_seqlens_cuda = self.seq_lens_cuda[:self.num_contexts]
272+
kv_seqlens_cuda = q_seqlens_cuda[:self.num_contexts]
273+
274+
ragged_prefill_wrapper.plan(
275+
qo_indptr=query_output_element_indptr,
276+
kv_indptr=key_value_element_indptr,
277+
num_qo_heads=plan_params.num_heads,
278+
num_kv_heads=plan_params.num_kv_heads,
279+
head_dim_qk=plan_params.head_dim,
280+
custom_mask=plan_params.attention_mask_data,
281+
causal=is_causal,
282+
sm_scale=plan_params.sm_scale,
283+
window_left=window_left,
284+
q_data_type=plan_params.q_dtype,
285+
kv_data_type=plan_params.kv_dtype,
286+
seq_lens=kv_seqlens_cuda,
287+
seq_lens_q=q_seqlens_cuda,
288+
max_token_per_sequence=max_query_tokens_per_sequence,
289+
max_sequence_kv=max_key_value_tokens_per_sequence,
290+
v_indptr=key_value_element_indptr,
291+
o_indptr=query_output_element_indptr,
292+
)
293+
206294
def prepare(self) -> None:
207295
super().prepare()
208296
extra_attrs = get_model_extra_attrs()
@@ -214,6 +302,27 @@ def prepare(self) -> None:
214302
dtype=torch.int32,
215303
out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1])
216304

305+
if self.kv_cache_manager is None:
306+
assert self.request_ids is not None
307+
assert self.num_generations == 0, (
308+
"FlashInfer without a KV cache manager only supports context-only "
309+
"batches (num_generations == 0) in TRT-LLM.")
310+
if self.is_cross:
311+
raise NotImplementedError(
312+
"FlashInfer without a KV cache manager is not tested for cross attention."
313+
)
314+
self.kv_cache_params = KVCacheParams(use_cache=False)
315+
n = self.num_seqs
316+
self._cached_token_lens[:n].zero_()
317+
for plan_params in list(self._plan_params_to_wrappers.keys()):
318+
if plan_params.attention_mask_data is None:
319+
self._plan_params_to_wrappers[
320+
plan_params].is_planned = False
321+
self._plan_with_params(plan_params)
322+
else:
323+
del self._plan_params_to_wrappers[plan_params]
324+
return
325+
217326
# indices of used cache blocks for each sequence
218327
assert self.request_ids is not None
219328
block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(
@@ -371,6 +480,33 @@ def _plan_with_params(self, plan_params: PlanParams) -> PlanParams:
371480
"Make sure you run a few warmup runs before capturing the graph!"
372481
)
373482

483+
if self.kv_cache_manager is None:
484+
if self.is_cuda_graph:
485+
raise NotImplementedError(
486+
"FlashInfer without a KV cache manager does not support "
487+
"CUDA graph capture; use the TRTLLM attention backend.")
488+
if plan_params in self._plan_params_to_wrappers:
489+
ragged_prefill_wrapper = self._plan_params_to_wrappers[
490+
plan_params].ragged_prefill_wrapper
491+
else:
492+
ragged_prefill_wrapper = (
493+
flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
494+
self.workspace_buffer,
495+
self.kv_layout,
496+
backend="cudnn",
497+
))
498+
torch.cuda.current_stream().synchronize()
499+
if self.num_contexts <= 0:
500+
raise ValueError(
501+
"FlashInfer ragged prefill without KV cache requires "
502+
"num_contexts >= 1.")
503+
self._plan_ragged_cudnn_no_kv(plan_params, ragged_prefill_wrapper)
504+
self._plan_params_to_wrappers[plan_params] = FlashInferWrappers(
505+
is_planned=True,
506+
ragged_prefill_wrapper=ragged_prefill_wrapper,
507+
)
508+
return plan_params
509+
374510
if plan_params in self._plan_params_to_wrappers:
375511
prefill_wrapper = self._plan_params_to_wrappers[
376512
plan_params].prefill_wrapper
@@ -437,6 +573,7 @@ def decode_plan():
437573
dtype=torch.int32,
438574
dim=0,
439575
)
576+
assert decode_wrapper is not None
440577
decode_wrapper.plan(
441578
paged_kv_indptr,
442579
self.paged_kv_indices[self.num_context_blocks:],
@@ -511,6 +648,36 @@ def forward_impl(
511648
# Query
512649
q = q.view(-1, self.num_heads, self.head_dim)
513650

651+
if metadata.kv_cache_manager is None:
652+
assert k is not None and v is not None, (
653+
"FlashInfer without a KV cache manager requires key/value tensors."
654+
)
655+
if self.has_fp8_kv_cache:
656+
raise NotImplementedError(
657+
"FP8 KV cache is not supported for FlashInfer without a "
658+
"KV cache manager.")
659+
k = k.view(-1, self.num_kv_heads, self.head_dim)
660+
v = v.view(-1, self.num_kv_heads, self.head_dim)
661+
plan_params = metadata.plan(
662+
self.num_heads,
663+
self.num_kv_heads,
664+
self.head_dim,
665+
q_dtype=q.dtype,
666+
kv_dtype=k.dtype,
667+
q_scaling=self.q_scaling,
668+
attention_window_size=attention_window_size,
669+
attention_mask_type=attention_mask_type,
670+
attention_mask_data=attention_mask_data,
671+
)
672+
wrapper = metadata.get_ragged_prefill_wrapper(plan_params)
673+
wrapper.run(
674+
q,
675+
k,
676+
v,
677+
out=output.view(-1, self.num_heads, self.head_dim),
678+
)
679+
return
680+
514681
# Key and Value
515682
kv_cache = metadata.kv_cache_manager.get_buffers(
516683
self.layer_idx, kv_layout=metadata.kv_layout)

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ class AttentionMetadata:
6565
# The max number of sequences in a single batch.
6666
max_num_sequences: Optional[int] = None
6767
# The KV cache manager.
68-
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2]
68+
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2, None] = None
6969
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
70-
draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] = None
70+
draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2, None] = None
7171
mapping: Optional[Mapping] = None
7272

7373
enable_flash_mla: bool = False

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ def _cache_multimodal_embeddings(
113113
def get_multimodal_embeddings(
114114
encoder_forward_fn: Callable[
115115
[List[MultimodalParams]],
116-
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]],
116+
Union[torch.Tensor, Tuple[torch.Tensor, Any]],
117117
],
118118
multimodal_params: List[MultimodalParams],
119119
encoder_kwargs: Optional[Dict[str, Any]] = None,
120-
) -> List[torch.Tensor]:
120+
) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], Any]]:
121121
"""
122122
High-level utility to get multimodal embeddings from encoder or cached embeddings.
123123
@@ -130,11 +130,15 @@ def get_multimodal_embeddings(
130130
Args:
131131
encoder_forward_fn: Callable that performs encoder forward pass.
132132
Should accept List[MultimodalParams] and return List[torch.Tensor] or
133-
Tuple[List[torch.Tensor], Dict[str, Any]] for models with auxiliary outputs.
133+
Tuple[List[torch.Tensor], aux_data] for models with auxiliary outputs.
134+
When returning a tuple, the first element must be a List[torch.Tensor]
135+
(one tensor per multimodal param), and aux_data is passed through to
136+
the caller unchanged.
134137
multimodal_params: All multimodal parameters in the batch.
135138
encoder_kwargs: Optional kwargs to pass to encoder_forward_fn.
136139
Returns:
137-
List of multimodal embeddings for all multimodal params in the batch.
140+
List of multimodal embeddings for all multimodal params in the batch, or a
141+
(List[torch.Tensor], aux_data) tuple if encoder_forward_fn returned auxiliary data.
138142
"""
139143
if not multimodal_params:
140144
return []
@@ -143,11 +147,26 @@ def get_multimodal_embeddings(
143147
uncached_multimodal_params = _get_uncached_multimodal_params(
144148
multimodal_params)
145149

150+
aux_data = None
151+
146152
# Step 2: Run encoder forward only on uncached parameters
147153
if uncached_multimodal_params:
148154
kwargs = encoder_kwargs or {}
149-
encoder_embeddings = encoder_forward_fn(uncached_multimodal_params,
150-
**kwargs)
155+
encoder_output = encoder_forward_fn(uncached_multimodal_params,
156+
**kwargs)
157+
158+
# Handle encoder returning (embeddings, aux_data) tuple.
159+
# In this case the first element is a List[torch.Tensor] with one tensor per
160+
# multimodal param (not yet concatenated), which we concatenate before caching.
161+
if isinstance(encoder_output, tuple):
162+
encoder_embeddings, aux_data = encoder_output
163+
# Concatenate per-param tensors into a single tensor for the caching path
164+
if isinstance(encoder_embeddings,
165+
list) and encoder_embeddings and isinstance(
166+
encoder_embeddings[0], torch.Tensor):
167+
encoder_embeddings = [torch.cat(encoder_embeddings, dim=0)]
168+
else:
169+
encoder_embeddings = encoder_output
151170

152171
# TODO: support multiple multimodal modalities per request
153172
if len(encoder_embeddings) > 1:
@@ -168,6 +187,8 @@ def get_multimodal_embeddings(
168187
logger.warning(
169188
"Multimodal runtime data missing or incomplete, will not cache embeddings."
170189
)
190+
if aux_data is not None:
191+
return encoder_embeddings, aux_data
171192
return encoder_embeddings
172193

173194
# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
@@ -190,6 +211,8 @@ def get_multimodal_embeddings(
190211
param.multimodal_data["multimodal_embedding"] for param in valid_params
191212
],
192213
dim=0)
214+
if aux_data is not None:
215+
return [all_embeddings], aux_data
193216
return [all_embeddings]
194217

195218

tensorrt_llm/_torch/models/modeling_nemotron_nano.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,11 +1649,13 @@ def __init__(self, model_config: ModelConfig):
16491649
super().__init__(config)
16501650

16511651
self.model_config = model_config
1652+
llm_model_config = copy.deepcopy(model_config)
1653+
vision_model_config = copy.deepcopy(model_config)
16521654
if hasattr(self, "llm"):
16531655
return
16541656

16551657
if not _is_disagg():
1656-
self.vision_encoder = NanoV2VLVisionEncoder(model_config).eval()
1658+
self.vision_encoder = NanoV2VLVisionEncoder(vision_model_config).eval()
16571659

16581660
self.sound_encoder: ProjectedParakeet | None = None
16591661
sound_config = getattr(config, "sound_config", None)
@@ -1664,7 +1666,6 @@ def __init__(self, model_config: ModelConfig):
16641666
dtype=getattr(config, "torch_dtype", torch.bfloat16),
16651667
).eval()
16661668

1667-
llm_model_config = copy.deepcopy(model_config)
16681669
llm_model_config.pretrained_config = llm_model_config.pretrained_config.llm_config
16691670
self._update_config_for_quantization(llm_model_config)
16701671

0 commit comments

Comments
 (0)