fix: topK uint32 overflow#2937
Conversation
also add cub stable radix sort and overflow handling Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces deterministic mode for FlashInfer's top-k operations, including basic top-k selection and fused transforms (page table and ragged). The implementation includes new CUDA kernels for deterministic index collection and post-processing stable sorts to ensure bitwise-reproducible results across different runs and environments. The Python API has been updated to expose a deterministic flag, and the benchmarking suite is expanded to include DeepSeek DSA-like workloads and various input patterns to stress-test tie-handling logic. Extensive tests have been added to verify repeatability and correctness. I have no feedback to provide as there were no review comments to assess.
|
cherry picked into #2661 |
📌 Description
This PR is based on top of #2661. Merge afterwards.
Fixes uint32 overflow in top_k when batch_size * vocab_size > 2^32
The row offset
row_idx * stridewas computed in uint32 arithmetic, silentlywrapping to zero for large inputs. Cast to size_t before the multiplication.
Add a regression test with batch=32769, vocab=131072 (fp16) that crosses the
overflow boundary.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes