[fix] fix blackwell gdn accuracy issue#3156
Conversation
📝 WalkthroughWalkthroughTwo targeted refinements to the codebase: the kernel computation is simplified by eliminating an intermediate variable and directly indexing a cumulative product array, while test assertions are tightened by reducing absolute tolerance thresholds for numerical comparisons. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request simplifies the calculation of cumprod_total in the Blackwell gated delta net kernel and tightens the absolute tolerance for output comparisons in the prefill delta rule tests. Feedback was provided to use self.b_t - 1 instead of sCumprod.shape[0] - 1 to maintain consistency with the rest of the kernel implementation.
|
/bot run |
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM, seems like simple fix. Will help merge after /bot run tests pass.
|
Could we add test(s) that check it? |
|
@Observer007, if you propose an accuracy fix, could you then share some testing done information to prove that the issue was resolved? |
We've already got tests, just recover to use more strict tolerance for guard. |
Updated in description. |
|
Could you, please, specify name of the test in the description section? As well in the description attach, please, source code of the test or give a link to it if this test is part of Flashinfer's codebase. |
Do you reference to change |
Right, in original mr #3001 , the tolerance has been loosened from |
More related test links are updated. |
jiahanc
left a comment
There was a problem hiding this comment.
LGTM, thanks for the fix
|
Thanks for the fix! |
## 📌 Description Addresses the two remaining CodeRabbit findings on [#3001](#3001) that weren't applied before merge: * **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend dispatch so the same call gives matching numerics on SM90 and SM100. The SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken attention. * **Don't eagerly allocate `output_state`** on the SM100 path when `output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway, so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch per call. SM90 still allocates unconditionally because its C++ kernel always writes into `output_state`. Dispatcher callsites now pass `output_state` directly on both branches (no inline `output_state if output_final_state else None`), so SM90 and SM100 read identically. ## 🔍 Related Issues * [[feat] Add blackwell GDN prefill kernel](#3001) * [fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155](#3155) * [[fix] fix blackwell gdn accuracy issue#3156](#3156) ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed scale parameter handling to correctly interpret explicit values and apply default scaling behavior. * Improved memory efficiency by avoiding unnecessary state allocations in certain configurations. * **Improvements** * Enhanced consistency in kernel invocation logic across different hardware architectures. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Fixed the accuracy issue in blackwell gdn kernel found by @bestzsq.
The root cause is that the legacy
max_coordis not the actual last coord of thesCumprod. We change to the last coord instead. It's a deeply hidden bug that we hadn't discovered previously. Thanks to @bestzsq.Reproducer test link from @bestzsq: #3001 (comment)
Reproducer test output before this pr:
Reproducer test output after this pr:
Previous local test tolerance loosen from
1e-3to2e-3in #3001 : https://github.com/flashinfer-ai/flashinfer/pull/3001/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6L132This pr tightenes the tolerance from
2e-3to1e-3: https://github.com/flashinfer-ai/flashinfer/pull/3156/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6R148🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
Refactor
Tests