Skip to content

[fix] fix blackwell gdn accuracy issue#3156

Merged
kahyunnam merged 1 commit intoflashinfer-ai:mainfrom
Observer007:fix/gdn_accuracy
Apr 24, 2026
Merged

[fix] fix blackwell gdn accuracy issue#3156
kahyunnam merged 1 commit intoflashinfer-ai:mainfrom
Observer007:fix/gdn_accuracy

Conversation

@Observer007
Copy link
Copy Markdown
Contributor

@Observer007 Observer007 commented Apr 23, 2026

📌 Description

Fixed the accuracy issue in blackwell gdn kernel found by @bestzsq.

The root cause is that the legacy max_coord is not the actual last coord of the sCumprod. We change to the last coord instead. It's a deeply hidden bug that we hadn't discovered previously. Thanks to @bestzsq.

Reproducer test link from @bestzsq: #3001 (comment)

Reproducer test output before this pr:

# flash-linear-attention==0.4.2
fla vs cute64: mae: 2.82288e-03, ulp: 9040.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
# flash-linear-attention==0.5.0
fla vs cute64: mae: 2.82288e-03, ulp: 9064.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0

Reproducer test output after this pr:

# flash-linear-attention==0.4.2
fla vs cute64: mae: 3.05176e-05, ulp: 74.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
# flash-linear-attention==0.5.0
fla vs cute64: mae: 3.05176e-05, ulp: 74.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0

Previous local test tolerance loosen from 1e-3 to 2e-3 in #3001 : https://github.com/flashinfer-ai/flashinfer/pull/3001/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6L132

This pr tightenes the tolerance from 2e-3 to 1e-3: https://github.com/flashinfer-ai/flashinfer/pull/3156/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6R148

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Refactor

    • Improved kernel computation efficiency by consolidating internal calculation steps and removing redundant intermediate operations, reducing code complexity while preserving all existing functionality and performance characteristics.
  • Tests

    • Strengthened numerical validation by reducing tolerance thresholds in computational accuracy tests for greater precision, ensuring more stringent verification of output correctness and numerical consistency across test scenarios.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

📝 Walkthrough

Walkthrough

Two targeted refinements to the codebase: the kernel computation is simplified by eliminating an intermediate variable and directly indexing a cumulative product array, while test assertions are tightened by reducing absolute tolerance thresholds for numerical comparisons.

Changes

Cohort / File(s) Summary
Kernel Computation
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
Simplifies cumprod_total derivation in compute_group_1 by removing intermediate max_coord computation and directly indexing sCumprod using its first dimension's last element.
Test Tolerance Adjustment
tests/gdn/test_prefill_delta_rule.py
Tightens numerical comparison tolerance by reducing absolute tolerance (atol_o) for output assertion in the non-bfloat16 code path.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Poem

🐰 A hop through the kernel, so sleek and so clean,
Less intermediate clutter, now crisp and lean.
The tests stand more rigid, precision refined,
A tighter embrace for each value aligned.
Small tweaks, mighty purpose—hooray, we have dined!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title directly relates to the main change: fixing a Blackwell GDN kernel accuracy issue by correcting how cumprod_total is derived.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request provides a detailed description of the accuracy issue, root cause, fix, and includes reproducer test outputs before and after the fix.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the calculation of cumprod_total in the Blackwell gated delta net kernel and tightens the absolute tolerance for output comparisons in the prefill delta rule tests. Feedback was provided to use self.b_t - 1 instead of sCumprod.shape[0] - 1 to maintain consistency with the rest of the kernel implementation.

Comment thread flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !591 has been created, and the CI pipeline #49304188 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems like simple fix. Will help merge after /bot run tests pass.

@vadiklyutiy
Copy link
Copy Markdown
Contributor

Could we add test(s) that check it?

@arpera
Copy link
Copy Markdown
Contributor

arpera commented Apr 23, 2026

@Observer007, if you propose an accuracy fix, could you then share some testing done information to prove that the issue was resolved?

@Observer007
Copy link
Copy Markdown
Contributor Author

Could we add test(s) that check it?

We've already got tests, just recover to use more strict tolerance for guard.

@Observer007
Copy link
Copy Markdown
Contributor Author

@Observer007, if you propose an accuracy fix, could you then share some testing done information to prove that the issue was resolved?

Updated in description.

@arpera
Copy link
Copy Markdown
Contributor

arpera commented Apr 23, 2026

Could you, please, specify name of the test in the description section? As well in the description attach, please, source code of the test or give a link to it if this test is part of Flashinfer's codebase.

@nvpohanh
Copy link
Copy Markdown
Contributor

cc @kaixih @YAMY1234

@vadiklyutiy
Copy link
Copy Markdown
Contributor

Could we add test(s) that check it?

We've already got tests, just recover to use more strict tolerance for guard.

Do you reference to change atol_o = 2e-3 to atol_o = 1e-3?

@Observer007
Copy link
Copy Markdown
Contributor Author

Could we add test(s) that check it?

We've already got tests, just recover to use more strict tolerance for guard.

Do you reference to change atol_o = 2e-3 to atol_o = 1e-3?

Right, in original mr #3001 , the tolerance has been loosened from 1e-3 to 2e-3 . This pr changes it back to 1e-3.

@Observer007
Copy link
Copy Markdown
Contributor Author

Could you, please, specify name of the test in the description section? As well in the description attach, please, source code of the test or give a link to it if this test is part of Flashinfer's codebase.

More related test links are updated.

Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix

@bestzsq
Copy link
Copy Markdown

bestzsq commented Apr 24, 2026

Thanks for the fix!

@kahyunnam kahyunnam merged commit 3516d2b into flashinfer-ai:main Apr 24, 2026
46 of 59 checks passed
kahyunnam pushed a commit that referenced this pull request Apr 27, 2026
## 📌 Description

Addresses the two remaining CodeRabbit findings on
[#3001](#3001) that
weren't applied before merge:

* **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend
dispatch so the same call gives matching numerics on SM90 and SM100. The
SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the
SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken
attention.

* **Don't eagerly allocate `output_state`** on the SM100 path when
`output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway,
so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch
per call. SM90 still allocates unconditionally because its C++ kernel
always writes into `output_state`.

Dispatcher callsites now pass `output_state` directly on both branches
(no inline `output_state if output_final_state else None`), so SM90 and
SM100 read identically.


## 🔍 Related Issues

* [[feat] Add blackwell GDN prefill
kernel](#3001)
* [fix(gdn): use physical SM count for SM100 persistent prefill
kernel#3155](#3155)
* [[fix] fix blackwell gdn accuracy
issue#3156](#3156)

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed scale parameter handling to correctly interpret explicit values
and apply default scaling behavior.
* Improved memory efficiency by avoiding unnecessary state allocations
in certain configurations.

* **Improvements**
* Enhanced consistency in kernel invocation logic across different
hardware architectures.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants