Skip to content

Commit 20061c2

Browse files
committed
Add row_starts
1 parent f5453c5 commit 20061c2

5 files changed

Lines changed: 244 additions & 83 deletions

File tree

csrc/flashinfer_topk_binding.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v
2424
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
2525
TensorView src_page_table,
2626
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
27+
Optional<TensorView> maybe_row_starts,
2728
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
2829
bool deterministic, int64_t tie_break, bool dsa_graph_safe);
2930

3031
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
31-
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
32-
int64_t top_k, bool deterministic, int64_t tie_break,
33-
bool dsa_graph_safe);
32+
TensorView lengths, Optional<TensorView> maybe_row_starts,
33+
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
34+
bool deterministic, int64_t tie_break, bool dsa_graph_safe);
3435

3536
bool can_implement_filtered_topk();
3637

csrc/topk.cu

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v
8383
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
8484
TensorView src_page_table,
8585
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
86+
Optional<TensorView> maybe_row_starts,
8687
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
8788
bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
8889
CHECK_INPUT(input);
@@ -93,6 +94,10 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
9394
CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k)
9495
CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len)
9596
CHECK_DIM(1, lengths); // lengths: (num_rows,)
97+
if (maybe_row_starts.has_value()) {
98+
CHECK_INPUT(maybe_row_starts.value());
99+
CHECK_DIM(1, maybe_row_starts.value());
100+
}
96101

97102
unsigned int num_rows = input.size(0);
98103
unsigned int max_len = input.size(1);
@@ -118,14 +123,21 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
118123
if (maybe_row_to_batch.has_value()) {
119124
row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr());
120125
}
126+
int32_t* row_starts_ptr = nullptr;
127+
if (maybe_row_starts.has_value()) {
128+
TVM_FFI_ICHECK(static_cast<unsigned int>(maybe_row_starts.value().size(0)) == num_rows)
129+
<< "row_starts must have shape (num_rows,)";
130+
row_starts_ptr = static_cast<int32_t*>(maybe_row_starts.value().data_ptr());
131+
}
121132

122133
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
123134
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
124135
status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>(
125136
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
126-
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
127-
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
128-
row_states_ptr, deterministic, tie_break_mode, stream, dsa_graph_safe);
137+
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride,
138+
static_cast<int32_t*>(lengths.data_ptr()), row_starts_ptr, row_to_batch_ptr, num_rows,
139+
static_cast<uint32_t>(top_k), max_len, row_states_ptr, deterministic, tie_break_mode,
140+
stream, dsa_graph_safe);
129141
return true;
130142
});
131143

@@ -134,9 +146,9 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
134146
}
135147

136148
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
137-
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
138-
int64_t top_k, bool deterministic, int64_t tie_break,
139-
bool dsa_graph_safe) {
149+
TensorView lengths, Optional<TensorView> maybe_row_starts,
150+
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
151+
bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
140152
CHECK_INPUT(input);
141153
CHECK_INPUT(output_indices);
142154
CHECK_INPUT(offsets);
@@ -145,6 +157,10 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
145157
CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
146158
CHECK_DIM(1, offsets); // offsets: (num_rows,)
147159
CHECK_DIM(1, lengths); // lengths: (num_rows,)
160+
if (maybe_row_starts.has_value()) {
161+
CHECK_INPUT(maybe_row_starts.value());
162+
CHECK_DIM(1, maybe_row_starts.value());
163+
}
148164

149165
unsigned int num_rows = input.size(0);
150166
unsigned int max_len = input.size(1);
@@ -164,14 +180,20 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
164180
row_states_ptr =
165181
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
166182
}
183+
int32_t* row_starts_ptr = nullptr;
184+
if (maybe_row_starts.has_value()) {
185+
TVM_FFI_ICHECK(static_cast<unsigned int>(maybe_row_starts.value().size(0)) == num_rows)
186+
<< "row_starts must have shape (num_rows,)";
187+
row_starts_ptr = static_cast<int32_t*>(maybe_row_starts.value().data_ptr());
188+
}
167189

168190
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
169191
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
170192
status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>(
171193
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
172194
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
173-
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, deterministic,
174-
tie_break_mode, stream, dsa_graph_safe);
195+
row_starts_ptr, num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr,
196+
deterministic, tie_break_mode, stream, dsa_graph_safe);
175197
return true;
176198
});
177199

flashinfer/topk.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def radix_topk_page_table_transform(
249249
src_page_table: torch.Tensor,
250250
row_to_batch: Optional[torch.Tensor],
251251
lengths: torch.Tensor,
252+
row_starts: Optional[torch.Tensor],
252253
row_states_buffer: Optional[torch.Tensor],
253254
top_k: int,
254255
deterministic: bool,
@@ -264,6 +265,7 @@ def radix_topk_page_table_transform(
264265
src_page_table,
265266
row_to_batch,
266267
lengths,
268+
row_starts,
267269
row_states_buffer,
268270
top_k,
269271
deterministic,
@@ -278,6 +280,7 @@ def _fake_radix_topk_page_table_transform(
278280
src_page_table: torch.Tensor,
279281
row_to_batch: Optional[torch.Tensor],
280282
lengths: torch.Tensor,
283+
row_starts: Optional[torch.Tensor],
281284
row_states_buffer: Optional[torch.Tensor],
282285
top_k: int,
283286
deterministic: bool,
@@ -295,6 +298,7 @@ def radix_topk_ragged_transform(
295298
output_indices: torch.Tensor,
296299
offsets: torch.Tensor,
297300
lengths: torch.Tensor,
301+
row_starts: Optional[torch.Tensor],
298302
row_states_buffer: Optional[torch.Tensor],
299303
top_k: int,
300304
deterministic: bool,
@@ -309,6 +313,7 @@ def radix_topk_ragged_transform(
309313
output_indices,
310314
offsets,
311315
lengths,
316+
row_starts,
312317
row_states_buffer,
313318
top_k,
314319
deterministic,
@@ -322,6 +327,7 @@ def _fake_radix_topk_ragged_transform(
322327
output_indices: torch.Tensor,
323328
offsets: torch.Tensor,
324329
lengths: torch.Tensor,
330+
row_starts: Optional[torch.Tensor],
325331
row_states_buffer: Optional[torch.Tensor],
326332
top_k: int,
327333
deterministic: bool,
@@ -655,6 +661,7 @@ def top_k_page_table_transform(
655661
src_page_table: torch.Tensor,
656662
lengths: torch.Tensor,
657663
k: int,
664+
row_starts: Optional[torch.Tensor] = None,
658665
row_to_batch: Optional[torch.Tensor] = None,
659666
deterministic: bool = False,
660667
tie_break: int = TopKTieBreak.NONE,
@@ -683,6 +690,10 @@ def top_k_page_table_transform(
683690
Actual KV lengths per row of shape ``(num_rows,)`` with dtype ``int32``.
684691
k : int
685692
Number of top elements to select from each row.
693+
row_starts : Optional[torch.Tensor], optional
694+
Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``.
695+
Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``.
696+
Default is None (equivalent to all zeros).
686697
row_to_batch : Optional[torch.Tensor], optional
687698
Mapping from row index to batch index of shape ``(num_rows,)`` with
688699
dtype ``int32``. If None, uses 1:1 mapping (row_idx == batch_idx).
@@ -714,7 +725,9 @@ def top_k_page_table_transform(
714725
Note
715726
----
716727
- This is specifically designed for sparse attention's second stage.
717-
- If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]]
728+
- If lengths[i] <= k, the output simply contains
729+
``src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]]`` (or start 0 when
730+
``row_starts`` is None)
718731
with remaining positions set to -1.
719732
720733
Examples
@@ -741,6 +754,7 @@ def top_k_page_table_transform(
741754
if (
742755
can_use_clusters_topk(input.device, deterministic, dsa_graph_safe)
743756
and row_to_batch is None
757+
and row_starts is None
744758
):
745759
return topk_clusters_page_table_transform(input, lengths, src_page_table, k)
746760

@@ -761,6 +775,7 @@ def top_k_page_table_transform(
761775
src_page_table,
762776
row_to_batch,
763777
lengths,
778+
row_starts,
764779
row_states_buffer,
765780
k,
766781
deterministic,
@@ -777,6 +792,7 @@ def top_k_ragged_transform(
777792
offsets: torch.Tensor,
778793
lengths: torch.Tensor,
779794
k: int,
795+
row_starts: Optional[torch.Tensor] = None,
780796
deterministic: bool = False,
781797
tie_break: int = TopKTieBreak.NONE,
782798
dsa_graph_safe: bool = False,
@@ -801,6 +817,11 @@ def top_k_ragged_transform(
801817
Actual KV lengths per row of shape ``(num_rows,)`` with dtype ``int32``.
802818
k : int
803819
Number of top elements to select from each row.
820+
row_starts : Optional[torch.Tensor], optional
821+
Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``.
822+
Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``.
823+
Output indices remain ``local_topk + offsets[i]`` where ``local_topk`` is relative to
824+
``row_starts[i]``. Default is None (equivalent to all zeros).
804825
deterministic : bool, optional
805826
If True, uses deterministic mode.
806827
Default is False (non-deterministic, which is faster).
@@ -853,7 +874,10 @@ def top_k_ragged_transform(
853874
if tie_break != TopKTieBreak.NONE:
854875
deterministic = True
855876

856-
if can_use_clusters_topk(input.device, deterministic, dsa_graph_safe):
877+
if (
878+
can_use_clusters_topk(input.device, deterministic, dsa_graph_safe)
879+
and row_starts is None
880+
):
857881
return topk_clusters_ragged_transform(input, lengths, offsets, k)
858882

859883
# Allocate row_states buffer for multi-CTA path
@@ -872,6 +896,7 @@ def top_k_ragged_transform(
872896
output_indices,
873897
offsets,
874898
lengths,
899+
row_starts,
875900
row_states_buffer,
876901
k,
877902
deterministic,

0 commit comments

Comments
 (0)