22import os
33import weakref
44from dataclasses import dataclass , field
5- from typing import Dict , Literal , Optional
5+ from typing import Any , Dict , Literal , Optional
66
77import flashinfer
88import torch
1212from tensorrt_llm .functional import AttentionMaskType
1313from tensorrt_llm .models .modeling_utils import QuantConfig
1414
15+ from ..metadata import KVCacheParams
1516from ..utils import get_global_attrs , get_model_extra_attrs
1617from .interface import (AttentionBackend , AttentionMask , AttentionMetadata ,
1718 CustomAttentionMask , PredefinedAttentionMask )
2526 arch_list = f"{ capability [0 ]} .{ capability [1 ]} "
2627 os .environ ["TORCH_CUDA_ARCH_LIST" ] = arch_list
2728
29+ from tensorrt_llm ._utils import prefer_pinned
30+
2831
2932@dataclass (kw_only = True , frozen = True )
3033class PlanParams :
@@ -46,10 +49,13 @@ class PlanParams:
4649
4750@dataclass (kw_only = True )
4851class FlashInferWrappers :
49- decode_wrapper : flashinfer .BatchDecodeWithPagedKVCacheWrapper
50- prefill_wrapper : Optional [flashinfer .BatchPrefillWithPagedKVCacheWrapper ]
51-
5252 is_planned : bool
53+ decode_wrapper : Optional [
54+ flashinfer .BatchDecodeWithPagedKVCacheWrapper ] = None
55+ prefill_wrapper : Optional [
56+ flashinfer .BatchPrefillWithPagedKVCacheWrapper ] = None
57+ ragged_prefill_wrapper : Optional [
58+ flashinfer .BatchPrefillWithRaggedKVCacheWrapper ] = None
5359
5460
5561@dataclass (kw_only = True )
@@ -94,6 +100,15 @@ def get_decode_wrapper(
94100 result = self ._plan_params_to_wrappers [plan_params ].decode_wrapper
95101 return result
96102
103+ def get_ragged_prefill_wrapper (
104+ self , plan_params : PlanParams
105+ ) -> flashinfer .BatchPrefillWithRaggedKVCacheWrapper :
106+ assert plan_params in self ._plan_params_to_wrappers , "Plan params not found, make sure to call plan()"
107+ result = self ._plan_params_to_wrappers [
108+ plan_params ].ragged_prefill_wrapper
109+ assert result is not None , "Ragged prefill wrapper was not created in plan()"
110+ return result
111+
97112 @property
98113 def paged_kv_indices (self ) -> torch .Tensor :
99114 return self ._paged_kv_indices [:self .num_generation_blocks +
@@ -201,8 +216,81 @@ def page_size(self) -> int:
201216 """
202217 Number of tokens per cache page
203218 """
219+ assert self .kv_cache_manager is not None , (
220+ "page_size is undefined without a KV cache manager; use the "
221+ "ragged prefill path instead." )
204222 return self .kv_cache_manager .tokens_per_block
205223
224+ def _plan_ragged_cudnn_no_kv (
225+ self ,
226+ plan_params : PlanParams ,
227+ ragged_prefill_wrapper : Any ,
228+ ) -> None :
229+ is_causal = plan_params .attention_mask_type == AttentionMaskType .causal
230+ if plan_params .attention_mask_data is not None :
231+ window_left = - 1
232+ else :
233+ window_left = plan_params .window_left
234+
235+ # Lengths are already on GPU via AttentionMetadata (seq_lens setter -> _seq_lens_cuda).
236+ assert self .seq_lens_cuda is not None
237+ assert self .seq_lens is not None
238+
239+ # NOTE: When kv_cache_manager is None (e.g. ViT), ragged prefill runs only for the context phase.
240+ # Restrict seq_lens to the first num_contexts entries accordingly.
241+ q_seqlens = self .seq_lens [:self .num_contexts ]
242+ kv_seqlens = q_seqlens
243+
244+ max_query_tokens_per_sequence = int (
245+ self .seq_lens [:self .num_contexts ].max ().item ())
246+ max_key_value_tokens_per_sequence = max_query_tokens_per_sequence
247+
248+ # cuDNN ragged prefill uses *element* offsets in qo/kv indptr, not token indptr.
249+ num_context_sequences = int (q_seqlens .shape [0 ])
250+ query_output_element_indptr = torch .zeros (
251+ num_context_sequences + 1 ,
252+ dtype = torch .int32 ,
253+ pin_memory = prefer_pinned (),
254+ )
255+ key_value_element_indptr = torch .zeros (
256+ num_context_sequences + 1 ,
257+ dtype = torch .int32 ,
258+ pin_memory = prefer_pinned (),
259+ )
260+ if num_context_sequences > 0 :
261+ num_query_output_heads = plan_params .num_heads
262+ num_key_value_heads = plan_params .num_kv_heads
263+ attention_head_dim = plan_params .head_dim
264+ query_output_element_indptr [1 :].copy_ (
265+ torch .cumsum (q_seqlens , dim = 0 ).mul_ (num_query_output_heads *
266+ attention_head_dim ))
267+ key_value_element_indptr [1 :].copy_ (
268+ torch .cumsum (kv_seqlens , dim = 0 ).mul_ (num_key_value_heads *
269+ attention_head_dim ))
270+
271+ q_seqlens_cuda = self .seq_lens_cuda [:self .num_contexts ]
272+ kv_seqlens_cuda = q_seqlens_cuda [:self .num_contexts ]
273+
274+ ragged_prefill_wrapper .plan (
275+ qo_indptr = query_output_element_indptr ,
276+ kv_indptr = key_value_element_indptr ,
277+ num_qo_heads = plan_params .num_heads ,
278+ num_kv_heads = plan_params .num_kv_heads ,
279+ head_dim_qk = plan_params .head_dim ,
280+ custom_mask = plan_params .attention_mask_data ,
281+ causal = is_causal ,
282+ sm_scale = plan_params .sm_scale ,
283+ window_left = window_left ,
284+ q_data_type = plan_params .q_dtype ,
285+ kv_data_type = plan_params .kv_dtype ,
286+ seq_lens = kv_seqlens_cuda ,
287+ seq_lens_q = q_seqlens_cuda ,
288+ max_token_per_sequence = max_query_tokens_per_sequence ,
289+ max_sequence_kv = max_key_value_tokens_per_sequence ,
290+ v_indptr = key_value_element_indptr ,
291+ o_indptr = query_output_element_indptr ,
292+ )
293+
206294 def prepare (self ) -> None :
207295 super ().prepare ()
208296 extra_attrs = get_model_extra_attrs ()
@@ -214,6 +302,27 @@ def prepare(self) -> None:
214302 dtype = torch .int32 ,
215303 out = self ._qo_indptr [1 :self .seq_lens_cuda .size (0 ) + 1 ])
216304
305+ if self .kv_cache_manager is None :
306+ assert self .request_ids is not None
307+ assert self .num_generations == 0 , (
308+ "FlashInfer without a KV cache manager only supports context-only "
309+ "batches (num_generations == 0) in TRT-LLM." )
310+ if self .is_cross :
311+ raise NotImplementedError (
312+ "FlashInfer without a KV cache manager is not tested for cross attention."
313+ )
314+ self .kv_cache_params = KVCacheParams (use_cache = False )
315+ n = self .num_seqs
316+ self ._cached_token_lens [:n ].zero_ ()
317+ for plan_params in list (self ._plan_params_to_wrappers .keys ()):
318+ if plan_params .attention_mask_data is None :
319+ self ._plan_params_to_wrappers [
320+ plan_params ].is_planned = False
321+ self ._plan_with_params (plan_params )
322+ else :
323+ del self ._plan_params_to_wrappers [plan_params ]
324+ return
325+
217326 # indices of used cache blocks for each sequence
218327 assert self .request_ids is not None
219328 block_ids_per_seq = self .kv_cache_manager .get_batch_cache_indices (
@@ -371,6 +480,33 @@ def _plan_with_params(self, plan_params: PlanParams) -> PlanParams:
371480 "Make sure you run a few warmup runs before capturing the graph!"
372481 )
373482
483+ if self .kv_cache_manager is None :
484+ if self .is_cuda_graph :
485+ raise NotImplementedError (
486+ "FlashInfer without a KV cache manager does not support "
487+ "CUDA graph capture; use the TRTLLM attention backend." )
488+ if plan_params in self ._plan_params_to_wrappers :
489+ ragged_prefill_wrapper = self ._plan_params_to_wrappers [
490+ plan_params ].ragged_prefill_wrapper
491+ else :
492+ ragged_prefill_wrapper = (
493+ flashinfer .BatchPrefillWithRaggedKVCacheWrapper (
494+ self .workspace_buffer ,
495+ self .kv_layout ,
496+ backend = "cudnn" ,
497+ ))
498+ torch .cuda .current_stream ().synchronize ()
499+ if self .num_contexts <= 0 :
500+ raise ValueError (
501+ "FlashInfer ragged prefill without KV cache requires "
502+ "num_contexts >= 1." )
503+ self ._plan_ragged_cudnn_no_kv (plan_params , ragged_prefill_wrapper )
504+ self ._plan_params_to_wrappers [plan_params ] = FlashInferWrappers (
505+ is_planned = True ,
506+ ragged_prefill_wrapper = ragged_prefill_wrapper ,
507+ )
508+ return plan_params
509+
374510 if plan_params in self ._plan_params_to_wrappers :
375511 prefill_wrapper = self ._plan_params_to_wrappers [
376512 plan_params ].prefill_wrapper
@@ -437,6 +573,7 @@ def decode_plan():
437573 dtype = torch .int32 ,
438574 dim = 0 ,
439575 )
576+ assert decode_wrapper is not None
440577 decode_wrapper .plan (
441578 paged_kv_indptr ,
442579 self .paged_kv_indices [self .num_context_blocks :],
@@ -511,6 +648,36 @@ def forward_impl(
511648 # Query
512649 q = q .view (- 1 , self .num_heads , self .head_dim )
513650
651+ if metadata .kv_cache_manager is None :
652+ assert k is not None and v is not None , (
653+ "FlashInfer without a KV cache manager requires key/value tensors."
654+ )
655+ if self .has_fp8_kv_cache :
656+ raise NotImplementedError (
657+ "FP8 KV cache is not supported for FlashInfer without a "
658+ "KV cache manager." )
659+ k = k .view (- 1 , self .num_kv_heads , self .head_dim )
660+ v = v .view (- 1 , self .num_kv_heads , self .head_dim )
661+ plan_params = metadata .plan (
662+ self .num_heads ,
663+ self .num_kv_heads ,
664+ self .head_dim ,
665+ q_dtype = q .dtype ,
666+ kv_dtype = k .dtype ,
667+ q_scaling = self .q_scaling ,
668+ attention_window_size = attention_window_size ,
669+ attention_mask_type = attention_mask_type ,
670+ attention_mask_data = attention_mask_data ,
671+ )
672+ wrapper = metadata .get_ragged_prefill_wrapper (plan_params )
673+ wrapper .run (
674+ q ,
675+ k ,
676+ v ,
677+ out = output .view (- 1 , self .num_heads , self .head_dim ),
678+ )
679+ return
680+
514681 # Key and Value
515682 kv_cache = metadata .kv_cache_manager .get_buffers (
516683 self .layer_idx , kv_layout = metadata .kv_layout )
0 commit comments