Skip to content

feat: implement deterministic topk#2661

Merged
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
jiangyinzuo:feat/deterministic-topk
Apr 1, 2026
Merged

feat: implement deterministic topk#2661
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
jiangyinzuo:feat/deterministic-topk

Conversation

@jiangyinzuo
Copy link
Copy Markdown
Contributor

@jiangyinzuo jiangyinzuo commented Mar 1, 2026

📌 Description

Part of the FilteredTopK implementation refers to or is adapted from @Linda-Stadter's work in #2759

Deterministic Mode for Top-K Kernels

FilteredTopK Kernel

FilteredTopKKernel implements deterministic mode as follows:

  1. Build a coarse histogram.
  • Build a coarse histogram on the top 8 bits to locate the coarse threshold bin that contains the k-th largest element.
  • Same as non-deterministic mode, elements with bin > threshold_bin are appended to s_indices via atomicAdd (see collect_gt_and_nondet_eq_threshold); their final order is determined by the post-sort kernel.
  1. Refine with 8-bit radix passes.
  • Run multiple 8-bit refine passes to find the exact pivot.
  • Deterministic == pivot selection is performed by collect_det_eq_pivot, which writes the selected tie elements into s_indices in deterministic thread-strided order.

Thread-strided order means, for example, if BLOCK_THREADS = 4, then the logical scan order is:

  • thread 0: 0, 4, 8, ...
  • thread 1: 1, 5, 9, ...
  • thread 2: 2, 6, 10, ...
  • thread 3: 3, 7, 11, ...

If the == pivot positions are:

  • thread 0: 0, 8
  • thread 1: 5
  • thread 2: none
  • thread 3: 3, 7

then the deterministic collection order is: [0, 8, 5, 3, 7].
That is, we order elements first by thread ID, and then by each thread's strided traversal order.

  1. Post-sort kernels.
  • After FilteredTopKKernel finishes, SortTopKByIndexKernel is applied to produce index-ascending output and make the final ordering deterministic (we use atomicAdd to collect > pivot at stage 1).
  • If the Python API is called with sorted=True, StableSortTopKByValueKernel is applied afterward to produce value-descending output.

RadixTopK Kernel

  1. RadixSelectFindPivot
  • Finds ordered_pivot, which Stage 2 uses to determine whether an element is >= ordered_pivot.
  • Computes cta_local_eq_count and cta_local_gt_count, which Stage 2 uses to determine how many elements the current CTA may emit and where each emitted element should be placed.
  1. collect_indices (RadixCollectIndicesDeterministic)

RadixCollectIndicesDeterministic: after the pivot is known, assigns each CTA a fixed output range, then writes all > pivot elements followed by the required == pivot elements in a deterministic order.

Order definition:

  • Emit > pivot elements first, then == pivot elements.
  • For each category, earlier CTAs write to earlier output positions.
  • Within each CTA, emit elements in thread-strided order.

Benchmarks

machine: NVIDIA A100-PCIE-40GB

command: (fp32/fp16/bf16)

python -u benchmarks/bench_topk.py \
  --op all \
  --dtype fp32 \
  --deterministic \
  --compare-torch-deterministic \
  --input-pattern random

raw results:

output.txt
Summary

dtype geomean det slowdown vs non-det geomean speedup vs torch.det
fp32 1.0992x 1.7660x
fp16 1.0777x 1.3381x
bf16 1.0745x 1.3055x

NOTE: FlashInfer deterministic underperforms PyTorch mainly on short-sequence workloads. Importantly, this is not unique to the deterministic path: FlashInfer non-deterministic top-k is also slower than PyTorch in the same short-sequence regime. This suggests the gap is primarily a short-sequence top-k issue rather than a deterministic-specific regression. Optimizing short-sequence top-k, for both non-deterministic and deterministic modes, is better treated as future work.

🔍 Related Issues

close: #2584

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).
unittest I ran:
test_topk.py
test_sampling.py
test_logits_processor.py

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Deterministic mode for top‑k and fused transforms (stable, repeatable tie ordering) with API flag to enable deterministic outputs and stable sorting behavior.
  • Benchmarks

    • Expanded benchmarking to compare deterministic vs nondeterministic runs, pre-generated input patterns, DSA workload cases, and richer CLI output.
  • Tests

    • Large suite of determinism and correctness tests (ties, multi‑CTA, streams, sorted behavior, cache transitions).
  • Bug Fixes

    • Improved runtime-error labeling and benchmark cache handling.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an opt-in deterministic mode across the top-k stack: Python APIs, FFI bindings, C++ dispatch, and CUDA kernels; implements deterministic multi-CTA collection and stable tie‑breaking, updates benchmarks/CLI for deterministic comparisons and DSA workloads, and adds deterministic-focused tests and helpers.

Changes

Cohort / File(s) Summary
Benchmarks & CLI
benchmarks/bench_topk.py
Refactor benchmark flow to pre-generate scores, add deterministic benchmarking infrastructure (deterministic vs nondeterministic timings, torch-deterministic comparison), DSA workload generation, and new CLI flags/options.
Python API surface
flashinfer/topk.py
Add deterministic: bool to top_k, top_k_page_table_transform, top_k_ragged_transform; forward sorted_output/deterministic into CUDA bindings; adjust kernel selection and stable-sort fallback behavior.
FFI binding layer
csrc/flashinfer_topk_binding.cu
Extend exported bindings radix_topk, radix_topk_page_table_transform, radix_topk_ragged_transform to accept new sorted_output/deterministic boolean parameters.
C++ dispatcher & glue
csrc/topk.cu
Thread new sorted_output and deterministic flags into TopKDispatch/fused dispatch calls and propagate to kernel launch paths.
CUDA kernels & headers
include/flashinfer/topk.cuh
Major deterministic additions: ordered SMEM sizing helper, deterministic multi‑CTA scratch/barrier primitives, deterministic collection and pivot eq-count tracking, deterministic-aware FilteredTopK and stable post-sort transforms, and updated dispatch/heuristics with deterministic guards.
Tests
tests/utils/test_topk.py
Add deterministic repeatability/tie/stability tests, cached radix row-states buffer inspection/eviction helpers, parameterize tests over deterministic mode, and expand transform/regression coverage to validate deterministic behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API
    participant Bind as FFI Binding
    participant Dispatch as C++ Dispatcher
    participant Kernel as CUDA Kernel
    Py->>Bind: call top_k(..., deterministic=True)
    Bind->>Dispatch: radix_topk(..., sorted_output=..., deterministic=...)
    Dispatch->>Kernel: launch deterministic-aware kernel (det scratch, DETERMINISTIC)
    Kernel->>Kernel: deterministic collect / stable tie-breaking / optional stable sort
    Kernel-->>Dispatch: return indices & values
    Dispatch-->>Bind: propagate results
    Bind-->>Py: deliver deterministic outputs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • jiahanc
  • kahyunnam
  • cyx-6
  • yzh119
  • IwakuraRein
  • jimmyzho
  • bkryu
  • nvmbreughe

Poem

🐰 I hopped through kernels, threads aligned,
I made each tie yield the smaller mind.
Runs now repeat, indices hold true,
Determinism carrot—crisp and new.
🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.10% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed The PR implements all key acceptance criteria from issue #2584: opt-in deterministic mode, bitwise reproducibility via thread-strided collection and CTA-assigned ranges, well-defined tie-breaking by thread-strided order, and comprehensive benchmarking showing ~1.08-1.10x slowdown and 1.43-2.06x speedup vs PyTorch.
Out of Scope Changes check ✅ Passed All changes are scoped to deterministic top-k implementation: kernel modifications (radix and filtered paths), Python API extensions, benchmark infrastructure updates, and comprehensive test additions. No unrelated changes detected.
Title check ✅ Passed The title 'feat: implement deterministic topk' directly and clearly summarizes the primary change of the pull request—adding deterministic mode support for top-k kernels.
Description check ✅ Passed The PR description is comprehensive and well-structured, following the repository's template with complete Description, Related Issues, and Pull Request Checklist sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant feature by enabling deterministic behavior for all top-k related operations within FlashInfer. It provides users with fine-grained control over determinism, which is crucial for reproducibility in machine learning and scientific computing. The changes span Python APIs, CUDA kernels, and benchmarking tools, ensuring robust implementation and verification of this new capability.

Highlights

  • Deterministic Top-K Operations: Implemented deterministic behavior for all top-k operations (top_k, top_k_page_table_transform, top_k_ragged_transform) to ensure repeatable results, especially when ties are present in input scores.
  • Flexible Deterministic Mode Control: Introduced a DeterministicMode enum (NON_DETERMINISTIC, REPRODUCIBLE) and corresponding parameters (deterministic, deterministic_mode) in Python APIs, allowing users to explicitly control the level of determinism. The deterministic=True flag is maintained for backward compatibility, mapping to REPRODUCIBLE mode.
  • Enhanced Benchmarking and Testing: Updated the benchmarking script (bench_topk.py) with new arguments for deterministic mode and various input patterns (random, tie_heavy, pivot_tie). Comprehensive unit tests (test_topk.py) were added to verify the repeatability and correctness of deterministic top-k operations across different scenarios and algorithms.
  • CUDA Kernel Modifications: Modified underlying CUDA kernels (topk.cu, topk.cuh) to incorporate deterministic logic, including changes to radix selection, index collection (with new RadixBlockExclusivePrefix and RadixCollectIndicesReproducible functions), and filtered top-k sorting (FilteredTopKBitonicSortIndices). Heuristics for algorithm selection (ShouldUseFilteredTopKDeterministicAware) were also updated to consider deterministic requirements.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_topk.py
    • Added contextlib.contextmanager for temporary PyTorch deterministic algorithm mode.
    • Implemented run_torch_topk to wrap torch.topk with optional deterministic mode.
    • Introduced generate_scores to create benchmark input with various tie patterns.
    • Extended bench_top_k, bench_page_table_transform, and bench_ragged_transform functions with input_pattern, deterministic, and deterministic_mode parameters.
    • Added command-line arguments --deterministic, --deterministic-mode, --compare-torch-deterministic, and --input-pattern to main function.
    • Updated benchmark output formatting to reflect new deterministic options.
  • csrc/flashinfer_topk_binding.cu
    • Modified C++ function signatures for radix_topk, radix_topk_page_table_transform, and radix_topk_ragged_transform to include a deterministic_mode parameter.
  • csrc/topk.cu
    • Added ParseDeterministicMode helper to convert integer mode to sampling::DeterministicMode enum.
    • Passed deterministic_mode to sampling::TopKDispatch functions.
    • Updated RadixSelectFromSharedMemory to optionally track equal counts for deterministic tie-breaking.
    • Implemented RadixBlockExclusivePrefix and RadixCollectIndicesReproducible for deterministic index collection.
    • Introduced RadixCollectIndicesDispatch to select between deterministic and non-deterministic collection paths.
    • Modified kernel launch logic to use the DETERMINISTIC template parameter.
  • flashinfer/init.py
    • Exported DeterministicMode enum from flashinfer.topk.
  • flashinfer/topk.py
    • Defined DeterministicMode as an IntEnum for NON_DETERMINISTIC and REPRODUCIBLE.
    • Added _DETERMINISTIC_MODE_ALIASES for string-based mode selection.
    • Implemented _resolve_deterministic_mode to parse and validate deterministic mode parameters.
    • Added deterministic and deterministic_mode parameters to top_k, top_k_page_table_transform, and top_k_ragged_transform Python APIs.
    • Modified top_k to use stable=True for torch.sort when sorted=True and in reproducible mode.
    • Passed the resolved deterministic mode to the underlying C++ kernel calls.
  • include/flashinfer/topk.cuh
    • Defined DeterministicMode enum and IsDeterministicMode helper.
    • Added GetReproducibleTargetCTAsPerGroup and MaybeBoostReproducibleCTAsPerGroup for dynamic CTA adjustment in reproducible mode.
    • Updated RadixSelectFromSharedMemory to optionally track eq_count.
    • Introduced RadixBlockExclusivePrefix for block-level exclusive prefix sums.
    • Implemented RadixCollectIndicesReproducible for deterministic index collection with tie-breaking.
    • Added RadixCollectIndicesDispatch to conditionally use deterministic collection.
    • Modified RadixTopKKernel_Unified to accept a DETERMINISTIC template parameter.
    • Added FilteredTopKBitonicSortIndices for sorting indices in deterministic filtered top-k.
    • Updated FilteredTopKUnifiedKernel to accept DETERMINISTIC template parameter and use bitonic sort.
    • Introduced SelectFilteredTopKBlockThreads for dynamic block size selection in filtered top-k.
    • Added LaunchFilteredTopKUnified to centralize filtered top-k kernel launches.
    • Modified ShouldUseFilteredTopKDeterministicAware to include deterministic mode heuristics for algorithm selection.
    • Updated TopKPageTableTransformDispatch, TopKRaggedTransformDispatch, and TopKDispatch to pass deterministic_mode and use the new deterministic-aware heuristics.
  • tests/utils/test_topk.py
    • Imported DeterministicMode enum.
    • Added test_top_k_deterministic_mode_bool_compatibility to verify backward compatibility.
    • Included test_top_k_reproducible_mode_repeatability and test_top_k_reproducible_mode_repeatability_multi_cta for repeatable results.
    • Added test_top_k_invalid_deterministic_mode to check error handling for invalid modes.
    • Implemented test_top_k_deterministic_bitwise_repeatability for strict bitwise repeatability.
    • Added repeatability tests for top_k_page_table_transform and top_k_ragged_transform in both deterministic=True and DeterministicMode.REPRODUCIBLE modes.
Activity
  • The pull request is currently a Work In Progress (WIP).
  • Pre-commit checks have been completed.
  • Tests are not yet marked as added/updated or passing, indicating ongoing development and verification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: deterministic top-k selection. The changes are extensive, adding new execution paths to both the radix and filtered top-k algorithms to ensure reproducible results, which is particularly important for handling ties. The implementation includes backward compatibility for existing APIs by adding new optional parameters. The benchmarks and tests have been updated comprehensively to cover the new deterministic modes. The overall implementation is well-designed and robust. I have one suggestion to improve code clarity and remove a minor redundancy in the CUDA kernel.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 51-56: The benchmark currently enters the
torch_deterministic_algorithms context inside run_torch_topk on every iteration,
adding overhead; instead enable deterministic mode once before the timing loop
and restore the prior state afterwards, removing the per-iteration context from
run_torch_topk (and the analogous per-iteration context in the other benchmark
at lines 136-145); specifically, call the global deterministic enable API once
(save the previous value), run the repeated torch.topk calls normally inside
run_torch_topk, then restore the saved deterministic setting after the loop so
the timing measures only torch.topk cost.

In `@tests/utils/test_topk.py`:
- Around line 1492-1514: The BF16 reproducibility test
(test_top_k_reproducible_mode_repeatability_multi_cta) runs unconditionally but
must be skipped on GPUs with compute capability < SM80; add a guard at the start
of the test that calls flashinfer.utils.get_compute_capability() (or the project
helper like flashinfer.utils.is_sm90a_supported/is_sm80_supported) and use
pytest.skip(...) when the capability is below 80 to avoid running BF16 on
unsupported hardware; also add an import for pytest if it's not present.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f521fe1 and a595ced.

📒 Files selected for processing (7)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/__init__.py
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

Comment thread benchmarks/bench_topk.py Outdated
Comment thread tests/utils/test_topk.py Outdated
@jiangyinzuo jiangyinzuo marked this pull request as draft March 1, 2026 09:31
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch 12 times, most recently from 4358ff1 to 9e88bc8 Compare March 5, 2026 16:48
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch from 9e88bc8 to 42a86f9 Compare March 8, 2026 12:00
@jiangyinzuo jiangyinzuo marked this pull request as ready for review March 8, 2026 12:01
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch from 42a86f9 to 7679c40 Compare March 8, 2026 12:02
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
tests/utils/test_topk.py (1)

1487-1499: ⚠️ Potential issue | 🟡 Minor

Rename sorted; Ruff still flags this helper.

Line 1488 shadows the Python builtin, so this helper keeps tripping A002. sorted_output avoids the lint with no behavior change.

🧹 Minimal rename
 def _assert_top_k_matches_torch(
-    logits: torch.Tensor, k: int, *, deterministic: bool = False, sorted: bool = True
+    logits: torch.Tensor,
+    k: int,
+    *,
+    deterministic: bool = False,
+    sorted_output: bool = True,
 ):
     """Assert FlashInfer top_k matches torch.topk for exact-order cases."""
     values, indices = flashinfer.top_k(
-        logits, k, deterministic=deterministic, sorted=sorted
+        logits, k, deterministic=deterministic, sorted=sorted_output
     )
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=sorted)
+    ref_values, ref_indices = torch.topk(
+        logits, k, dim=-1, sorted=sorted_output
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1487 - 1499, Rename the parameter
named sorted in the helper function _assert_top_k_matches_torch to avoid
shadowing the built-in; change the parameter name to sorted_output and update
all uses inside the function (the flashinfer.top_k call and torch.topk call) to
pass sorted=sorted_output (and any internal references if present), leaving the
behavior and variable names values, indices, ref_values, ref_indices unchanged.
include/flashinfer/topk.cuh (2)

232-240: ⚠️ Potential issue | 🔴 Critical

Synchronize the CTA before publishing the radix-group arrival.

AdvanceRadixGroupBarrier() still lets Line 235 advance arrival_counter before the rest of the block is forced to finish its preceding histogram/output writes. The current callers at Line 468, Line 648, and Line 851 hit it immediately after per-thread atomics/stores, so another CTA can observe partially updated state and break correctness/determinism again.

🔧 Minimal fix
 __device__ __forceinline__ void AdvanceRadixGroupBarrier(RadixRowState* state, int& barrier_phase,
                                                          uint32_t ctas_per_group, uint32_t tx) {
+  __syncthreads();
   if (tx == 0) {
     red_release(&state->arrival_counter, 1);
   }
   int target = (barrier_phase + 1) * ctas_per_group;
   wait_ge(&state->arrival_counter, target, tx);

Expected result: either the helper owns the CTA sync, or every releasing call site shows an immediate __syncthreads() before it.

#!/bin/bash
set -euo pipefail
sed -n '232,240p' include/flashinfer/topk.cuh
sed -n '452,470p' include/flashinfer/topk.cuh
sed -n '635,650p' include/flashinfer/topk.cuh
sed -n '835,855p' include/flashinfer/topk.cuh
sed -n '1256,1263p' include/flashinfer/topk.cuh
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 232 - 240, AdvanceRadixGroupBarrier
currently releases the radix-group arrival (red_release(&state->arrival_counter,
1)) before the CTA is synchronized, allowing other CTAs to observe partially
written per-thread state; fix it by owning the CTA sync inside
AdvanceRadixGroupBarrier: add a __syncthreads() immediately before the tx==0
release path so the block finishes all histogram/output stores before calling
red_release, leaving the existing wait_ge(&state->arrival_counter, target, tx),
barrier_phase++, and trailing __syncthreads() intact.

3241-3258: ⚠️ Potential issue | 🟠 Major

Canonicalize radix ties before the stable value sort.

Line 3246 index-sorts only the filtered deterministic path. When Line 3251 routes deterministic work through radix, StableSortTopKByValue() on Line 3256 preserves the deterministic collection order from RadixCollectIndicesDeterministic, so sorted=True, deterministic=True still returns a different tie order depending on which algorithm was selected.

🔧 Suggested fix
   } else {
     FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>(
         input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len,
         row_states_buffer, deterministic, stream)));
+    if (deterministic && sorted_output) {
+      FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(
+          output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len,
+          stream)));
+    }
   }
   if (sorted_output) {
     FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>(
         output_indices, output_values, num_rows, top_k_val, max_len, stream)));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3241 - 3258, The deterministic
canonicalization (index-sort via LaunchSortTopKByIndex) is only applied in the
filtered path; ensure radix-based deterministic results are canonicalized the
same way before the stable value sort. After calling RadixTopKMultiCTA in the
else branch, if deterministic is true call
LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the same
arguments used in the filtered branch (output_indices, output_values, nullptr,
0, nullptr, num_rows, top_k_val, max_len, stream) so that StableSortTopKByValue
sees a canonical tie order regardless of which algorithm ran; keep the existing
filtered-path LaunchSortTopKByIndex and the final StableSortTopKByValue intact.
🧹 Nitpick comments (2)
benchmarks/bench_topk.py (1)

209-223: Consider using -float('inf') consistently for neg_inf fallback.

For fp16/bf16, using torch.finfo(dtype).min instead of -inf means values at the minimum representable float could still be selected over the masked positions. If the intent is to fully exclude masked positions from top-k selection, -inf (which is representable in fp16/bf16) would be more robust.

🔧 Suggested fix
-        neg_inf = -torch.inf if dtype == torch.float32 else torch.finfo(dtype).min
+        neg_inf = float('-inf')
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 209 - 223, The masking uses
torch.finfo(dtype).min for neg_inf when dtype is fp16/bf16 which can still be
chosen; change the neg_inf computation in the causal_chunk block (where
start_pos, lengths, q_len, dtype are used) to use a true negative infinity
constant (e.g. -float('inf')) for the masked_fill value so masked positions are
fully excluded when you call scores = scores.masked_fill(invalid, neg_inf).
flashinfer/topk.py (1)

176-182: Docstring could clarify tie-breaking strategy for deterministic mode.

The PR objectives and issue #2584 mention that deterministic mode uses "lower element index wins" for tie-breaking. Consider adding this detail to the docstring so users understand the expected behavior when values are equal.

📝 Suggested docstring enhancement
     deterministic : bool, optional
         If True, uses deterministic mode.
         Default is False (non-deterministic, which is faster).
 
         Deterministic mode guarantees repeatable FlashInfer output ordering for
-        the selected top-k set on a fixed input and system.
+        the selected top-k set on a fixed input and system. When values are equal,
+        elements with lower indices are selected first (stable tie-breaking).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 176 - 182, Update the docstring for the
top_k function to explicitly state the tie-breaking rule used when
deterministic=True: when values are equal the element with the lower index is
chosen ("lower element index wins"). Mention this behavior near the
deterministic parameter description in the top_k docstring so callers know how
ties are resolved and that ordering may differ from non-deterministic behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 213-218: The cached row-state workspace currently doesn't include
space for the deterministic scratch tail used by
MaybeGetRadixDeterministicCollectScratchBuffer, so when deterministic &&
!single_cta the pointer (row_states_buffer + num_groups) can walk past the
allocated buffer; update all allocation sites that create
radix_topk_row_states_* (in flashinfer/topk.py and any C++/CUDA allocs) to
reserve room for both RadixRowState[num_groups] and
RadixDeterministicCollectScratch[num_groups] (i.e. allocate num_groups of
RadixRowState plus num_groups of RadixDeterministicCollectScratch, or
equivalently adjust byte-size to
num_groups*(sizeof(RadixRowState)+sizeof(RadixDeterministicCollectScratch))),
and ensure any cached size calculations and related comments reflect this change
so deterministic multi-CTA no longer overruns the buffer.

In `@tests/utils/test_topk.py`:
- Around line 1896-1937: The tests
test_top_k_deterministic_sorted_large_k_matches_torch_by_algo and
test_top_k_deterministic_trivial_k_equals_length_by_algo currently parametrize
over "filtered" but use k values (4096 and vocab_size) larger than
FILTERED_TOPK_MAX_K (defined in include/flashinfer/topk.cuh as 2048), so they
never exercise FilteredTopK; update the parametrization to only use ["auto",
"multi_cta"] for these two tests, or alternatively add a separate test case that
explicitly uses set_topk_algo("filtered") with k <= FILTERED_TOPK_MAX_K (e.g.,
k=2048) to validate the filtered path.

---

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently releases the
radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is
synchronized, allowing other CTAs to observe partially written per-thread state;
fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a
__syncthreads() immediately before the tx==0 release path so the block finishes
all histogram/output stores before calling red_release, leaving the existing
wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing
__syncthreads() intact.
- Around line 3241-3258: The deterministic canonicalization (index-sort via
LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based
deterministic results are canonicalized the same way before the stable value
sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is
true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the
same arguments used in the filtered branch (output_indices, output_values,
nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that
StableSortTopKByValue sees a canonical tie order regardless of which algorithm
ran; keep the existing filtered-path LaunchSortTopKByIndex and the final
StableSortTopKByValue intact.

In `@tests/utils/test_topk.py`:
- Around line 1487-1499: Rename the parameter named sorted in the helper
function _assert_top_k_matches_torch to avoid shadowing the built-in; change the
parameter name to sorted_output and update all uses inside the function (the
flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any
internal references if present), leaving the behavior and variable names values,
indices, ref_values, ref_indices unchanged.

---

Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 209-223: The masking uses torch.finfo(dtype).min for neg_inf when
dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in
the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use
a true negative infinity constant (e.g. -float('inf')) for the masked_fill value
so masked positions are fully excluded when you call scores =
scores.masked_fill(invalid, neg_inf).

In `@flashinfer/topk.py`:
- Around line 176-182: Update the docstring for the top_k function to explicitly
state the tie-breaking rule used when deterministic=True: when values are equal
the element with the lower index is chosen ("lower element index wins"). Mention
this behavior near the deterministic parameter description in the top_k
docstring so callers know how ties are resolved and that ordering may differ
from non-deterministic behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae026e77-9523-469a-8f61-393293ddf092

📥 Commits

Reviewing files that changed from the base of the PR and between fe9e258 and da7ad5e.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu

Comment thread include/flashinfer/topk.cuh
Comment thread tests/utils/test_topk.py Outdated
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 30, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation across all modified files (include/flashinfer/topk.cuh, csrc/topk.cu, csrc/flashinfer_topk_binding.cu, flashinfer/topk.py, benchmarks/bench_topk.py, and tests/utils/test_topk.py), including the CUB sort integration and overflow fix cherry-picked from @Linda-Stadter's fork.


🧠 Learnings used
Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters.cu:135-141
Timestamp: 2026-03-20T17:55:09.735Z
Learning: In `csrc/fast_topk_clusters.cu`, the `fast_topk_cuda_v4` device function intentionally drops elements that fall in the threshold bin when the shared cache is full (`cached_offset >= num_cached`). This is a deliberate performance trade-off in the "fast" (inexact) TopK variant. The exact/correctness-preserving alternative is `fast_topk_clusters_exact` in `csrc/fast_topk_clusters_exact.cu`. Do not flag this overflow-drop behavior as a bug.

Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters_exact.cu:406-410
Timestamp: 2026-03-20T17:56:48.864Z
Learning: In `csrc/fast_topk_clusters.cu` and `csrc/fast_topk_clusters_exact.cu`, the launcher functions (`launch_fast_topk_clusters` and `launch_fast_topk_clusters_exact`) intentionally and silently fall back to `num_clusters=1, pdl_enabled=false` when an unsupported `num_clusters` value is passed. No warning log is emitted. Do not flag this silent fallback as a bug or suggest adding a warning.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.
✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
include/flashinfer/topk.cuh (1)

3259-3267: ⚠️ Potential issue | 🟠 Major

Radix deterministic path omits index canonicalization; tie order differs from filtered path.

The filtered path calls SortTopKByIndex when deterministic=true (line 3255-3257), but the radix path (line 3260-3262) does not. This means:

  • Filtered + deterministic: ties ordered by ascending index
  • Radix + deterministic: ties ordered by thread-strided emission

When sorted_output=true, StableSortTopKByValue preserves prior ordering for ties, so the final tie-breaking strategy depends on which algorithm was selected—breaking the contract that deterministic mode produces consistent, index-ordered ties.

🔧 Suggested fix
   } else {
     FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>(
         input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len,
         row_states_buffer, deterministic, stream)));
+    if (deterministic) {
+      FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(
+          output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len,
+          stream)));
+    }
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3259 - 3267, The radix
deterministic path omits the canonical index tie-break used by the filtered
path; after RadixTopKMultiCTA completes and when deterministic==true, call
SortTopKByIndex with the same arguments used by the filtered path (e.g.,
output_indices, output_values, num_rows, top_k_val, max_len, stream) before any
StableSortTopKByValue call so ties are canonicalized by ascending index; ensure
this conditional mirrors the filtered path's deterministic branch around
SortTopKByIndex so both algorithms produce identical tie order.
🧹 Nitpick comments (2)
flashinfer/topk.py (1)

63-73: Ruff flags input shadowing Python builtin.

The static analysis tool flags line 65 for shadowing Python's built-in input. However, this pattern is consistent with the existing codebase conventions for tensor parameter naming in this file. Given the "Chill" review mode and that this is a widespread pattern, this can be addressed in a separate cleanup if desired.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 63 - 73, The parameter name input in the
_fake_radix_topk function shadows Python's builtin and should be renamed to
avoid the Ruff warning; update the function signature of _fake_radix_topk
(registered as "flashinfer::radix_topk") to use a non-builtins name (e.g.,
tensor, inp, or src_tensor) and replace all uses inside the function (input.size
and input.device) accordingly so behavior is unchanged.
include/flashinfer/topk.cuh (1)

3174-3198: Consider documenting heuristic rationale.

The deterministic-mode algorithm selection heuristics (lines 3174-3184) differ significantly from non-deterministic heuristics (lines 3186-3197). Consider adding a brief comment explaining the trade-off (e.g., filtered deterministic overhead vs. radix multi-CTA coordination cost).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3174 - 3198, Add a short
explanatory comment immediately above the block that branches on deterministic
and DType size (referencing variables/conditions: deterministic, sizeof(DType),
max_len, num_rows, and batch_threshold) that explains why deterministic-mode
thresholds differ from non-deterministic ones — e.g., deterministic
implementation favors simpler per-row filtered scans to avoid non-deterministic
cross-CTA radix coordination (hence lower thresholds like 16384 and the special
256 divisor), while non-deterministic heuristics accept radix/multi-CTA
strategies for larger max_len (notice thresholds 16384/32768 and the use of
max_len/4096 or /16384 to compute batch_threshold); keep the comment concise
(1–3 lines) describing the trade-off and pointing to the key constants so future
readers understand the rationale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3267: The radix deterministic path omits the canonical index
tie-break used by the filtered path; after RadixTopKMultiCTA completes and when
deterministic==true, call SortTopKByIndex with the same arguments used by the
filtered path (e.g., output_indices, output_values, num_rows, top_k_val,
max_len, stream) before any StableSortTopKByValue call so ties are canonicalized
by ascending index; ensure this conditional mirrors the filtered path's
deterministic branch around SortTopKByIndex so both algorithms produce identical
tie order.

---

Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 63-73: The parameter name input in the _fake_radix_topk function
shadows Python's builtin and should be renamed to avoid the Ruff warning; update
the function signature of _fake_radix_topk (registered as
"flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or
src_tensor) and replace all uses inside the function (input.size and
input.device) accordingly so behavior is unchanged.

In `@include/flashinfer/topk.cuh`:
- Around line 3174-3198: Add a short explanatory comment immediately above the
block that branches on deterministic and DType size (referencing
variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and
batch_threshold) that explains why deterministic-mode thresholds differ from
non-deterministic ones — e.g., deterministic implementation favors simpler
per-row filtered scans to avoid non-deterministic cross-CTA radix coordination
(hence lower thresholds like 16384 and the special 256 divisor), while
non-deterministic heuristics accept radix/multi-CTA strategies for larger
max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to
compute batch_threshold); keep the comment concise (1–3 lines) describing the
trade-off and pointing to the key constants so future readers understand the
rationale.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4b20947c-0a7d-4e41-ace0-d46414a25216

📥 Commits

Reviewing files that changed from the base of the PR and between 779c24d and e8eaf33.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@yzh119 @Linda-Stadter This PR is ready for review now.

Comment thread include/flashinfer/topk.cuh Outdated
Comment thread include/flashinfer/topk.cuh Outdated
uint32_t cta_local_eq_count = 0;
OrderedType ordered_pivot =
RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DETERMINISTIC, DType>(
input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't contain my overflow fix by casting to size_t. I will create another commit on top of this :)

@Linda-Stadter
Copy link
Copy Markdown
Contributor

Can you take a look at this and cherry pick? Linda-Stadter@674161b

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

Can you take a look at this and cherry pick? Linda-Stadter@674161b

Looks like it could be submitted as a standalone bug-fix PR, as it seems orthogonal to the deterministic top-k implementation. I am not sure if the FlashInfer maintainers would accept merging two commits with different objectives into a single PR.

@Linda-Stadter
Copy link
Copy Markdown
Contributor

Looks like it could be submitted as a standalone bug-fix PR, as it seems orthogonal to the deterministic top-k implementation. I am not sure if the FlashInfer maintainers would accept merging two commits with different objectives into a single PR.

Yes, I agree, it is an additional bug fix. But due to time constraints and because it is only a small change, I wanted to merge it with this PR. Let me know if you are strictly against it.

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

jiangyinzuo commented Mar 30, 2026

Looks like it could be submitted as a standalone bug-fix PR, as it seems orthogonal to the deterministic top-k implementation. I am not sure if the FlashInfer maintainers would accept merging two commits with different objectives into a single PR.

Yes, I agree, it is an additional bug fix. But due to time constraints and because it is only a small change, I wanted to merge it with this PR. Let me know if you are strictly against it.

I don't mind chery-picking this small commit, up to @yzh119

also add cub stable radix sort and overflow handling

Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
@Linda-Stadter Linda-Stadter mentioned this pull request Apr 1, 2026
5 tasks
@Linda-Stadter
Copy link
Copy Markdown
Contributor

I have put the overflow now in an extra PR

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 1, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !486 has been created, and the CI pipeline #47450958 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 1, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #47450958 has been cancelled.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 1, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !486 has been updated with latest changes, and the CI pipeline #47452629 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #47452629: canceled

Comment thread tests/utils/test_topk.py
assert torch.equal(out, ref)


def test_top_k_uint32_pointer_overflow():
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add more parameter combinations to this test case using pytest decorators, such as

  • deterministic/non-deterministic
  • plain/ragged/page-table
    so that we can ensure the overflow issue is covered across all kind of modes?

@@ -1154,7 +1154,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
if (chunk_start + i < k) {
row_output[chunk_start + i] = static_cast<IdType>(chunk_start + i);
output_values[row_idx * top_k_val + chunk_start + i] =
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be more overflow issues in the current topk.cuh file? For example, in code like output_values[row_idx * top_k_val + chunk_start + i] =? We may need to review topk.cuh more thoroughly, or strengthen the test cases to catch such issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Deterministic top-k kernels for sparse attention

4 participants