feat: implement deterministic topk#2661
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an opt-in deterministic mode across the top-k stack: Python APIs, FFI bindings, C++ dispatch, and CUDA kernels; implements deterministic multi-CTA collection and stable tie‑breaking, updates benchmarks/CLI for deterministic comparisons and DSA workloads, and adds deterministic-focused tests and helpers. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API
participant Bind as FFI Binding
participant Dispatch as C++ Dispatcher
participant Kernel as CUDA Kernel
Py->>Bind: call top_k(..., deterministic=True)
Bind->>Dispatch: radix_topk(..., sorted_output=..., deterministic=...)
Dispatch->>Kernel: launch deterministic-aware kernel (det scratch, DETERMINISTIC)
Kernel->>Kernel: deterministic collect / stable tie-breaking / optional stable sort
Kernel-->>Dispatch: return indices & values
Dispatch-->>Bind: propagate results
Bind-->>Py: deliver deterministic outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant feature by enabling deterministic behavior for all top-k related operations within FlashInfer. It provides users with fine-grained control over determinism, which is crucial for reproducibility in machine learning and scientific computing. The changes span Python APIs, CUDA kernels, and benchmarking tools, ensuring robust implementation and verification of this new capability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: deterministic top-k selection. The changes are extensive, adding new execution paths to both the radix and filtered top-k algorithms to ensure reproducible results, which is particularly important for handling ties. The implementation includes backward compatibility for existing APIs by adding new optional parameters. The benchmarks and tests have been updated comprehensively to cover the new deterministic modes. The overall implementation is well-designed and robust. I have one suggestion to improve code clarity and remove a minor redundancy in the CUDA kernel.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 51-56: The benchmark currently enters the
torch_deterministic_algorithms context inside run_torch_topk on every iteration,
adding overhead; instead enable deterministic mode once before the timing loop
and restore the prior state afterwards, removing the per-iteration context from
run_torch_topk (and the analogous per-iteration context in the other benchmark
at lines 136-145); specifically, call the global deterministic enable API once
(save the previous value), run the repeated torch.topk calls normally inside
run_torch_topk, then restore the saved deterministic setting after the loop so
the timing measures only torch.topk cost.
In `@tests/utils/test_topk.py`:
- Around line 1492-1514: The BF16 reproducibility test
(test_top_k_reproducible_mode_repeatability_multi_cta) runs unconditionally but
must be skipped on GPUs with compute capability < SM80; add a guard at the start
of the test that calls flashinfer.utils.get_compute_capability() (or the project
helper like flashinfer.utils.is_sm90a_supported/is_sm80_supported) and use
pytest.skip(...) when the capability is below 80 to avoid running BF16 on
unsupported hardware; also add an import for pytest if it's not present.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/__init__.pyflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
4358ff1 to
9e88bc8
Compare
9e88bc8 to
42a86f9
Compare
42a86f9 to
7679c40
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
tests/utils/test_topk.py (1)
1487-1499:⚠️ Potential issue | 🟡 MinorRename
sorted; Ruff still flags this helper.Line 1488 shadows the Python builtin, so this helper keeps tripping A002.
sorted_outputavoids the lint with no behavior change.🧹 Minimal rename
def _assert_top_k_matches_torch( - logits: torch.Tensor, k: int, *, deterministic: bool = False, sorted: bool = True + logits: torch.Tensor, + k: int, + *, + deterministic: bool = False, + sorted_output: bool = True, ): """Assert FlashInfer top_k matches torch.topk for exact-order cases.""" values, indices = flashinfer.top_k( - logits, k, deterministic=deterministic, sorted=sorted + logits, k, deterministic=deterministic, sorted=sorted_output ) - ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=sorted) + ref_values, ref_indices = torch.topk( + logits, k, dim=-1, sorted=sorted_output + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1487 - 1499, Rename the parameter named sorted in the helper function _assert_top_k_matches_torch to avoid shadowing the built-in; change the parameter name to sorted_output and update all uses inside the function (the flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any internal references if present), leaving the behavior and variable names values, indices, ref_values, ref_indices unchanged.include/flashinfer/topk.cuh (2)
232-240:⚠️ Potential issue | 🔴 CriticalSynchronize the CTA before publishing the radix-group arrival.
AdvanceRadixGroupBarrier()still lets Line 235 advancearrival_counterbefore the rest of the block is forced to finish its preceding histogram/output writes. The current callers at Line 468, Line 648, and Line 851 hit it immediately after per-thread atomics/stores, so another CTA can observe partially updated state and break correctness/determinism again.🔧 Minimal fix
__device__ __forceinline__ void AdvanceRadixGroupBarrier(RadixRowState* state, int& barrier_phase, uint32_t ctas_per_group, uint32_t tx) { + __syncthreads(); if (tx == 0) { red_release(&state->arrival_counter, 1); } int target = (barrier_phase + 1) * ctas_per_group; wait_ge(&state->arrival_counter, target, tx);Expected result: either the helper owns the CTA sync, or every releasing call site shows an immediate
__syncthreads()before it.#!/bin/bash set -euo pipefail sed -n '232,240p' include/flashinfer/topk.cuh sed -n '452,470p' include/flashinfer/topk.cuh sed -n '635,650p' include/flashinfer/topk.cuh sed -n '835,855p' include/flashinfer/topk.cuh sed -n '1256,1263p' include/flashinfer/topk.cuh🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 232 - 240, AdvanceRadixGroupBarrier currently releases the radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is synchronized, allowing other CTAs to observe partially written per-thread state; fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a __syncthreads() immediately before the tx==0 release path so the block finishes all histogram/output stores before calling red_release, leaving the existing wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing __syncthreads() intact.
3241-3258:⚠️ Potential issue | 🟠 MajorCanonicalize radix ties before the stable value sort.
Line 3246 index-sorts only the filtered deterministic path. When Line 3251 routes deterministic work through radix,
StableSortTopKByValue()on Line 3256 preserves the deterministic collection order fromRadixCollectIndicesDeterministic, sosorted=True, deterministic=Truestill returns a different tie order depending on which algorithm was selected.🔧 Suggested fix
} else { FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>( input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream))); + if (deterministic && sorted_output) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } } if (sorted_output) { FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>( output_indices, output_values, num_rows, top_k_val, max_len, stream)));🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3241 - 3258, The deterministic canonicalization (index-sort via LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based deterministic results are canonicalized the same way before the stable value sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the same arguments used in the filtered branch (output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that StableSortTopKByValue sees a canonical tie order regardless of which algorithm ran; keep the existing filtered-path LaunchSortTopKByIndex and the final StableSortTopKByValue intact.
🧹 Nitpick comments (2)
benchmarks/bench_topk.py (1)
209-223: Consider using-float('inf')consistently forneg_inffallback.For
fp16/bf16, usingtorch.finfo(dtype).mininstead of-infmeans values at the minimum representable float could still be selected over the masked positions. If the intent is to fully exclude masked positions from top-k selection,-inf(which is representable in fp16/bf16) would be more robust.🔧 Suggested fix
- neg_inf = -torch.inf if dtype == torch.float32 else torch.finfo(dtype).min + neg_inf = float('-inf')🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 209 - 223, The masking uses torch.finfo(dtype).min for neg_inf when dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use a true negative infinity constant (e.g. -float('inf')) for the masked_fill value so masked positions are fully excluded when you call scores = scores.masked_fill(invalid, neg_inf).flashinfer/topk.py (1)
176-182: Docstring could clarify tie-breaking strategy for deterministic mode.The PR objectives and issue
#2584mention that deterministic mode uses "lower element index wins" for tie-breaking. Consider adding this detail to the docstring so users understand the expected behavior when values are equal.📝 Suggested docstring enhancement
deterministic : bool, optional If True, uses deterministic mode. Default is False (non-deterministic, which is faster). Deterministic mode guarantees repeatable FlashInfer output ordering for - the selected top-k set on a fixed input and system. + the selected top-k set on a fixed input and system. When values are equal, + elements with lower indices are selected first (stable tie-breaking).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 176 - 182, Update the docstring for the top_k function to explicitly state the tie-breaking rule used when deterministic=True: when values are equal the element with the lower index is chosen ("lower element index wins"). Mention this behavior near the deterministic parameter description in the top_k docstring so callers know how ties are resolved and that ordering may differ from non-deterministic behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 213-218: The cached row-state workspace currently doesn't include
space for the deterministic scratch tail used by
MaybeGetRadixDeterministicCollectScratchBuffer, so when deterministic &&
!single_cta the pointer (row_states_buffer + num_groups) can walk past the
allocated buffer; update all allocation sites that create
radix_topk_row_states_* (in flashinfer/topk.py and any C++/CUDA allocs) to
reserve room for both RadixRowState[num_groups] and
RadixDeterministicCollectScratch[num_groups] (i.e. allocate num_groups of
RadixRowState plus num_groups of RadixDeterministicCollectScratch, or
equivalently adjust byte-size to
num_groups*(sizeof(RadixRowState)+sizeof(RadixDeterministicCollectScratch))),
and ensure any cached size calculations and related comments reflect this change
so deterministic multi-CTA no longer overruns the buffer.
In `@tests/utils/test_topk.py`:
- Around line 1896-1937: The tests
test_top_k_deterministic_sorted_large_k_matches_torch_by_algo and
test_top_k_deterministic_trivial_k_equals_length_by_algo currently parametrize
over "filtered" but use k values (4096 and vocab_size) larger than
FILTERED_TOPK_MAX_K (defined in include/flashinfer/topk.cuh as 2048), so they
never exercise FilteredTopK; update the parametrization to only use ["auto",
"multi_cta"] for these two tests, or alternatively add a separate test case that
explicitly uses set_topk_algo("filtered") with k <= FILTERED_TOPK_MAX_K (e.g.,
k=2048) to validate the filtered path.
---
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently releases the
radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is
synchronized, allowing other CTAs to observe partially written per-thread state;
fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a
__syncthreads() immediately before the tx==0 release path so the block finishes
all histogram/output stores before calling red_release, leaving the existing
wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing
__syncthreads() intact.
- Around line 3241-3258: The deterministic canonicalization (index-sort via
LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based
deterministic results are canonicalized the same way before the stable value
sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is
true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the
same arguments used in the filtered branch (output_indices, output_values,
nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that
StableSortTopKByValue sees a canonical tie order regardless of which algorithm
ran; keep the existing filtered-path LaunchSortTopKByIndex and the final
StableSortTopKByValue intact.
In `@tests/utils/test_topk.py`:
- Around line 1487-1499: Rename the parameter named sorted in the helper
function _assert_top_k_matches_torch to avoid shadowing the built-in; change the
parameter name to sorted_output and update all uses inside the function (the
flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any
internal references if present), leaving the behavior and variable names values,
indices, ref_values, ref_indices unchanged.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 209-223: The masking uses torch.finfo(dtype).min for neg_inf when
dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in
the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use
a true negative infinity constant (e.g. -float('inf')) for the masked_fill value
so masked positions are fully excluded when you call scores =
scores.masked_fill(invalid, neg_inf).
In `@flashinfer/topk.py`:
- Around line 176-182: Update the docstring for the top_k function to explicitly
state the tie-breaking rule used when deterministic=True: when values are equal
the element with the lower index is chosen ("lower element index wins"). Mention
this behavior near the deterministic parameter description in the top_k
docstring so callers know how ties are resolved and that ordering may differ
from non-deterministic behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ae026e77-9523-469a-8f61-393293ddf092
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/flashinfer_topk_binding.cu
- csrc/topk.cu
|
@coderabbitai review /gemini review |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
🧠 Learnings used✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
include/flashinfer/topk.cuh (1)
3259-3267:⚠️ Potential issue | 🟠 MajorRadix deterministic path omits index canonicalization; tie order differs from filtered path.
The filtered path calls
SortTopKByIndexwhendeterministic=true(line 3255-3257), but the radix path (line 3260-3262) does not. This means:
- Filtered + deterministic: ties ordered by ascending index
- Radix + deterministic: ties ordered by thread-strided emission
When
sorted_output=true,StableSortTopKByValuepreserves prior ordering for ties, so the final tie-breaking strategy depends on which algorithm was selected—breaking the contract that deterministic mode produces consistent, index-ordered ties.🔧 Suggested fix
} else { FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>( input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream))); + if (deterministic) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3259 - 3267, The radix deterministic path omits the canonical index tie-break used by the filtered path; after RadixTopKMultiCTA completes and when deterministic==true, call SortTopKByIndex with the same arguments used by the filtered path (e.g., output_indices, output_values, num_rows, top_k_val, max_len, stream) before any StableSortTopKByValue call so ties are canonicalized by ascending index; ensure this conditional mirrors the filtered path's deterministic branch around SortTopKByIndex so both algorithms produce identical tie order.
🧹 Nitpick comments (2)
flashinfer/topk.py (1)
63-73: Ruff flagsinputshadowing Python builtin.The static analysis tool flags line 65 for shadowing Python's built-in
input. However, this pattern is consistent with the existing codebase conventions for tensor parameter naming in this file. Given the "Chill" review mode and that this is a widespread pattern, this can be addressed in a separate cleanup if desired.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 63 - 73, The parameter name input in the _fake_radix_topk function shadows Python's builtin and should be renamed to avoid the Ruff warning; update the function signature of _fake_radix_topk (registered as "flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or src_tensor) and replace all uses inside the function (input.size and input.device) accordingly so behavior is unchanged.include/flashinfer/topk.cuh (1)
3174-3198: Consider documenting heuristic rationale.The deterministic-mode algorithm selection heuristics (lines 3174-3184) differ significantly from non-deterministic heuristics (lines 3186-3197). Consider adding a brief comment explaining the trade-off (e.g., filtered deterministic overhead vs. radix multi-CTA coordination cost).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3174 - 3198, Add a short explanatory comment immediately above the block that branches on deterministic and DType size (referencing variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and batch_threshold) that explains why deterministic-mode thresholds differ from non-deterministic ones — e.g., deterministic implementation favors simpler per-row filtered scans to avoid non-deterministic cross-CTA radix coordination (hence lower thresholds like 16384 and the special 256 divisor), while non-deterministic heuristics accept radix/multi-CTA strategies for larger max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to compute batch_threshold); keep the comment concise (1–3 lines) describing the trade-off and pointing to the key constants so future readers understand the rationale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3267: The radix deterministic path omits the canonical index
tie-break used by the filtered path; after RadixTopKMultiCTA completes and when
deterministic==true, call SortTopKByIndex with the same arguments used by the
filtered path (e.g., output_indices, output_values, num_rows, top_k_val,
max_len, stream) before any StableSortTopKByValue call so ties are canonicalized
by ascending index; ensure this conditional mirrors the filtered path's
deterministic branch around SortTopKByIndex so both algorithms produce identical
tie order.
---
Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 63-73: The parameter name input in the _fake_radix_topk function
shadows Python's builtin and should be renamed to avoid the Ruff warning; update
the function signature of _fake_radix_topk (registered as
"flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or
src_tensor) and replace all uses inside the function (input.size and
input.device) accordingly so behavior is unchanged.
In `@include/flashinfer/topk.cuh`:
- Around line 3174-3198: Add a short explanatory comment immediately above the
block that branches on deterministic and DType size (referencing
variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and
batch_threshold) that explains why deterministic-mode thresholds differ from
non-deterministic ones — e.g., deterministic implementation favors simpler
per-row filtered scans to avoid non-deterministic cross-CTA radix coordination
(hence lower thresholds like 16384 and the special 256 divisor), while
non-deterministic heuristics accept radix/multi-CTA strategies for larger
max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to
compute batch_threshold); keep the comment concise (1–3 lines) describing the
trade-off and pointing to the key constants so future readers understand the
rationale.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4b20947c-0a7d-4e41-ace0-d46414a25216
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
@yzh119 @Linda-Stadter This PR is ready for review now. |
| uint32_t cta_local_eq_count = 0; | ||
| OrderedType ordered_pivot = | ||
| RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DETERMINISTIC, DType>( | ||
| input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars, |
There was a problem hiding this comment.
This doesn't contain my overflow fix by casting to size_t. I will create another commit on top of this :)
|
Can you take a look at this and cherry pick? Linda-Stadter@674161b |
Looks like it could be submitted as a standalone bug-fix PR, as it seems orthogonal to the deterministic top-k implementation. I am not sure if the FlashInfer maintainers would accept merging two commits with different objectives into a single PR. |
Yes, I agree, it is an additional bug fix. But due to time constraints and because it is only a small change, I wanted to merge it with this PR. Let me know if you are strictly against it. |
I don't mind chery-picking this small commit, up to @yzh119 |
also add cub stable radix sort and overflow handling Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
|
I have put the overflow now in an extra PR |
|
/bot run |
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
|
/bot stop |
|
The GitLab CI pipeline #47450958 has been cancelled. |
|
/bot run |
|
[CANCELING] Pipeline #47452629: canceled |
| assert torch.equal(out, ref) | ||
|
|
||
|
|
||
| def test_top_k_uint32_pointer_overflow(): |
There was a problem hiding this comment.
Should we add more parameter combinations to this test case using pytest decorators, such as
- deterministic/non-deterministic
- plain/ragged/page-table
so that we can ensure the overflow issue is covered across all kind of modes?
| @@ -1154,7 +1154,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( | |||
| if (chunk_start + i < k) { | |||
| row_output[chunk_start + i] = static_cast<IdType>(chunk_start + i); | |||
| output_values[row_idx * top_k_val + chunk_start + i] = | |||
There was a problem hiding this comment.
Could there be more overflow issues in the current topk.cuh file? For example, in code like output_values[row_idx * top_k_val + chunk_start + i] =? We may need to review topk.cuh more thoroughly, or strengthen the test cases to catch such issues.
📌 Description
Deterministic Mode for Top-K Kernels
FilteredTopK Kernel
FilteredTopKKernel implements deterministic mode as follows:
collect_gt_and_nondet_eq_threshold); their final order is determined by the post-sort kernel.collect_det_eq_pivot, which writes the selected tie elements intos_indicesin deterministic thread-strided order.SortTopKByIndexKernelis applied to produce index-ascending output and make the final ordering deterministic (we use atomicAdd to collect > pivot at stage 1).StableSortTopKByValueKernelis applied afterward to produce value-descending output.RadixTopK Kernel
ordered_pivot, which Stage 2 uses to determine whether an element is >=ordered_pivot.cta_local_eq_countandcta_local_gt_count, which Stage 2 uses to determine how many elements the current CTA may emit and where each emitted element should be placed.RadixCollectIndicesDeterministic)RadixCollectIndicesDeterministic: after the pivot is known, assigns each CTA a fixed output range, then writes all > pivot elements followed by the required == pivot elements in a deterministic order.
Order definition:
Benchmarks
machine: NVIDIA A100-PCIE-40GB
command: (fp32/fp16/bf16)
raw results:
output.txt
Summary
NOTE: FlashInfer deterministic underperforms PyTorch mainly on short-sequence workloads. Importantly, this is not unique to the deterministic path: FlashInfer non-deterministic top-k is also slower than PyTorch in the same short-sequence regime. This suggests the gap is primarily a short-sequence top-k issue rather than a deterministic-specific regression. Optimizing short-sequence top-k, for both non-deterministic and deterministic modes, is better treated as future work.
🔍 Related Issues
close: #2584
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Benchmarks
Tests
Bug Fixes