Skip to content

fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155

Merged
kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
arpera:fix-backwell-gdn
Apr 25, 2026
Merged

fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155
kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
arpera:fix-backwell-gdn

Conversation

@arpera
Copy link
Copy Markdown
Contributor

@arpera arpera commented Apr 23, 2026

📌 Description

Fixes the num_sm issue CodeRabbit flagged on #3001 but which was not applied before merge: #3001 (comment)

The raw HardwareInfo().get_max_active_clusters(1) call returns 0 / stale values in spawned subprocesses (e.g. vLLM's EngineCore workers) where the CUDA driver API context has not been made current yet. The persistent tile scheduler then leaves some CTAs without any work and the kernel deadlocks at first call. Switch to get_num_sm(q.device), matching the SM120 MoE dispatch.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Kernel compilation now derives device-specific SM and cluster counts at runtime, improving GPU resource allocation and leading to more consistent performance across different CUDA devices.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a8731a58-4907-459e-a8b5-4c8df701c516

📥 Commits

Reviewing files that changed from the base of the PR and between 1f337a3 and c7f1020.

📒 Files selected for processing (1)
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py

📝 Walkthrough

Walkthrough

Kernel compilation for the Blackwell GDN prefill now derives num_sm from get_num_sm(q.device) and sets max_active_clusters equal to that during the initial compile-once path; compile/cache-and-replay control flow and execution logic are unchanged.

Changes

Cohort / File(s) Summary
Kernel compilation update
flashinfer/gdn_kernels/blackwell/gdn_prefill.py
Replace use of cutlass.utils.HardwareInfo for SM-related values: derive num_sm with get_num_sm(q.device) and set max_active_clusters to the same value for the initial compile-once inputs to GatedDeltaNetChunkedKernel. Control flow and execution remain unchanged.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu

Poem

🐰 I nudged the kernel, counted SMs with care,
No hardware gossip, just numbers to spare.
Clusters matched neatly, compile-once in sight,
Cached like a burrow, ready for flight. ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: fixing SM count selection for the SM100 persistent prefill kernel by using the physical SM count instead of HardwareInfo.
Description check ✅ Passed The description includes a detailed explanation of the fix, references the related issue, and confirms completion of pre-commit checks and tests, meeting the template requirements.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the hardware information retrieval in the Blackwell GDN prefill kernel by replacing cutlass_utils.HardwareInfo with specialized utility functions. A potential issue was identified where get_max_active_clusters could return zero in certain environments (e.g., spawned subprocesses), which would lead to kernel launch failures. A suggestion was made to provide a fallback to the total number of SMs in such cases.

Comment thread flashinfer/gdn_kernels/blackwell/gdn_prefill.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 38: Remove the unused import get_max_active_clusters and change the
persistent scheduler grid cap logic to use num_sm directly (instead of
min(get_max_active_clusters(1), num_sm)) so the grid shape never becomes
(0,1,1); update the code locations referencing get_max_active_clusters and the
persistent scheduler grid shape (search for get_max_active_clusters and the
variable num_sm and where the persistent scheduler grid is constructed) to
compute cap = num_sm and use that cap when forming the scheduler grid to avoid a
zero dimension.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b94f14f-a888-4860-b268-b68dbacc313e

📥 Commits

Reviewing files that changed from the base of the PR and between 9f7adfb and 66e738c.

📒 Files selected for processing (1)
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Comment thread flashinfer/gdn_kernels/blackwell/gdn_prefill.py Outdated
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/gdn_kernels/blackwell/gdn_prefill.py (1)

38-38: ⚠️ Potential issue | 🔴 Critical

Avoid the stale active-cluster probe entirely.

Line 162 still calls get_max_active_clusters(1). The or num_sm fallback handles 0/None, but a stale positive value still survives through min(...) and can under-cap the persistent scheduler. For this SM100 path, use the physical SM count directly and drop the import.

Proposed fix
-from flashinfer.cute_dsl.utils import get_max_active_clusters, get_num_sm
+from flashinfer.cute_dsl.utils import get_num_sm
@@
         # --- First call: compile the kernel ---
         num_sm = get_num_sm(q.device)
-        max_active_clusters = min(get_max_active_clusters(1) or num_sm, num_sm)
+        max_active_clusters = num_sm

Also applies to: 161-162

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` at line 38, Remove the stale
active-cluster probe by deleting the import of get_max_active_clusters (and
get_num_sm if unused) and change the logic that currently calls
get_max_active_clusters(1) (and then uses min(...) with num_sm) to use the
physical SM count directly (the num_sm variable / physical SM-count provider)
when computing the persistent scheduler capacity; ensure any min(...) uses only
the real SM count and adjust variable names accordingly so the persistent
scheduler is never capped by a stale probe value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 38: Remove the stale active-cluster probe by deleting the import of
get_max_active_clusters (and get_num_sm if unused) and change the logic that
currently calls get_max_active_clusters(1) (and then uses min(...) with num_sm)
to use the physical SM count directly (the num_sm variable / physical SM-count
provider) when computing the persistent scheduler capacity; ensure any min(...)
uses only the real SM count and adjust variable names accordingly so the
persistent scheduler is never capped by a stale probe value.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 241377df-4d17-435a-857e-c6f54aca9b61

📥 Commits

Reviewing files that changed from the base of the PR and between 66e738c and 1f337a3.

📒 Files selected for processing (1)
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py

@jiahanc jiahanc added the run-ci label Apr 23, 2026
@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Apr 23, 2026

/bot run

Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !589 has been created, and the CI pipeline #49300979 is currently running. I'll report back once the pipeline job completes.

@vadiklyutiy
Copy link
Copy Markdown
Contributor

@jiahanc any idea why we didn't catch it with unit tests?

@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Apr 24, 2026

@jiahanc any idea why we didn't catch it with unit tests?

could be because the unit test doesnt have as high pressure as framework side, so bug is not exposed

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

@kahyunnam kahyunnam merged commit 5e1318c into flashinfer-ai:main Apr 25, 2026
28 of 34 checks passed
kahyunnam pushed a commit that referenced this pull request Apr 27, 2026
## 📌 Description

Addresses the two remaining CodeRabbit findings on
[#3001](#3001) that
weren't applied before merge:

* **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend
dispatch so the same call gives matching numerics on SM90 and SM100. The
SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the
SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken
attention.

* **Don't eagerly allocate `output_state`** on the SM100 path when
`output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway,
so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch
per call. SM90 still allocates unconditionally because its C++ kernel
always writes into `output_state`.

Dispatcher callsites now pass `output_state` directly on both branches
(no inline `output_state if output_final_state else None`), so SM90 and
SM100 read identically.


## 🔍 Related Issues

* [[feat] Add blackwell GDN prefill
kernel](#3001)
* [fix(gdn): use physical SM count for SM100 persistent prefill
kernel#3155](#3155)
* [[fix] fix blackwell gdn accuracy
issue#3156](#3156)

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed scale parameter handling to correctly interpret explicit values
and apply default scaling behavior.
* Improved memory efficiency by avoiding unnecessary state allocations
in certain configurations.

* **Improvements**
* Enhanced consistency in kernel invocation logic across different
hardware architectures.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants