Skip to content

Commit 30d7210

Browse files
committed
Add dsa_graph_safe
1 parent 85d7b7e commit 30d7210

4 files changed

Lines changed: 87 additions & 42 deletions

File tree

csrc/flashinfer_topk_binding.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@ using tvm::ffi::Optional;
1919

2020
void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
2121
Optional<TensorView> maybe_row_states_buffer, int64_t top_k, bool sorted_output,
22-
bool deterministic, int64_t tie_break);
22+
bool deterministic, int64_t tie_break, bool dsa_graph_safe);
2323

2424
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
2525
TensorView src_page_table,
2626
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
2727
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
28-
bool deterministic, int64_t tie_break);
28+
bool deterministic, int64_t tie_break, bool dsa_graph_safe);
2929

3030
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
3131
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
32-
int64_t top_k, bool deterministic, int64_t tie_break);
32+
int64_t top_k, bool deterministic, int64_t tie_break,
33+
bool dsa_graph_safe);
3334

3435
bool can_implement_filtered_topk();
3536

csrc/topk.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ inline sampling::TopKTieBreak ParseTopKTieBreak(int64_t tie_break) {
4040

4141
void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
4242
Optional<TensorView> maybe_row_states_buffer, int64_t top_k, bool sorted_output,
43-
bool deterministic, int64_t tie_break) {
43+
bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
4444
CHECK_INPUT(input);
4545
CHECK_INPUT(output_indices);
4646
CHECK_INPUT(output_values);
@@ -72,7 +72,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v
7272
status = sampling::TopKDispatch<c_type, int32_t>(
7373
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
7474
static_cast<c_type*>(output_values.data_ptr()), batch_size, static_cast<uint32_t>(top_k), d,
75-
row_states_ptr, sorted_output, deterministic, tie_break_mode, stream);
75+
row_states_ptr, sorted_output, deterministic, tie_break_mode, stream, dsa_graph_safe);
7676
return true;
7777
});
7878

@@ -84,7 +84,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
8484
TensorView src_page_table,
8585
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
8686
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
87-
bool deterministic, int64_t tie_break) {
87+
bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
8888
CHECK_INPUT(input);
8989
CHECK_INPUT(output_page_table);
9090
CHECK_INPUT(src_page_table);
@@ -125,7 +125,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
125125
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
126126
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
127127
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
128-
row_states_ptr, deterministic, tie_break_mode, stream);
128+
row_states_ptr, deterministic, tie_break_mode, stream, dsa_graph_safe);
129129
return true;
130130
});
131131

@@ -135,7 +135,8 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
135135

136136
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
137137
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
138-
int64_t top_k, bool deterministic, int64_t tie_break) {
138+
int64_t top_k, bool deterministic, int64_t tie_break,
139+
bool dsa_graph_safe) {
139140
CHECK_INPUT(input);
140141
CHECK_INPUT(output_indices);
141142
CHECK_INPUT(offsets);
@@ -170,7 +171,7 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
170171
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
171172
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
172173
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, deterministic,
173-
tie_break_mode, stream);
174+
tie_break_mode, stream, dsa_graph_safe);
174175
return true;
175176
});
176177

flashinfer/topk.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def radix_topk(
7272
tie_break: int,
7373
row_states_buffer: Optional[torch.Tensor],
7474
output_values: torch.Tensor,
75+
dsa_graph_safe: bool = False,
7576
) -> torch.Tensor:
7677
device = input.device
7778
# Supports float32, float16, bfloat16
@@ -91,6 +92,7 @@ def radix_topk(
9192
sorted_output,
9293
deterministic,
9394
tie_break,
95+
dsa_graph_safe,
9496
)
9597
return output_indices
9698

@@ -103,6 +105,7 @@ def _fake_radix_topk(
103105
tie_break: int,
104106
row_states_buffer: Optional[torch.Tensor],
105107
output_values: torch.Tensor,
108+
dsa_graph_safe: bool = False,
106109
) -> torch.Tensor:
107110
batch_size = input.size(0)
108111
return torch.empty(batch_size, top_k, dtype=torch.int32, device=input.device)
@@ -250,6 +253,7 @@ def radix_topk_page_table_transform(
250253
top_k: int,
251254
deterministic: bool,
252255
tie_break: int,
256+
dsa_graph_safe: bool = False,
253257
) -> None:
254258
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
255259
f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
@@ -264,6 +268,7 @@ def radix_topk_page_table_transform(
264268
top_k,
265269
deterministic,
266270
tie_break,
271+
dsa_graph_safe,
267272
)
268273

269274
@register_fake_op("flashinfer::radix_topk_page_table_transform")
@@ -277,6 +282,7 @@ def _fake_radix_topk_page_table_transform(
277282
top_k: int,
278283
deterministic: bool,
279284
tie_break: int,
285+
dsa_graph_safe: bool = False,
280286
) -> None:
281287
pass
282288

@@ -293,6 +299,7 @@ def radix_topk_ragged_transform(
293299
top_k: int,
294300
deterministic: bool,
295301
tie_break: int,
302+
dsa_graph_safe: bool = False,
296303
) -> None:
297304
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
298305
f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
@@ -306,6 +313,7 @@ def radix_topk_ragged_transform(
306313
top_k,
307314
deterministic,
308315
tie_break,
316+
dsa_graph_safe,
309317
)
310318

311319
@register_fake_op("flashinfer::radix_topk_ragged_transform")
@@ -318,6 +326,7 @@ def _fake_radix_topk_ragged_transform(
318326
top_k: int,
319327
deterministic: bool,
320328
tie_break: int,
329+
dsa_graph_safe: bool = False,
321330
) -> None:
322331
pass
323332

@@ -490,6 +499,7 @@ def top_k(
490499
sorted: bool = False,
491500
deterministic: bool = False,
492501
tie_break: int = TopKTieBreak.NONE,
502+
dsa_graph_safe: bool = False,
493503
) -> Tuple[torch.Tensor, torch.Tensor]:
494504
r"""Radix-based Top-K selection.
495505
@@ -525,6 +535,9 @@ def top_k(
525535
- ``2``: prefer larger indices
526536
527537
Default is ``0``.
538+
dsa_graph_safe : bool, optional
539+
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
540+
Default is False.
528541
529542
Returns
530543
-------
@@ -578,7 +591,7 @@ def top_k(
578591
if tie_break != TopKTieBreak.NONE:
579592
deterministic = True
580593

581-
if can_use_clusters_topk(input.device, deterministic):
594+
if not dsa_graph_safe and can_use_clusters_topk(input.device, deterministic):
582595
indices, output_values = topk_clusters_exact(
583596
input, k, output_values=True, out_dtype=torch.int64
584597
)
@@ -613,6 +626,7 @@ def top_k(
613626
tie_break,
614627
row_states_buffer,
615628
output_values,
629+
dsa_graph_safe,
616630
)
617631

618632
# Convert to int64 for compatibility
@@ -642,6 +656,7 @@ def top_k_page_table_transform(
642656
row_to_batch: Optional[torch.Tensor] = None,
643657
deterministic: bool = False,
644658
tie_break: int = TopKTieBreak.NONE,
659+
dsa_graph_safe: bool = False,
645660
) -> torch.Tensor:
646661
r"""Fused Top-K selection + Page Table Transform for sparse attention.
647662
@@ -682,6 +697,9 @@ def top_k_page_table_transform(
682697
- ``2``: prefer larger indices
683698
684699
Default is ``0``.
700+
dsa_graph_safe : bool, optional
701+
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
702+
Default is False.
685703
686704
687705
Returns
@@ -718,7 +736,11 @@ def top_k_page_table_transform(
718736
if tie_break != TopKTieBreak.NONE:
719737
deterministic = True
720738

721-
if can_use_clusters_topk(input.device, deterministic) and row_to_batch is None:
739+
if (
740+
not dsa_graph_safe
741+
and can_use_clusters_topk(input.device, deterministic)
742+
and row_to_batch is None
743+
):
722744
return topk_clusters_page_table_transform(input, lengths, src_page_table, k)
723745

724746
# Allocate row_states buffer for multi-CTA path
@@ -742,6 +764,7 @@ def top_k_page_table_transform(
742764
k,
743765
deterministic,
744766
tie_break,
767+
dsa_graph_safe,
745768
)
746769

747770
return output_page_table
@@ -755,6 +778,7 @@ def top_k_ragged_transform(
755778
k: int,
756779
deterministic: bool = False,
757780
tie_break: int = TopKTieBreak.NONE,
781+
dsa_graph_safe: bool = False,
758782
) -> torch.Tensor:
759783
r"""Fused Top-K selection + Ragged Index Transform for sparse attention.
760784
@@ -788,6 +812,9 @@ def top_k_ragged_transform(
788812
- ``2``: prefer larger indices
789813
790814
Default is ``0``.
815+
dsa_graph_safe : bool, optional
816+
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
817+
Default is False.
791818
792819
793820
Returns
@@ -825,7 +852,7 @@ def top_k_ragged_transform(
825852
if tie_break != TopKTieBreak.NONE:
826853
deterministic = True
827854

828-
if can_use_clusters_topk(input.device, deterministic):
855+
if not dsa_graph_safe and can_use_clusters_topk(input.device, deterministic):
829856
return topk_clusters_ragged_transform(input, lengths, offsets, k)
830857

831858
# Allocate row_states buffer for multi-CTA path
@@ -848,6 +875,7 @@ def top_k_ragged_transform(
848875
k,
849876
deterministic,
850877
tie_break,
878+
dsa_graph_safe,
851879
)
852880

853881
return output_indices

0 commit comments

Comments
 (0)