feat: implement configurable tie_break for filtered topk#3095
feat: implement configurable tie_break for filtered topk#3095kahyunnam merged 10 commits intoflashinfer-ai:mainfrom
tie_break for filtered topk#3095Conversation
|
Note Reviews pausedUse the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a TopKTieBreak enum and threads a tie-break parameter through Python APIs, benchmarks/CLI, FFI bindings, C++ dispatch, and CUDA kernels; selecting a tie-break forces deterministic execution. Tests and benchmarks were extended to exercise SMALL/LARGE tie-break modes and reporting. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI
participant Benchmarks
participant PythonAPI
participant FFI
participant CppDispatch
participant CUDA
rect rgba(200,200,255,0.5)
User->>CLI: run benchmark with --tie-break
CLI->>Benchmarks: enable compare_tie_break / force deterministic
end
Benchmarks->>PythonAPI: call top_k / transforms (tie_break)
PythonAPI->>FFI: call radix_topk*(..., deterministic, tie_break)
FFI->>CppDispatch: invoke radix_topk entry (tie_break)
CppDispatch->>CUDA: select kernel/collection path (FilteredTopK + tie_break)
CUDA-->>CppDispatch: return results
CppDispatch-->>FFI: return outputs
FFI-->>PythonAPI: return tensors/timings
PythonAPI-->>Benchmarks: return metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi @aleozlx , could you help review? Thank you! |
There was a problem hiding this comment.
Code Review
This pull request introduces tie-breaking support for Top-K operations in FlashInfer, enabling deterministic selection of smaller or larger indices during value ties. The changes include a new TopKTieBreak enum, a DeterministicContiguousCollect CUDA helper, and updates to the Python API, C++ bindings, and benchmarking tools. Reviewer feedback highlights opportunities to reduce code duplication in the benchmark script and dispatch logic, and suggests restructuring macros for improved readability.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
benchmarks/bench_topk.py (2)
685-699: Minor: "tie" column shows integer code instead of mode name.
tie_breakis aTopKTieBreak(IntEnum). With the current:>7format spec, Python formats it via the mixed-inint.__format__, so table rows render0/1/2, even though the banner lines (','.join(map(str, tie_break_modes))) rendernone,small,largevia the overridden__str__. This inconsistency makes it harder to eyeball results in a multi-mode sweep. Consider formatting the column asstr(tie_break)(ortie_break.name.lower()) for readability. Applies to all four print loops (top_k, dsa_topk, page_table, ragged).♻️ Example tweak
- f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} {tie_break:>7} | " + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} {str(tie_break):>7} | "Also applies to: 731-741
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 685 - 699, The table's "tie" column is showing the IntEnum value (0/1/2) because TopKTieBreak (an IntEnum) is formatted with int.__format__; update the row formatting where tie_break is printed (in the top_k, dsa_topk, page_table, and ragged print loops) to render the human-readable name instead — e.g., use str(tie_break) or tie_break.name.lower() when building the row string so the column shows "none"/"small"/"large" instead of integers, keeping the same width/alignment specifier.
518-529: Argparsechoicesand CLI input format are inconsistent and may confuse users.
TopKTieBreakexplicitly overrides__str__()to return the lowercase name ("none"/"small"/"large"), and__format__()delegates to it. When argparse displayschoices=list(TopKTieBreak)in help output, it invokesstr()on each choice, rendering them as{none,small,large}— butparse_tie_break_modeonly accepts integer strings0/1/2(anything else raisesArgumentTypeError). Users following the help text verbatim (--tie-break-modes none) will hit an error.Two options:
- Drop
choices=list(TopKTieBreak)(validation already lives inparse_tie_break_mode, whose error message is explicit), or- Accept both integer and name inputs in
parse_tie_break_modeso the help text and accepted inputs agree.♻️ Proposed fix (accept names too, drop redundant
choices)def parse_tie_break_mode(mode_str: str) -> TopKTieBreak: """Parse CLI tie-break mode into TopKTieBreak enum.""" try: - return TopKTieBreak(int(mode_str)) - except ValueError as e: + return TopKTieBreak(int(mode_str)) + except ValueError: + pass + try: + return TopKTieBreak[mode_str.upper()] + except KeyError as e: raise argparse.ArgumentTypeError( - "tie_break mode must be one of: 0 (none), 1 (small), 2 (large)" + "tie_break mode must be one of: 0/none, 1/small, 2/large" ) from e @@ parser.add_argument( "--tie-break-modes", type=parse_tie_break_mode, nargs="+", - choices=list(TopKTieBreak), default=[TopKTieBreak.NONE], help=( "Tie-break modes to benchmark. " - "Use one or more of: 0 (none), 1 (prefer small indices), 2 (prefer large indices). " + "Use one or more of: 0/none, 1/small (prefer small indices), 2/large (prefer large indices). " "Default: 0" ), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 518 - 529, Help text and accepted inputs disagree: argparse currently lists choices as list(TopKTieBreak) (which stringifies to "none/small/large") but parse_tie_break_mode only accepts "0/1/2". Update parse_tie_break_mode to also accept enum names (case-insensitive, e.g. "none","small","large") in addition to numeric strings and return the corresponding TopKTieBreak, and remove the redundant choices=list(TopKTieBreak) from the parser.add_argument call for "--tie-break-modes" so help and parsing behavior match.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3333-3342: The code forces any non-None TopKTieBreak onto the
filtered-topk path and returns cudaErrorNotSupported when FILTERED_TOPK_MAX_K is
exceeded or CanImplementFilteredTopK() is false, causing valid existing
top_k(...) calls to fail; update the logic around tie_break,
deterministic_effective, and use_filtered so we do not unconditionally force
filtered mode: (1) if the radix path can support the requested tie_break, allow
the radix path by not setting use_filtered=true; (2) if radix cannot support
tie_break and filtered topk is unavailable or k>FILTERED_TOPK_MAX_K, return a
clear API error (not a silent cudaErrorNotSupported) indicating tie_break is
unsupported for this shape/GPU and suggesting removing tie_break or using a
supported device/shape; and (3) centralize this check where use_filtered is
computed (the block that calls ShouldUseFilteredTopK<DType>, references
TopKTieBreak, CanImplementFilteredTopK(), FILTERED_TOPK_MAX_K, and
deterministic_effective) so the decision honors radix capability and emits a
descriptive error message instead of breaking existing callers.
In `@tests/utils/test_topk.py`:
- Around line 106-114: The direct-run "__main__" smoke invocation calls
test_top_k with an outdated signature and raises TypeError because test_top_k
now requires the tie_break parameter; update the "__main__" block so it calls
test_top_k with a valid tie_break (for example iterate over
flashinfer.TopKTieBreak options or pass a default like
flashinfer.TopKTieBreak.NONE) and mirror this change for the other direct-run
calls that invoke test_top_k (the same pattern at the other two locations),
ensuring the call sites match the new test_top_k(batch_size, vocab_size, k,
dtype, tie_break) signature.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 685-699: The table's "tie" column is showing the IntEnum value
(0/1/2) because TopKTieBreak (an IntEnum) is formatted with int.__format__;
update the row formatting where tie_break is printed (in the top_k, dsa_topk,
page_table, and ragged print loops) to render the human-readable name instead —
e.g., use str(tie_break) or tie_break.name.lower() when building the row string
so the column shows "none"/"small"/"large" instead of integers, keeping the same
width/alignment specifier.
- Around line 518-529: Help text and accepted inputs disagree: argparse
currently lists choices as list(TopKTieBreak) (which stringifies to
"none/small/large") but parse_tie_break_mode only accepts "0/1/2". Update
parse_tie_break_mode to also accept enum names (case-insensitive, e.g.
"none","small","large") in addition to numeric strings and return the
corresponding TopKTieBreak, and remove the redundant choices=list(TopKTieBreak)
from the parser.add_argument call for "--tie-break-modes" so help and parsing
behavior match.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6af13619-8a53-408e-bda4-818cceb77d46
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/__init__.pyflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)
1960-1961: Use theTopKTieBreakenum instead of raw integers for consistency.The existing parametrized tests (Lines 106-113, 152-159, 197-204, 229-236) all pass
flashinfer.TopKTieBreak.{NONE,SMALL,LARGE}. The two new tests (here and at Lines 2022-2033) pass bare integers1and2, which is less readable and inconsistent with the style used throughout the rest of the file.♻️ Proposed change
- values_small, indices_small = flashinfer.top_k(logits, k, tie_break=1) - values_large, indices_large = flashinfer.top_k(logits, k, tie_break=2) + values_small, indices_small = flashinfer.top_k( + logits, k, tie_break=flashinfer.TopKTieBreak.SMALL + ) + values_large, indices_large = flashinfer.top_k( + logits, k, tie_break=flashinfer.TopKTieBreak.LARGE + )Apply the equivalent change at Lines 2022-2033.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1960 - 1961, Replace raw integer tie_break arguments with the TopKTieBreak enum for consistency: change the calls to flashinfer.top_k(logits, k, tie_break=1) and tie_break=2 to use flashinfer.TopKTieBreak.SMALL and flashinfer.TopKTieBreak.LARGE respectively (the variables values_small/indices_small and values_large/indices_large remain the same); apply the same replacement in the other new test mentioned (the block around the second occurrence) so all tests use flashinfer.TopKTieBreak instead of integers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_topk.py`:
- Around line 1952-1978: The test currently assumes a stable output ordering
from flashinfer.top_k(logits, k, tie_break=1/2) but ordering is only guaranteed
when sorted=True; either pass sorted=True into both top_k calls or (preferred)
keep top_k as-is and replace the strict torch.equal index checks with the
unordered comparator _assert_unordered_indices_match so the selection is
validated without requiring a specific permutation; keep the existing values
assertions (torch.testing.assert_close on values_small/values_large) and change
the two index assertions to call _assert_unordered_indices_match(indices_small,
expected_small) and _assert_unordered_indices_match(indices_large,
expected_large).
---
Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 1960-1961: Replace raw integer tie_break arguments with the
TopKTieBreak enum for consistency: change the calls to flashinfer.top_k(logits,
k, tie_break=1) and tie_break=2 to use flashinfer.TopKTieBreak.SMALL and
flashinfer.TopKTieBreak.LARGE respectively (the variables
values_small/indices_small and values_large/indices_large remain the same);
apply the same replacement in the other new test mentioned (the block around the
second occurrence) so all tests use flashinfer.TopKTieBreak instead of integers.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
tie_break for topktie_break for filtered topk
|
@coderabbitai resume |
✅ Actions performedReviews resumed. |
|
/bot run |
|
@kahyunnam thanks for letting me know! I will rebase. Implmentation code looks fine but |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_topk.py (1)
690-721:⚠️ Potential issue | 🟡 MinorCheck
--tie-breakbefore the deterministic compare-algorithms rejection.Line 694 makes
args.deterministic=True, so the Line 719args.tie_breakbranch is unreachable for--compare-algorithms --tie-break; users get the deterministic error instead of the tie-break-specific error.Proposed fix
# Algorithm comparison mode if args.compare_algorithms: + if args.tie_break: + print("ERROR: --compare-algorithms does not support --tie-break") + return if args.deterministic: print( "ERROR: --compare-algorithms is only meaningful with non-deterministic mode" ) return - if args.tie_break: - print("ERROR: --compare-algorithms does not support --tie-break") - return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 690 - 721, The tie-break auto-enable logic (the if args.tie_break block that sets args.deterministic = True) must not run before the compare-algorithms validation because it masks the intended tie-break incompatibility error; change the order or the condition so that the compare-algorithms check (the if args.compare_algorithms block that errors when args.deterministic is True and when args.tie_break is True) is evaluated first, or make the tie-break auto-enable skip if args.compare_algorithms is set, ensuring the args.tie_break path triggers its specific error rather than being hidden by the deterministic check.
🧹 Nitpick comments (1)
benchmarks/bench_topk.py (1)
640-647: Consider registering the new tie-break benchmark mode in the unified benchmark framework.This adds another custom CLI benchmark path in
bench_topk.py; the benchmark cases would be easier to run consistently if wired throughbenchmarks/flashinfer_benchmark.py.As per coding guidelines,
benchmarks/**/*.py: Use the unified benchmarking framework inbenchmarks/flashinfer_benchmark.pyfor kernel benchmarking with CUPTI timing support.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 640 - 647, You added a new CLI flag "--tie-break" via parser.add_argument in bench_topk.py but did not wire this mode into the unified benchmark framework; update the benchmark registration so the tie-break variants are exposed through the central runner in flashinfer_benchmark.py instead of a custom path. Specifically, detect the tie-break option where existing modes are mapped (the code that registers/dispatches benchmark cases) and add entries that create the deterministic tie-break benchmark cases (tie-small/tie-large) with the same CUPTI timing and slowdown alignment logic used by existing FlashInfer variants; ensure the dispatch uses the same helper/factory functions used for other modes so CUPTI setup and reporting are consistent with the rest of the benchmarks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 485-502: The tie-break variant runs use a different top-k
algorithm than the baseline because set_topk_algo("auto") is active; fix by
running bench_tie_break_variants under the same algorithm used when measuring
fi_ms/fi_nondeterministic_ms (call set_topk_algo("default") or the saved
baseline algorithm before invoking bench_tie_break_variants and restore the
prior algorithm afterwards). Update both places where bench_tie_break_variants
is called (the block using compare_tie_break and the similar block at the other
location) so the top-k algorithm matches the baseline measurement when
constructing the lambda that calls flashinfer.top_k_page_table_transform.
- Around line 89-103: The lambda passed to bench_median_ms inside
bench_tie_break_variants closes over the loop variable tie_break, which can lead
to late-binding bugs; change the lambda to bind the current tie_break value as a
default parameter (e.g. use a parameter like tb=tie_break) and call
run_flashinfer_with_tie_break(tb) so each iteration captures its own tie_break;
update the call site where bench_median_ms is invoked to use this bound lambda
and keep all other metric logic the same (references: bench_tie_break_variants,
run_flashinfer_with_tie_break, TIE_BREAK_VARIANTS, bench_median_ms).
---
Outside diff comments:
In `@benchmarks/bench_topk.py`:
- Around line 690-721: The tie-break auto-enable logic (the if args.tie_break
block that sets args.deterministic = True) must not run before the
compare-algorithms validation because it masks the intended tie-break
incompatibility error; change the order or the condition so that the
compare-algorithms check (the if args.compare_algorithms block that errors when
args.deterministic is True and when args.tie_break is True) is evaluated first,
or make the tie-break auto-enable skip if args.compare_algorithms is set,
ensuring the args.tie_break path triggers its specific error rather than being
hidden by the deterministic check.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 640-647: You added a new CLI flag "--tie-break" via
parser.add_argument in bench_topk.py but did not wire this mode into the unified
benchmark framework; update the benchmark registration so the tie-break variants
are exposed through the central runner in flashinfer_benchmark.py instead of a
custom path. Specifically, detect the tie-break option where existing modes are
mapped (the code that registers/dispatches benchmark cases) and add entries that
create the deterministic tie-break benchmark cases (tie-small/tie-large) with
the same CUPTI timing and slowdown alignment logic used by existing FlashInfer
variants; ensure the dispatch uses the same helper/factory functions used for
other modes so CUPTI setup and reporting are consistent with the rest of the
benchmarks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/topk.cuh (1)
3063-3112:⚠️ Potential issue | 🟡 Minor
tie_breakis silently dropped whendeterministic == falsein the lower-level launch APIs.In
DISPATCH_VEC_SIZE, the!deterministicbranch hard-codesTopKTieBreak::Noneregardless of the incomingtie_breakargument:if (!deterministic) { LAUNCH_FILTERED_KERNEL(VS, false, TopKTieBreak::None); // tie_break ignored } else { ... }The three public wrappers (
FilteredTopK,FilteredTopKPageTableTransform,FilteredTopKRaggedTransform) also accepttie_breakand forward it verbatim without forcingdeterministic = true. Only the top-levelTopKDispatch/TopKPageTableTransformDispatch/TopKRaggedTransformDispatchflipdeterministicwhentie_break != None.That means any future/internal caller that reaches the lower layer with
(deterministic=false, tie_break=Small|Large)will silently get non-deterministic, non-tie-broken results — violating the documented contract that "selecting a tie-break forces deterministic execution."Either force
deterministic = truehere (and in the three wrappers) whentie_break != None, or at minimum returncudaErrorInvalidValueto make the misuse loud.🛠️ Possible fix at the launch site
const int vec_size = ComputeFilteredTopKVecSize<DType>(max_len); + if (tie_break != TopKTieBreak::None && !deterministic) { + // Tie-break modes require deterministic emission; the non-deterministic + // kernel path ignores TIE_BREAK at compile time. + deterministic = true; + }Or, stricter:
+ if (tie_break != TopKTieBreak::None && !deterministic) { + return cudaErrorInvalidValue; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3063 - 3112, The DISPATCH_VEC_SIZE path currently ignores the incoming tie_break when deterministic==false; update LaunchFilteredTopKUnified so that if tie_break != TopKTieBreak::None you force deterministic=true (or return cudaErrorInvalidValue) before dispatching the kernel; specifically, check the tie_break parameter at the start of LaunchFilteredTopKUnified (or inside DISPATCH_VEC_SIZE) and either set deterministic=true when tie_break is Small or Large, or return an error, and mirror this same enforcement in the public wrappers FilteredTopK, FilteredTopKPageTableTransform, and FilteredTopKRaggedTransform so callers cannot reach the lower layer with (deterministic=false, tie_break!=None).
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)
222-301: Add a brief justification for the per-thread-consecutive item layout.
DeterministicContiguousCollectdeliberately gives threadttheITEMS_PER_THREADconsecutive indices[t*IPT, t*IPT+IPT)within each chunk, which makes intra-warp loads strided byITEMS_PER_THREADrather than unit-stride. This is the right choice here (a CUBBlockScanover the per-thread counts naturally yields contiguous/REVERSE-contiguous emission order, which is exactly what row-global tie-breaking needs), but it is a non-obvious perf trade-off vs. thetx, tx+BLOCK_THREADS, …layout used byDeterministicThreadStridedCollect.Per the coding guideline on performance-critical hot paths, it would help future readers to note this rationale (and mention the strided-layout alternative) next to the existing
// TODO: maybe tune ITEMS_PER_THREAD and vectorizecomment. No code change required.As per coding guidelines: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 222 - 301, Add a short comment in DeterministicContiguousCollect (near the ITEMS_PER_THREAD/TODO comment) explaining that the per-thread-consecutive layout (thread t owning indices [t*ITEMS_PER_THREAD, ...)) intentionally trades unit-stride loads for intra-warp strided loads because cub::BlockScan over per-thread selected counts yields contiguous (or REVERSE-contiguous) global emission order required for row-global tie-breaking; also mention the alternative strided layout (tx, tx+BLOCK_THREADS, ...) used by DeterministicThreadStridedCollect and why it was rejected here (to preserve deterministic contiguous output ordering), so future readers understand the performance/correctness rationale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3063-3112: The DISPATCH_VEC_SIZE path currently ignores the
incoming tie_break when deterministic==false; update LaunchFilteredTopKUnified
so that if tie_break != TopKTieBreak::None you force deterministic=true (or
return cudaErrorInvalidValue) before dispatching the kernel; specifically, check
the tie_break parameter at the start of LaunchFilteredTopKUnified (or inside
DISPATCH_VEC_SIZE) and either set deterministic=true when tie_break is Small or
Large, or return an error, and mirror this same enforcement in the public
wrappers FilteredTopK, FilteredTopKPageTableTransform, and
FilteredTopKRaggedTransform so callers cannot reach the lower layer with
(deterministic=false, tie_break!=None).
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 222-301: Add a short comment in DeterministicContiguousCollect
(near the ITEMS_PER_THREAD/TODO comment) explaining that the
per-thread-consecutive layout (thread t owning indices [t*ITEMS_PER_THREAD,
...)) intentionally trades unit-stride loads for intra-warp strided loads
because cub::BlockScan over per-thread selected counts yields contiguous (or
REVERSE-contiguous) global emission order required for row-global tie-breaking;
also mention the alternative strided layout (tx, tx+BLOCK_THREADS, ...) used by
DeterministicThreadStridedCollect and why it was rejected here (to preserve
deterministic contiguous output ordering), so future readers understand the
performance/correctness rationale.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3702cdcc-e9b9-4b5c-9e74-9b0fa7b5009b
📒 Files selected for processing (2)
benchmarks/bench_topk.pyinclude/flashinfer/topk.cuh
|
@coderabbitai stop |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/topk.cuh (1)
3080-3111:⚠️ Potential issue | 🟡 Minor
tie_breakis silently dropped on the!deterministicbranch.
DISPATCH_VEC_SIZEalways launches withTopKTieBreak::Nonewhendeterministicis false, regardless of thetie_breakargument. The top-levelTopKDispatch/TopKPageTableTransformDispatch/TopKRaggedTransformDispatchforcedeterministic=truewhenevertie_break != None, so end users are safe — butLaunchFilteredTopKUnifiedand the mid-level wrappersFilteredTopK,FilteredTopKPageTableTransform,FilteredTopKRaggedTransformall exposetie_breakalongside a defaulteddeterministic=false. A direct caller passingtie_break=Small/Largewithout flippingdeterministicwill get ordinary non-deterministic output with no diagnostic. Prefer failing loudly (or auto-promoting to deterministic) at this layer so the invarianttie_break != None ⇒ deterministicis enforced at every entry point.🛡️ Proposed guard
cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_output, const IdType* aux_input, int64_t aux_stride, const IdType* row_to_batch, const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0) { + if (tie_break != TopKTieBreak::None && !deterministic) { + // Tie-break selection requires the deterministic collect path. + return cudaErrorInvalidValue; + } constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3080 - 3111, The DISPATCH_VEC_SIZE macro silently drops the provided tie_break when deterministic==false; update LaunchFilteredTopKUnified (and callers FilteredTopK / FilteredTopKPageTableTransform / FilteredTopKRaggedTransform) to enforce the invariant "tie_break != TopKTieBreak::None ⇒ deterministic==true": add an explicit guard before dispatching (e.g., if (!deterministic && tie_break != TopKTieBreak::None) return cudaErrorInvalidValue;) so callers who pass tie_break=Small/Large get a clear error instead of losing the tie-break behavior; adjust the branch logic in DISPATCH_VEC_SIZE / LAUNCH_FILTERED_KERNEL accordingly so the deterministic path is used only when deterministic==true.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3080-3111: The DISPATCH_VEC_SIZE macro silently drops the provided
tie_break when deterministic==false; update LaunchFilteredTopKUnified (and
callers FilteredTopK / FilteredTopKPageTableTransform /
FilteredTopKRaggedTransform) to enforce the invariant "tie_break !=
TopKTieBreak::None ⇒ deterministic==true": add an explicit guard before
dispatching (e.g., if (!deterministic && tie_break != TopKTieBreak::None) return
cudaErrorInvalidValue;) so callers who pass tie_break=Small/Large get a clear
error instead of losing the tie-break behavior; adjust the branch logic in
DISPATCH_VEC_SIZE / LAUNCH_FILTERED_KERNEL accordingly so the deterministic path
is used only when deterministic==true.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1a727b01-c095-4a81-9994-2ba40b30ac61
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/__init__.pyflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/init.py
🚧 Files skipped from review as they are similar to previous changes (4)
- csrc/topk.cu
- flashinfer/topk.py
- csrc/flashinfer_topk_binding.cu
- tests/utils/test_topk.py
|
🧠 Learnings used✅ Actions performedReviews paused. |
|
/bot run |
kahyunnam
left a comment
There was a problem hiding this comment.
@zianglih some CICD unit tests are failing on RTX pro 6000 (sm120). Please take a look; if "operation not supported" is expected here, then a change in the test script may be needed to handle this case / catch the error without falling through.
Test log:
==========================================
Running: pytest --continue-on-collection-errors --junitxml=/tmp/junit/tests_utils_test_topk.py.xml "tests/utils/test_topk.py"
==========================================
============================= test session starts ==============================
platform linux -- Python 3.12.13, pytest-9.0.3, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1334 items
tests/utils/test_topk.py ............................................... [ 3%]
..................................FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 8%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 14%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF................FFFF [ 19%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFF..FFFF........FFFFFFFFFFFFFFFF.............. [ 25%]
........................................................................ [ 30%]
........................................................................ [ 35%]
........................................................................ [ 41%]
........................................................................ [ 46%]
......................................................s..s..s..s..s..s.. [ 52%]
s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s.. [ 57%]
s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..s..........sss. [ 62%]
.ss..ss..ssssssss....s...s...s...sssssss..s....ssss...s..s..s..s.s.s.... [ 68%]
...sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 79%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 84%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [100%]
=================================== FAILURES ===================================
E RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
/workspace/flashinfer/csrc/topk.cu:79: RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
E RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
/workspace/flashinfer/csrc/topk.cu:79: RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
E RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
/workspace/flashinfer/csrc/topk.cu:79: RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
E RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
/workspace/flashinfer/csrc/topk.cu:79: RuntimeError: Check failed: (status == cudaSuccess) is false: TopK failed with error code operation not supported
... etc...
=========================== short test summary info ============================
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-32000-1] - Runti...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-32000-16] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-32000-64] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-65536-1] - Runti...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-65536-16] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-65536-64] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-128512-1] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-128512-16] - Run...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-256-128512-64] - Run...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-512-32000-1] - Runti...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-512-32000-16] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-512-32000-64] - Runt...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-512-65536-1] - Runti...
FAILED tests/utils/test_topk.py::test_top_k[small-dtype0-512-65536-16] - Runt...
... etc ...
=========== 214 failed, 614 passed, 506 skipped in 104.92s (0:01:44) ===========
❌ FAILED: tests/utils/test_topk.py
==========================================
|
Hi @kahyunnam, this is fixed by 53eb16f . Waiting ci. Thanks! |
|
/bot run |
|
Previously failing A10G and T4 ci passed. |
📌 Description
@HumansAnd
Implement configurable
tie_breakfor filtered topk.This automatically achieves determinism.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
python3 -m pytest tests/utils/test_topk.py -vpassed on B200.Summary by CodeRabbit
New Features
CLI
Benchmarks
Tests