@@ -72,6 +72,7 @@ def radix_topk(
7272 tie_break : int ,
7373 row_states_buffer : Optional [torch .Tensor ],
7474 output_values : torch .Tensor ,
75+ dsa_graph_safe : bool = False ,
7576 ) -> torch .Tensor :
7677 device = input .device
7778 # Supports float32, float16, bfloat16
@@ -91,6 +92,7 @@ def radix_topk(
9192 sorted_output ,
9293 deterministic ,
9394 tie_break ,
95+ dsa_graph_safe ,
9496 )
9597 return output_indices
9698
@@ -103,6 +105,7 @@ def _fake_radix_topk(
103105 tie_break : int ,
104106 row_states_buffer : Optional [torch .Tensor ],
105107 output_values : torch .Tensor ,
108+ dsa_graph_safe : bool = False ,
106109 ) -> torch .Tensor :
107110 batch_size = input .size (0 )
108111 return torch .empty (batch_size , top_k , dtype = torch .int32 , device = input .device )
@@ -250,6 +253,7 @@ def radix_topk_page_table_transform(
250253 top_k : int ,
251254 deterministic : bool ,
252255 tie_break : int ,
256+ dsa_graph_safe : bool = False ,
253257 ) -> None :
254258 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], (
255259 f"Unsupported dtype { input .dtype } , expected float32, float16, or bfloat16"
@@ -264,6 +268,7 @@ def radix_topk_page_table_transform(
264268 top_k ,
265269 deterministic ,
266270 tie_break ,
271+ dsa_graph_safe ,
267272 )
268273
269274 @register_fake_op ("flashinfer::radix_topk_page_table_transform" )
@@ -277,6 +282,7 @@ def _fake_radix_topk_page_table_transform(
277282 top_k : int ,
278283 deterministic : bool ,
279284 tie_break : int ,
285+ dsa_graph_safe : bool = False ,
280286 ) -> None :
281287 pass
282288
@@ -293,6 +299,7 @@ def radix_topk_ragged_transform(
293299 top_k : int ,
294300 deterministic : bool ,
295301 tie_break : int ,
302+ dsa_graph_safe : bool = False ,
296303 ) -> None :
297304 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], (
298305 f"Unsupported dtype { input .dtype } , expected float32, float16, or bfloat16"
@@ -306,6 +313,7 @@ def radix_topk_ragged_transform(
306313 top_k ,
307314 deterministic ,
308315 tie_break ,
316+ dsa_graph_safe ,
309317 )
310318
311319 @register_fake_op ("flashinfer::radix_topk_ragged_transform" )
@@ -318,6 +326,7 @@ def _fake_radix_topk_ragged_transform(
318326 top_k : int ,
319327 deterministic : bool ,
320328 tie_break : int ,
329+ dsa_graph_safe : bool = False ,
321330 ) -> None :
322331 pass
323332
@@ -490,6 +499,7 @@ def top_k(
490499 sorted : bool = False ,
491500 deterministic : bool = False ,
492501 tie_break : int = TopKTieBreak .NONE ,
502+ dsa_graph_safe : bool = False ,
493503) -> Tuple [torch .Tensor , torch .Tensor ]:
494504 r"""Radix-based Top-K selection.
495505
@@ -525,6 +535,9 @@ def top_k(
525535 - ``2``: prefer larger indices
526536
527537 Default is ``0``.
538+ dsa_graph_safe : bool, optional
539+ If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
540+ Default is False.
528541
529542 Returns
530543 -------
@@ -578,7 +591,7 @@ def top_k(
578591 if tie_break != TopKTieBreak .NONE :
579592 deterministic = True
580593
581- if can_use_clusters_topk (input .device , deterministic ):
594+ if not dsa_graph_safe and can_use_clusters_topk (input .device , deterministic ):
582595 indices , output_values = topk_clusters_exact (
583596 input , k , output_values = True , out_dtype = torch .int64
584597 )
@@ -613,6 +626,7 @@ def top_k(
613626 tie_break ,
614627 row_states_buffer ,
615628 output_values ,
629+ dsa_graph_safe ,
616630 )
617631
618632 # Convert to int64 for compatibility
@@ -642,6 +656,7 @@ def top_k_page_table_transform(
642656 row_to_batch : Optional [torch .Tensor ] = None ,
643657 deterministic : bool = False ,
644658 tie_break : int = TopKTieBreak .NONE ,
659+ dsa_graph_safe : bool = False ,
645660) -> torch .Tensor :
646661 r"""Fused Top-K selection + Page Table Transform for sparse attention.
647662
@@ -682,6 +697,9 @@ def top_k_page_table_transform(
682697 - ``2``: prefer larger indices
683698
684699 Default is ``0``.
700+ dsa_graph_safe : bool, optional
701+ If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
702+ Default is False.
685703
686704
687705 Returns
@@ -718,7 +736,11 @@ def top_k_page_table_transform(
718736 if tie_break != TopKTieBreak .NONE :
719737 deterministic = True
720738
721- if can_use_clusters_topk (input .device , deterministic ) and row_to_batch is None :
739+ if (
740+ not dsa_graph_safe
741+ and can_use_clusters_topk (input .device , deterministic )
742+ and row_to_batch is None
743+ ):
722744 return topk_clusters_page_table_transform (input , lengths , src_page_table , k )
723745
724746 # Allocate row_states buffer for multi-CTA path
@@ -742,6 +764,7 @@ def top_k_page_table_transform(
742764 k ,
743765 deterministic ,
744766 tie_break ,
767+ dsa_graph_safe ,
745768 )
746769
747770 return output_page_table
@@ -755,6 +778,7 @@ def top_k_ragged_transform(
755778 k : int ,
756779 deterministic : bool = False ,
757780 tie_break : int = TopKTieBreak .NONE ,
781+ dsa_graph_safe : bool = False ,
758782) -> torch .Tensor :
759783 r"""Fused Top-K selection + Ragged Index Transform for sparse attention.
760784
@@ -788,6 +812,9 @@ def top_k_ragged_transform(
788812 - ``2``: prefer larger indices
789813
790814 Default is ``0``.
815+ dsa_graph_safe : bool, optional
816+ If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
817+ Default is False.
791818
792819
793820 Returns
@@ -825,7 +852,7 @@ def top_k_ragged_transform(
825852 if tie_break != TopKTieBreak .NONE :
826853 deterministic = True
827854
828- if can_use_clusters_topk (input .device , deterministic ):
855+ if not dsa_graph_safe and can_use_clusters_topk (input .device , deterministic ):
829856 return topk_clusters_ragged_transform (input , lengths , offsets , k )
830857
831858 # Allocate row_states buffer for multi-CTA path
@@ -848,6 +875,7 @@ def top_k_ragged_transform(
848875 k ,
849876 deterministic ,
850877 tie_break ,
878+ dsa_graph_safe ,
851879 )
852880
853881 return output_indices
0 commit comments