Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions csrc/flashinfer_topk_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ using tvm::ffi::Optional;

void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k, bool sorted_output,
bool deterministic, int64_t tie_break);
bool deterministic, int64_t tie_break, bool dsa_graph_safe);

void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
TensorView src_page_table,
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
bool deterministic, int64_t tie_break);
bool deterministic, int64_t tie_break, bool dsa_graph_safe,
Optional<TensorView> maybe_row_starts);

void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
int64_t top_k, bool deterministic, int64_t tie_break);
int64_t top_k, bool deterministic, int64_t tie_break,
bool dsa_graph_safe, Optional<TensorView> maybe_row_starts);

bool can_implement_filtered_topk();

Expand Down
41 changes: 32 additions & 9 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ inline sampling::TopKTieBreak ParseTopKTieBreak(int64_t tie_break) {

void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k, bool sorted_output,
bool deterministic, int64_t tie_break) {
bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(output_values);
Expand Down Expand Up @@ -72,7 +72,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v
status = sampling::TopKDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<c_type*>(output_values.data_ptr()), batch_size, static_cast<uint32_t>(top_k), d,
row_states_ptr, sorted_output, deterministic, tie_break_mode, stream);
row_states_ptr, sorted_output, deterministic, tie_break_mode, stream, dsa_graph_safe);
return true;
});

Expand All @@ -84,7 +84,8 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
TensorView src_page_table,
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
bool deterministic, int64_t tie_break) {
bool deterministic, int64_t tie_break, bool dsa_graph_safe,
Optional<TensorView> maybe_row_starts) {
CHECK_INPUT(input);
CHECK_INPUT(output_page_table);
CHECK_INPUT(src_page_table);
Expand All @@ -93,6 +94,10 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k)
CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
if (maybe_row_starts.has_value()) {
CHECK_INPUT(maybe_row_starts.value());
CHECK_DIM(1, maybe_row_starts.value());
}

unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
Expand All @@ -118,14 +123,21 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
if (maybe_row_to_batch.has_value()) {
row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr());
}
int32_t* row_starts_ptr = nullptr;
if (maybe_row_starts.has_value()) {
TVM_FFI_ICHECK(static_cast<unsigned int>(maybe_row_starts.value().size(0)) == num_rows)
<< "row_starts must have shape (num_rows,)";
row_starts_ptr = static_cast<int32_t*>(maybe_row_starts.value().data_ptr());
}

// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
row_states_ptr, deterministic, tie_break_mode, stream);
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride,
static_cast<int32_t*>(lengths.data_ptr()), row_starts_ptr, row_to_batch_ptr, num_rows,
static_cast<uint32_t>(top_k), max_len, row_states_ptr, deterministic, tie_break_mode,
stream, dsa_graph_safe);
return true;
});

Expand All @@ -135,7 +147,8 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta

void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
int64_t top_k, bool deterministic, int64_t tie_break) {
int64_t top_k, bool deterministic, int64_t tie_break,
bool dsa_graph_safe, Optional<TensorView> maybe_row_starts) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(offsets);
Expand All @@ -144,6 +157,10 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
CHECK_DIM(1, offsets); // offsets: (num_rows,)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
if (maybe_row_starts.has_value()) {
CHECK_INPUT(maybe_row_starts.value());
CHECK_DIM(1, maybe_row_starts.value());
}

unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
Expand All @@ -163,14 +180,20 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}
int32_t* row_starts_ptr = nullptr;
if (maybe_row_starts.has_value()) {
TVM_FFI_ICHECK(static_cast<unsigned int>(maybe_row_starts.value().size(0)) == num_rows)
<< "row_starts must have shape (num_rows,)";
row_starts_ptr = static_cast<int32_t*>(maybe_row_starts.value().data_ptr());
}

// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, deterministic,
tie_break_mode, stream);
row_starts_ptr, num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr,
deterministic, tie_break_mode, stream, dsa_graph_safe);
return true;
});

Expand Down
64 changes: 59 additions & 5 deletions flashinfer/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def radix_topk(
tie_break: int,
row_states_buffer: Optional[torch.Tensor],
output_values: torch.Tensor,
dsa_graph_safe: bool = False,
) -> torch.Tensor:
device = input.device
# Supports float32, float16, bfloat16
Expand All @@ -91,6 +92,7 @@ def radix_topk(
sorted_output,
deterministic,
tie_break,
dsa_graph_safe,
)
return output_indices

Expand All @@ -103,6 +105,7 @@ def _fake_radix_topk(
tie_break: int,
row_states_buffer: Optional[torch.Tensor],
output_values: torch.Tensor,
dsa_graph_safe: bool = False,
) -> torch.Tensor:
batch_size = input.size(0)
return torch.empty(batch_size, top_k, dtype=torch.int32, device=input.device)
Expand Down Expand Up @@ -250,6 +253,8 @@ def radix_topk_page_table_transform(
top_k: int,
deterministic: bool,
tie_break: int,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> None:
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
Expand All @@ -264,6 +269,8 @@ def radix_topk_page_table_transform(
top_k,
deterministic,
tie_break,
dsa_graph_safe,
row_starts,
)

@register_fake_op("flashinfer::radix_topk_page_table_transform")
Expand All @@ -277,6 +284,8 @@ def _fake_radix_topk_page_table_transform(
top_k: int,
deterministic: bool,
tie_break: int,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> None:
pass

Expand All @@ -293,6 +302,8 @@ def radix_topk_ragged_transform(
top_k: int,
deterministic: bool,
tie_break: int,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> None:
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
Expand All @@ -306,6 +317,8 @@ def radix_topk_ragged_transform(
top_k,
deterministic,
tie_break,
dsa_graph_safe,
row_starts,
)

@register_fake_op("flashinfer::radix_topk_ragged_transform")
Expand All @@ -318,6 +331,8 @@ def _fake_radix_topk_ragged_transform(
top_k: int,
deterministic: bool,
tie_break: int,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> None:
pass

Expand Down Expand Up @@ -477,7 +492,9 @@ def topk_clusters_ragged_transform(logits, seq_lens, offsets, top_k, pdl=False):
return indices


def can_use_clusters_topk(device, deterministic):
def can_use_clusters_topk(device, deterministic, dsa_graph_safe):
if dsa_graph_safe:
return False
algo = os.environ.get("FLASHINFER_TOPK_ALGO")
cap = get_compute_capability(device)
return (algo is None or algo == "clusters") and not deterministic and cap[0] == 10
Expand All @@ -490,6 +507,7 @@ def top_k(
sorted: bool = False,
deterministic: bool = False,
tie_break: int = TopKTieBreak.NONE,
dsa_graph_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Radix-based Top-K selection.

Expand Down Expand Up @@ -525,6 +543,9 @@ def top_k(
- ``2``: prefer larger indices

Default is ``0``.
dsa_graph_safe : bool, optional
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
Default is False.

Returns
-------
Expand Down Expand Up @@ -578,7 +599,7 @@ def top_k(
if tie_break != TopKTieBreak.NONE:
deterministic = True

if can_use_clusters_topk(input.device, deterministic):
if can_use_clusters_topk(input.device, deterministic, dsa_graph_safe):
indices, output_values = topk_clusters_exact(
input, k, output_values=True, out_dtype=torch.int64
)
Expand Down Expand Up @@ -613,6 +634,7 @@ def top_k(
tie_break,
row_states_buffer,
output_values,
dsa_graph_safe,
)

# Convert to int64 for compatibility
Expand Down Expand Up @@ -642,6 +664,8 @@ def top_k_page_table_transform(
row_to_batch: Optional[torch.Tensor] = None,
deterministic: bool = False,
tie_break: int = TopKTieBreak.NONE,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Fused Top-K selection + Page Table Transform for sparse attention.

Expand Down Expand Up @@ -682,6 +706,13 @@ def top_k_page_table_transform(
- ``2``: prefer larger indices

Default is ``0``.
dsa_graph_safe : bool, optional
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
Default is False.
row_starts : Optional[torch.Tensor], optional
Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``.
Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``.
Default is None (equivalent to all zeros).


Returns
Expand All @@ -694,7 +725,9 @@ def top_k_page_table_transform(
Note
----
- This is specifically designed for sparse attention's second stage.
- If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]]
- If lengths[i] <= k, the output simply contains
``src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]]`` (or start 0 when
``row_starts`` is None)
with remaining positions set to -1.

Examples
Expand All @@ -718,7 +751,11 @@ def top_k_page_table_transform(
if tie_break != TopKTieBreak.NONE:
deterministic = True

if can_use_clusters_topk(input.device, deterministic) and row_to_batch is None:
if (
can_use_clusters_topk(input.device, deterministic, dsa_graph_safe)
and row_to_batch is None
and row_starts is None
):
return topk_clusters_page_table_transform(input, lengths, src_page_table, k)

# Allocate row_states buffer for multi-CTA path
Expand All @@ -742,6 +779,8 @@ def top_k_page_table_transform(
k,
deterministic,
tie_break,
dsa_graph_safe,
row_starts=row_starts,
)

return output_page_table
Expand All @@ -755,6 +794,8 @@ def top_k_ragged_transform(
k: int,
deterministic: bool = False,
tie_break: int = TopKTieBreak.NONE,
dsa_graph_safe: bool = False,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Fused Top-K selection + Ragged Index Transform for sparse attention.

Expand Down Expand Up @@ -788,6 +829,14 @@ def top_k_ragged_transform(
- ``2``: prefer larger indices

Default is ``0``.
dsa_graph_safe : bool, optional
If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1).
Default is False.
row_starts : Optional[torch.Tensor], optional
Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``.
Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``.
Output indices remain ``local_topk + offsets[i]`` where ``local_topk`` is relative to
``row_starts[i]``. Default is None (equivalent to all zeros).


Returns
Expand Down Expand Up @@ -825,7 +874,10 @@ def top_k_ragged_transform(
if tie_break != TopKTieBreak.NONE:
deterministic = True

if can_use_clusters_topk(input.device, deterministic):
if (
can_use_clusters_topk(input.device, deterministic, dsa_graph_safe)
and row_starts is None
):
return topk_clusters_ragged_transform(input, lengths, offsets, k)

# Allocate row_states buffer for multi-CTA path
Expand All @@ -848,6 +900,8 @@ def top_k_ragged_transform(
k,
deterministic,
tie_break,
dsa_graph_safe,
row_starts=row_starts,
)

return output_indices
Loading
Loading