feat: Add row_starts and dsa_graph_safe to topk#3133
feat: Add row_starts and dsa_graph_safe to topk#3133kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
row_starts and dsa_graph_safe to topk#3133Conversation
📝 WalkthroughWalkthroughThreads a new boolean flag Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API
participant FFI as C++ FFI Binding
participant Dispatch as TopK Dispatch
participant Kernel as CUDA Kernel
Py->>FFI: call radix_topk(..., dsa_graph_safe, maybe_row_starts)
FFI->>Dispatch: forward tensors, dsa_graph_safe, row_starts_ptr
Dispatch->>Dispatch: choose path (FilteredTopK vs Radix) using dsa_graph_safe, tie_break
Dispatch->>Kernel: launch kernel with row_starts, dsa_graph_safe
Kernel-->>Dispatch: return results/status
Dispatch-->>FFI: propagate results
FFI-->>Py: return tensors
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
dsa_graph_safe flag to topk
There was a problem hiding this comment.
Code Review
This pull request introduces deterministic tie-breaking support for top-k operations, enabling users to specify whether to prefer smaller or larger indices for equal values at the selection boundary. The changes include the addition of a TopKTieBreak enum, updates to the CUDA kernels and Python API, and the implementation of a DeterministicContiguousCollect helper for contiguous index traversal. Benchmarking and testing utilities have also been expanded to cover these new modes. Review feedback highlights opportunities to improve performance by ensuring coalesced memory reads in the collection helper and suggests reusing shared memory buffers to stay within hardware limits.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
include/flashinfer/topk.cuh (1)
3218-3226:⚠️ Potential issue | 🟠 MajorLet tie-break requests override the benchmark algorithm override.
With
FLASHINFER_TOPK_ALGO=multi_cta, Line 3221 returnsfalsebefore the tie-break check, sotie_break=Small/Largesilently falls back to radix even though the comment says tie-break is only supported by FilteredTopK.Proposed fix
- // Check for algorithm override - const TopKAlgoOverride algo_override = GetTopKAlgoOverride(); - if (algo_override == TopKAlgoOverride::FILTERED) return true; - if (algo_override == TopKAlgoOverride::MULTI_CTA) return false; - // Tie-break modes are only supported by FilteredTopK if (tie_break != TopKTieBreak::None) { return true; } + + // Check for algorithm override + const TopKAlgoOverride algo_override = GetTopKAlgoOverride(); + if (algo_override == TopKAlgoOverride::FILTERED) return true; + if (algo_override == TopKAlgoOverride::MULTI_CTA) return false;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3218 - 3226, The current logic checks GetTopKAlgoOverride() before considering tie_break, which lets TopKAlgoOverride::MULTI_CTA override requested tie-breaks; change the branch order so tie-break requests take precedence: first check if tie_break != TopKTieBreak::None and return true (support FilteredTopK), then query GetTopKAlgoOverride() and handle TopKAlgoOverride::FILTERED / MULTI_CTA; update the function containing these checks (refer to GetTopKAlgoOverride, TopKAlgoOverride, and TopKTieBreak) so tie-break modes always force the FilteredTopK path.benchmarks/bench_topk.py (1)
883-889:⚠️ Potential issue | 🟡 MinorThe
"sglang_error"key is never populated — this branch is dead and inconsistent with other sections.Line 888 checks
"sglang_error" in result, butsglang_erroris not set anywhere in the codebase. The sglang block (lines 208–212) only writessglang_us, and failures surface asRuntimeErrorexceptions caught at lines 891–899. This makes theelifat line 888 unreachable.Additionally, the analogous fallback branches in page_table (line 1126) and ragged (line 1231) still use the original
k == 2048check. This inconsistency suggests incomplete refactoring—either restore thek == 2048check in the top_k section or populateresult["sglang_error"]and mirror the change in page_table and ragged sections.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 883 - 889, The branch checking "sglang_error" is dead because result["sglang_error"] is never set; fix by either (A) when catching the RuntimeError in the top_k benchmark code path (the block that currently writes result["sglang_us"]) set result["sglang_error"]=True (or an error message) so the existing display branch can detect failures, and update the analogous page_table and ragged sections to populate the same key for consistency; or (B) revert the refactor and restore the original k == 2048 fallback checks in the top_k, page_table and ragged reporting code so the fallback branches behave the same across all sections—choose one approach and apply it consistently to result handling for sglang.
🧹 Nitpick comments (4)
include/flashinfer/topk.cuh (1)
234-236: Document the hot-path tradeoffs.
ITEMS_PER_THREAD = 4and forcingvec_size = 1fordsa_graph_safeare special performance-sensitive choices. Please add a short rationale and note the alternative considered, especially because Line 234 already leaves this as a TODO.As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered."
Also applies to: 2843-2846
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 234 - 236, Add a concise comment in the hot path explaining why ITEMS_PER_THREAD is set to 4 and CHUNK_ITEMS derived from it (e.g., memory/register pressure vs. occupancy tradeoff, cache/vectorization limits) and document the decision to force vec_size = 1 for dsa_graph_safe (e.g., alignment/unaligned memory access, divergent control flow, or correctness constraints) along with the primary alternative(s) considered (e.g., ITEMS_PER_THREAD=8 or using vectorized loads) and why they were rejected (impact on shared memory, register usage, or branch divergence). Place this justification adjacent to the ITEMS_PER_THREAD/CHUNK_ITEMS definitions and mirror a similar explanatory note where vec_size is set for dsa_graph_safe so future maintainers can understand the performance tradeoffs and tuning rationale.tests/utils/test_topk.py (1)
1931-2050: Add coverage fordsa_graph_safe=True.These new tests cover tie-break behavior, but the PR’s graph-safe flag can regress independently through routing and
VEC_SIZE=1dispatch. Please add at least onetop_kand one transform API case withdsa_graph_safe=True; ideally include a CUDA graph capture/replay smoke test.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1931 - 2050, Add tests that exercise the dsa_graph_safe=True path: in test_top_k_tie_break_modes add a case that calls flashinfer.top_k(logits, k, tie_break=1/2, dsa_graph_safe=True) (use the same seed/generator and skip logic with can_implement_filtered_topk() and set_topk_algo), and in test_top_k_tie_break_modes_transform_apis add calls to flashinfer.top_k_page_table_transform(..., tie_break=1/2, dsa_graph_safe=True) and flashinfer.top_k_ragged_transform(..., tie_break=1/2, dsa_graph_safe=True) validating expected indices/values as done for the non-graph-safe variants; optionally wrap one of these calls in a simple CUDA graph capture/replay smoke test to ensure graph capture works.flashinfer/topk.py (1)
499-540: Optional: annotatetie_breakwith the enum type.Since
TopKTieBreakis now a first-class public enum and the default is aTopKTieBreakmember, consider typing the parameter asTopKTieBreak(orUnion[TopKTieBreak, int]) across all three public APIs (top_k,top_k_page_table_transform,top_k_ragged_transform).IntEnumvalues still satisfy the FFI int conversion, so runtime behavior is unchanged, but callers get enum-level type checking and IDE completion instead of a bareint.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 499 - 540, Update the tie_break parameter annotations to use the TopKTieBreak enum (or Union[TopKTieBreak, int]) in the public APIs so callers get enum-level typing and IDE completion: change the type on function signatures for top_k, top_k_page_table_transform, and top_k_ragged_transform to TopKTieBreak (or Union[TopKTieBreak, int]) while leaving default values and runtime behavior unchanged; ensure imports/typing references for TopKTieBreak are added where needed and run typechecks to confirm no FFI/int conversion assumptions are broken.benchmarks/bench_topk.py (1)
89-103: Nit: bindtie_breakvia a default argument to silence B023 and harden against future refactors.Ruff flags
B023on line 95. Today this is a false positive —bench_median_msconsumes the lambda synchronously before the loop advances, so the late-binding hazard does not actually trigger. It’s still cheap to make the capture explicit in case the lambda is ever deferred (e.g., scheduled, stored, or passed to an async benchmarker):Proposed defensive fix
- for suffix, tie_break in TIE_BREAK_VARIANTS: - try: - tie_ms = bench_median_ms(lambda: run_flashinfer_with_tie_break(tie_break)) + for suffix, tie_break in TIE_BREAK_VARIANTS: + try: + tie_ms = bench_median_ms( + lambda tb=tie_break: run_flashinfer_with_tie_break(tb) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 89 - 103, The loop in bench_tie_break_variants closes over tie_break causing a potential late-binding issue flagged by Ruff B023; change the lambda passed to bench_median_ms to capture tie_break as a default argument (e.g., lambda tb=tie_break: run_flashinfer_with_tie_break(tb)) so the current tie_break value is bound immediately; update the invocation around bench_median_ms(...) and leave the rest of the logic (metrics keys using suffix, error handling/classify_benchmark_runtime_error) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/bench_topk.py`:
- Around line 883-889: The branch checking "sglang_error" is dead because
result["sglang_error"] is never set; fix by either (A) when catching the
RuntimeError in the top_k benchmark code path (the block that currently writes
result["sglang_us"]) set result["sglang_error"]=True (or an error message) so
the existing display branch can detect failures, and update the analogous
page_table and ragged sections to populate the same key for consistency; or (B)
revert the refactor and restore the original k == 2048 fallback checks in the
top_k, page_table and ragged reporting code so the fallback branches behave the
same across all sections—choose one approach and apply it consistently to result
handling for sglang.
In `@include/flashinfer/topk.cuh`:
- Around line 3218-3226: The current logic checks GetTopKAlgoOverride() before
considering tie_break, which lets TopKAlgoOverride::MULTI_CTA override requested
tie-breaks; change the branch order so tie-break requests take precedence: first
check if tie_break != TopKTieBreak::None and return true (support FilteredTopK),
then query GetTopKAlgoOverride() and handle TopKAlgoOverride::FILTERED /
MULTI_CTA; update the function containing these checks (refer to
GetTopKAlgoOverride, TopKAlgoOverride, and TopKTieBreak) so tie-break modes
always force the FilteredTopK path.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 89-103: The loop in bench_tie_break_variants closes over tie_break
causing a potential late-binding issue flagged by Ruff B023; change the lambda
passed to bench_median_ms to capture tie_break as a default argument (e.g.,
lambda tb=tie_break: run_flashinfer_with_tie_break(tb)) so the current tie_break
value is bound immediately; update the invocation around bench_median_ms(...)
and leave the rest of the logic (metrics keys using suffix, error
handling/classify_benchmark_runtime_error) unchanged.
In `@flashinfer/topk.py`:
- Around line 499-540: Update the tie_break parameter annotations to use the
TopKTieBreak enum (or Union[TopKTieBreak, int]) in the public APIs so callers
get enum-level typing and IDE completion: change the type on function signatures
for top_k, top_k_page_table_transform, and top_k_ragged_transform to
TopKTieBreak (or Union[TopKTieBreak, int]) while leaving default values and
runtime behavior unchanged; ensure imports/typing references for TopKTieBreak
are added where needed and run typechecks to confirm no FFI/int conversion
assumptions are broken.
In `@include/flashinfer/topk.cuh`:
- Around line 234-236: Add a concise comment in the hot path explaining why
ITEMS_PER_THREAD is set to 4 and CHUNK_ITEMS derived from it (e.g.,
memory/register pressure vs. occupancy tradeoff, cache/vectorization limits) and
document the decision to force vec_size = 1 for dsa_graph_safe (e.g.,
alignment/unaligned memory access, divergent control flow, or correctness
constraints) along with the primary alternative(s) considered (e.g.,
ITEMS_PER_THREAD=8 or using vectorized loads) and why they were rejected (impact
on shared memory, register usage, or branch divergence). Place this
justification adjacent to the ITEMS_PER_THREAD/CHUNK_ITEMS definitions and
mirror a similar explanatory note where vec_size is set for dsa_graph_safe so
future maintainers can understand the performance tradeoffs and tuning
rationale.
In `@tests/utils/test_topk.py`:
- Around line 1931-2050: Add tests that exercise the dsa_graph_safe=True path:
in test_top_k_tie_break_modes add a case that calls flashinfer.top_k(logits, k,
tie_break=1/2, dsa_graph_safe=True) (use the same seed/generator and skip logic
with can_implement_filtered_topk() and set_topk_algo), and in
test_top_k_tie_break_modes_transform_apis add calls to
flashinfer.top_k_page_table_transform(..., tie_break=1/2, dsa_graph_safe=True)
and flashinfer.top_k_ragged_transform(..., tie_break=1/2, dsa_graph_safe=True)
validating expected indices/values as done for the non-graph-safe variants;
optionally wrap one of these calls in a simple CUDA graph capture/replay smoke
test to ensure graph capture works.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 88e1832b-cadc-43e2-b4c6-4c84155aaf21
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/__init__.pyflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
6bbd1da to
30d7210
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/topk.cuh (1)
3070-3104:⚠️ Potential issue | 🟡 MinorMake
tie_breakimply deterministic mode inside the filtered launcher.
LaunchFilteredTopKUnifiedexposestie_break, but direct callers passingtie_break != Nonewithdeterministic=falsestill launch the non-deterministicTopKTieBreak::Nonespecialization. The higher-level dispatchers normalize this today, but this wrapper should enforce its own API contract.Suggested fix
cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_output, const IdType* aux_input, int64_t aux_stride, const IdType* row_to_batch, const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0, bool dsa_graph_safe = false) { constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; constexpr int MAX_VEC = 16 / sizeof(DType); + const bool effective_deterministic = deterministic || tie_break != TopKTieBreak::None; @@ `#define` DISPATCH_VEC_SIZE(VS) \ if (vec_size == VS) { \ - if (!deterministic) { \ + if (!effective_deterministic) { \ LAUNCH_FILTERED_KERNEL(VS, false, TopKTieBreak::None); \ } else { \ if (tie_break == TopKTieBreak::Small) { \ LAUNCH_FILTERED_KERNEL(VS, true, TopKTieBreak::Small); \🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3070 - 3104, The launcher currently ignores a non-None tie_break when deterministic==false; change the dispatch to compute an effective deterministic flag (e.g., bool effective_det = deterministic || (tie_break != TopKTieBreak::None)) and use effective_det in DISPATCH_VEC_SIZE/launch logic so that any tie_break != TopKTieBreak::None forces the deterministic specialization via LAUNCH_FILTERED_KERNEL(..., true, tie_break) while preserving the existing non-deterministic path only when effective_det is false; update references to deterministic in the DISPATCH_VEC_SIZE block to use this effective_det and select the correct TopKTieBreak template parameter accordingly.
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)
234-236: Document the fixed chunking choice or remove the TODO.
ITEMS_PER_THREAD = 4is now part of a performance-sensitive deterministic tie-break path. Please either justify why4is the intended trade-off here or link this TODO to a tracked tuning task so the algorithmic choice is explicit. As per coding guidelines, For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 234 - 236, Replace the TODO by either (a) adding a short justification comment next to ITEMS_PER_THREAD = 4 explaining why 4 was chosen (trade-offs tested, microbenchmarks summary, sensitivity in the deterministic tie-break hot path, interaction with BLOCK_THREADS and CHUNK_ITEMS, and why vectorization wasn't chosen), or (b) if the number is provisional, remove the TODO and add a one-line reference to a tracked tuning task/issue ID that contains the benchmarking results and alternative values tested; ensure the comment mentions the symbols ITEMS_PER_THREAD, CHUNK_ITEMS and BLOCK_THREADS and that this choice affects the deterministic tie-break/performance-critical path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 89-103: In bench_tie_break_variants, the lambda passed to
bench_median_ms captures the loop variable tie_break by reference and the
metrics dict is annotated too narrowly as dict[str, float]; fix by binding the
loop variable in the lambda (e.g., make it a default arg so you call
run_flashinfer_with_tie_break(tie_break=tie_break) inside the lambda) and widen
the return type to allow string error labels (e.g., change the annotation from
dict[str, float] to dict[str, float | str] or dict[str, Any]); keep references
to TIE_BREAK_VARIANTS, run_flashinfer_with_tie_break,
classify_benchmark_runtime_error, and metrics when making the edits.
---
Outside diff comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3070-3104: The launcher currently ignores a non-None tie_break
when deterministic==false; change the dispatch to compute an effective
deterministic flag (e.g., bool effective_det = deterministic || (tie_break !=
TopKTieBreak::None)) and use effective_det in DISPATCH_VEC_SIZE/launch logic so
that any tie_break != TopKTieBreak::None forces the deterministic specialization
via LAUNCH_FILTERED_KERNEL(..., true, tie_break) while preserving the existing
non-deterministic path only when effective_det is false; update references to
deterministic in the DISPATCH_VEC_SIZE block to use this effective_det and
select the correct TopKTieBreak template parameter accordingly.
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 234-236: Replace the TODO by either (a) adding a short
justification comment next to ITEMS_PER_THREAD = 4 explaining why 4 was chosen
(trade-offs tested, microbenchmarks summary, sensitivity in the deterministic
tie-break hot path, interaction with BLOCK_THREADS and CHUNK_ITEMS, and why
vectorization wasn't chosen), or (b) if the number is provisional, remove the
TODO and add a one-line reference to a tracked tuning task/issue ID that
contains the benchmarking results and alternative values tested; ensure the
comment mentions the symbols ITEMS_PER_THREAD, CHUNK_ITEMS and BLOCK_THREADS and
that this choice affects the deterministic tie-break/performance-critical path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 51145eb8-8c31-4857-a35e-958b56de4bbb
📒 Files selected for processing (5)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuh
5432f6d to
e5f4eb0
Compare
e5f4eb0 to
63cfd7f
Compare
911aca7 to
20061c2
Compare
dsa_graph_safe flag to topkrow_starts and dsa_graph_safe to topk
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/topk.py (2)
789-799:⚠️ Potential issue | 🟠 MajorPreserve positional compatibility for
deterministic.
row_startsis inserted before the existingdeterministicparameter, so existing calls liketop_k_ragged_transform(scores, offsets, lengths, k, True)now passTrueasrow_starts.Proposed fix
def top_k_ragged_transform( input: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, k: int, - row_starts: Optional[torch.Tensor] = None, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, + row_starts: Optional[torch.Tensor] = None, dsa_graph_safe: bool = False, ) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 789 - 799, The signature change to top_k_ragged_transform moved row_starts before deterministic, breaking positional callers; restore positional compatibility by ensuring deterministic remains the positional parameter before row_starts (i.e., place deterministic as the parameter immediately after k and make row_starts either follow deterministic or be keyword-only), update the function signature accordingly and adjust any internal usage of row_starts/deterministic inside top_k_ragged_transform to match the restored parameter order.
658-669:⚠️ Potential issue | 🟠 MajorPreserve positional compatibility for
row_to_batch.
row_startsis inserted before the existingrow_to_batchparameter, so existing calls liketop_k_page_table_transform(scores, table, lengths, k, row_to_batch)now bind that tensor asrow_startsand silently compute the wrong mapping.Proposed fix
def top_k_page_table_transform( input: torch.Tensor, src_page_table: torch.Tensor, lengths: torch.Tensor, k: int, - row_starts: Optional[torch.Tensor] = None, row_to_batch: Optional[torch.Tensor] = None, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, + row_starts: Optional[torch.Tensor] = None, dsa_graph_safe: bool = False, ) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 658 - 669, The function signature change in top_k_page_table_transform broke positional compatibility by inserting row_starts before the existing row_to_batch parameter; restore compatibility by reordering the parameters so row_to_batch appears before row_starts (i.e., keep the original positional order: ..., k, row_to_batch: Optional[torch.Tensor]=None, row_starts: Optional[torch.Tensor]=None, deterministic=..., tie_break=..., dsa_graph_safe=...), update any internal references to use the renamed parameters accordingly, and run tests that call top_k_page_table_transform(positionally) to confirm behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 789-799: The signature change to top_k_ragged_transform moved
row_starts before deterministic, breaking positional callers; restore positional
compatibility by ensuring deterministic remains the positional parameter before
row_starts (i.e., place deterministic as the parameter immediately after k and
make row_starts either follow deterministic or be keyword-only), update the
function signature accordingly and adjust any internal usage of
row_starts/deterministic inside top_k_ragged_transform to match the restored
parameter order.
- Around line 658-669: The function signature change in
top_k_page_table_transform broke positional compatibility by inserting
row_starts before the existing row_to_batch parameter; restore compatibility by
reordering the parameters so row_to_batch appears before row_starts (i.e., keep
the original positional order: ..., k, row_to_batch:
Optional[torch.Tensor]=None, row_starts: Optional[torch.Tensor]=None,
deterministic=..., tie_break=..., dsa_graph_safe=...), update any internal
references to use the renamed parameters accordingly, and run tests that call
top_k_page_table_transform(positionally) to confirm behavior is unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: be26a571-3378-41eb-9773-499092e2e2f0
📒 Files selected for processing (5)
csrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/topk.cuh
|
/bot run |
jiahanc
left a comment
There was a problem hiding this comment.
LGTM, thanks for contribution!
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM overall except for one comment! Will approve after this is updated + /bot run CICD passes
| src_page_table: torch.Tensor, | ||
| lengths: torch.Tensor, | ||
| k: int, | ||
| row_starts: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Can we move this new optional arguments to the end? This breaks positional ordering for existing callers; this may break backwards compatibility for API definition
|
Hi @kahyunnam I have made the requested changes. All python APIs and top-level C++ bindings have both args at the end. Internal implementation still use previous ordering for better readability. Thank you! |
There was a problem hiding this comment.
🧹 Nitpick comments (4)
flashinfer/topk.py (3)
849-854: Minor doc gap: clarifyrow_startsinteraction with the trivial ragged path.For
top_k_ragged_transform, the "If lengths[i] <= k" note still reads as ifrow_startshas no trivial-case role, but callers may reasonably assume symmetry withtop_k_page_table_transform(which now documents the row-shifted slice). A short clarification avoids ambiguity, e.g.:📝 Suggested wording
- - If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1] - with remaining positions set to -1. + - If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1] + with remaining positions set to -1. ``row_starts`` only shifts the score window used for + top-k selection; it does not shift these local indices in the trivial case.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 849 - 854, The docstring for top_k_ragged_transform is ambiguous about how row_starts shifts indices in the trivial ragged path (lengths[i] <= k); update the Note to explicitly state that when lengths[i] <= k the returned indices are the sequence [row_starts[i]+offsets[i], row_starts[i]+offsets[i]+1, ..., row_starts[i]+offsets[i]+lengths[i]-1] with remaining positions set to -1, mirroring the documented behavior/symmetry of top_k_page_table_transform; reference top_k_ragged_transform, row_starts, offsets, lengths, and top_k_page_table_transform in the docstring so callers aren’t confused about whether indices are row-shifted.
495-500: Defaultdsa_graph_safeto preserve backward compatibility.
can_use_clusters_topkis a module-level (non-underscore) helper. Adding a required third positional parameter is technically a breaking change for any external caller. Given the PR's explicit "Keep API backward compatibility" intent (and the prior review feedback about positional ordering), consider defaulting it:♻️ Suggested default
-def can_use_clusters_topk(device, deterministic, dsa_graph_safe): +def can_use_clusters_topk(device, deterministic, dsa_graph_safe=False): if dsa_graph_safe: return False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 495 - 500, can_use_clusters_topk currently requires a third positional parameter dsa_graph_safe which is a breaking API change; make dsa_graph_safe optional with a default value (e.g., False) so existing callers keep current behavior, update the function signature for can_use_clusters_topk to set dsa_graph_safe=False and ensure the function body still uses the parameter as before, and scan for external uses of can_use_clusters_topk to confirm none rely on a mandatory third argument.
728-731: Nit: tighten the trivial-case wording for readability.The inline parenthetical splits an RST code reference across lines and reads awkwardly. A small rewrite keeps the code literal intact:
📝 Suggested wording
- - If lengths[i] <= k, the output simply contains - ``src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]]`` (or start 0 when - ``row_starts`` is None) - with remaining positions set to -1. + - If lengths[i] <= k, the output simply contains + ``src_page_table[batch_idx, s:s + lengths[i]]`` where ``s = row_starts[i]`` (or 0 when + ``row_starts`` is None), with remaining positions set to -1.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 728 - 731, The docstring sentence describing the trivial case splits the RST code reference across lines and reads awkwardly; update the sentence in topk.py that starts "If lengths[i] <= k" to keep the code literal intact by making it a single clear clause referencing the symbols lengths, k, src_page_table, row_starts and batch_idx — e.g. state that when lengths[i] <= k the output contains the entries of src_page_table for batch_idx from row_starts[i] to row_starts[i] + lengths[i], with row_starts treated as starting at 0 when row_starts is None, and any remaining positions set to -1.tests/utils/test_topk.py (1)
459-476: Consider adding trivial-length coverage forrow_starts.In the ragged reference,
row_startis read but intentionally unused in the trivial branch (length <= k), matching the documented semantics (output islocal_topk + offsets[i]). The newtest_top_k_transform_with_row_startsforceslengths >= k+1, so the kernel's trivial-length behavior under non-zerorow_startsis not validated against this reference. A small additional case (e.g., one row withlengths[i] <= kandrow_starts[i] > 0) would close that gap for both transforms.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 459 - 476, Add a trivial-length test case to exercise the branch where length <= k while row_starts is non-zero: in tests/utils/test_topk.py (the test_top_k_transform_with_row_starts setup), append or insert one row whose lengths[i] <= k and row_starts[i] > 0 (ensure offsets[i] is set) and verify output[i, :length] equals torch.arange(offset, offset+length) so the reference path that ignores row_start is validated; update any generated scores/slices accordingly so that this single-row case triggers the trivial branch alongside the existing longer rows.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 849-854: The docstring for top_k_ragged_transform is ambiguous
about how row_starts shifts indices in the trivial ragged path (lengths[i] <=
k); update the Note to explicitly state that when lengths[i] <= k the returned
indices are the sequence [row_starts[i]+offsets[i], row_starts[i]+offsets[i]+1,
..., row_starts[i]+offsets[i]+lengths[i]-1] with remaining positions set to -1,
mirroring the documented behavior/symmetry of top_k_page_table_transform;
reference top_k_ragged_transform, row_starts, offsets, lengths, and
top_k_page_table_transform in the docstring so callers aren’t confused about
whether indices are row-shifted.
- Around line 495-500: can_use_clusters_topk currently requires a third
positional parameter dsa_graph_safe which is a breaking API change; make
dsa_graph_safe optional with a default value (e.g., False) so existing callers
keep current behavior, update the function signature for can_use_clusters_topk
to set dsa_graph_safe=False and ensure the function body still uses the
parameter as before, and scan for external uses of can_use_clusters_topk to
confirm none rely on a mandatory third argument.
- Around line 728-731: The docstring sentence describing the trivial case splits
the RST code reference across lines and reads awkwardly; update the sentence in
topk.py that starts "If lengths[i] <= k" to keep the code literal intact by
making it a single clear clause referencing the symbols lengths, k,
src_page_table, row_starts and batch_idx — e.g. state that when lengths[i] <= k
the output contains the entries of src_page_table for batch_idx from
row_starts[i] to row_starts[i] + lengths[i], with row_starts treated as starting
at 0 when row_starts is None, and any remaining positions set to -1.
In `@tests/utils/test_topk.py`:
- Around line 459-476: Add a trivial-length test case to exercise the branch
where length <= k while row_starts is non-zero: in tests/utils/test_topk.py (the
test_top_k_transform_with_row_starts setup), append or insert one row whose
lengths[i] <= k and row_starts[i] > 0 (ensure offsets[i] is set) and verify
output[i, :length] equals torch.arange(offset, offset+length) so the reference
path that ignores row_start is validated; update any generated scores/slices
accordingly so that this single-row case triggers the trivial branch alongside
the existing longer rows.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f2e2f667-ff32-401b-a377-dac5ae258868
📒 Files selected for processing (4)
csrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pytests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/topk.cu
📌 Description
@HumansAnd
Parent PR: #3095
SGLang PR: sgl-project/sglang#22851
Add
row_startsanddsa_graph_safefor SGLang DSA integration.🔍 Related Issues
sgl-project/sglang#22851 (comment)
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Behavior
Tests