fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155
fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughKernel compilation for the Blackwell GDN prefill now derives Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the hardware information retrieval in the Blackwell GDN prefill kernel by replacing cutlass_utils.HardwareInfo with specialized utility functions. A potential issue was identified where get_max_active_clusters could return zero in certain environments (e.g., spawned subprocesses), which would lead to kernel launch failures. A suggestion was made to provide a fallback to the total number of SMs in such cases.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 38: Remove the unused import get_max_active_clusters and change the
persistent scheduler grid cap logic to use num_sm directly (instead of
min(get_max_active_clusters(1), num_sm)) so the grid shape never becomes
(0,1,1); update the code locations referencing get_max_active_clusters and the
persistent scheduler grid shape (search for get_max_active_clusters and the
variable num_sm and where the persistent scheduler grid is constructed) to
compute cap = num_sm and use that cap when forming the scheduler grid to avoid a
zero dimension.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2b94f14f-a888-4860-b268-b68dbacc313e
📒 Files selected for processing (1)
flashinfer/gdn_kernels/blackwell/gdn_prefill.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/gdn_kernels/blackwell/gdn_prefill.py (1)
38-38:⚠️ Potential issue | 🔴 CriticalAvoid the stale active-cluster probe entirely.
Line 162 still calls
get_max_active_clusters(1). Theor num_smfallback handles0/None, but a stale positive value still survives throughmin(...)and can under-cap the persistent scheduler. For this SM100 path, use the physical SM count directly and drop the import.Proposed fix
-from flashinfer.cute_dsl.utils import get_max_active_clusters, get_num_sm +from flashinfer.cute_dsl.utils import get_num_sm @@ # --- First call: compile the kernel --- num_sm = get_num_sm(q.device) - max_active_clusters = min(get_max_active_clusters(1) or num_sm, num_sm) + max_active_clusters = num_smAlso applies to: 161-162
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` at line 38, Remove the stale active-cluster probe by deleting the import of get_max_active_clusters (and get_num_sm if unused) and change the logic that currently calls get_max_active_clusters(1) (and then uses min(...) with num_sm) to use the physical SM count directly (the num_sm variable / physical SM-count provider) when computing the persistent scheduler capacity; ensure any min(...) uses only the real SM count and adjust variable names accordingly so the persistent scheduler is never capped by a stale probe value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 38: Remove the stale active-cluster probe by deleting the import of
get_max_active_clusters (and get_num_sm if unused) and change the logic that
currently calls get_max_active_clusters(1) (and then uses min(...) with num_sm)
to use the physical SM count directly (the num_sm variable / physical SM-count
provider) when computing the persistent scheduler capacity; ensure any min(...)
uses only the real SM count and adjust variable names accordingly so the
persistent scheduler is never capped by a stale probe value.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 241377df-4d17-435a-857e-c6f54aca9b61
📒 Files selected for processing (1)
flashinfer/gdn_kernels/blackwell/gdn_prefill.py
|
/bot run |
|
@jiahanc any idea why we didn't catch it with unit tests? |
could be because the unit test doesnt have as high pressure as framework side, so bug is not exposed |
## 📌 Description Addresses the two remaining CodeRabbit findings on [#3001](#3001) that weren't applied before merge: * **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend dispatch so the same call gives matching numerics on SM90 and SM100. The SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken attention. * **Don't eagerly allocate `output_state`** on the SM100 path when `output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway, so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch per call. SM90 still allocates unconditionally because its C++ kernel always writes into `output_state`. Dispatcher callsites now pass `output_state` directly on both branches (no inline `output_state if output_final_state else None`), so SM90 and SM100 read identically. ## 🔍 Related Issues * [[feat] Add blackwell GDN prefill kernel](#3001) * [fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155](#3155) * [[fix] fix blackwell gdn accuracy issue#3156](#3156) ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed scale parameter handling to correctly interpret explicit values and apply default scaling behavior. * Improved memory efficiency by avoiding unnecessary state allocations in certain configurations. * **Improvements** * Enhanced consistency in kernel invocation logic across different hardware architectures. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Fixes the
num_smissue CodeRabbit flagged on #3001 but which was not applied before merge: #3001 (comment)The raw
HardwareInfo().get_max_active_clusters(1)call returns 0 / stale values in spawned subprocesses (e.g. vLLM's EngineCore workers) where the CUDA driver API context has not been made current yet. The persistent tile scheduler then leaves some CTAs without any work and the kernel deadlocks at first call. Switch toget_num_sm(q.device), matching the SM120 MoE dispatch.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit