Skip to content

Commit 5e1318c

Browse files
fix(gdn): use physical SM count for SM100 persistent prefill kernel (#3155)
## 📌 Description Fixes the `num_sm` issue CodeRabbit flagged on #3001 but which was not applied before merge: #3001 (comment) The raw `HardwareInfo().get_max_active_clusters(1)` call returns 0 / stale values in spawned subprocesses (e.g. vLLM's EngineCore workers) where the CUDA driver API context has not been made current yet. The persistent tile scheduler then leaves some CTAs without any work and the kernel deadlocks at first call. Switch to `get_num_sm(q.device)`, matching the SM120 MoE dispatch. ## 🔍 Related Issues ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Kernel compilation now derives device-specific SM and cluster counts at runtime, improving GPU resource allocation and leading to more consistent performance across different CUDA devices. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 24c4aee commit 5e1318c

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
import cuda.bindings.driver as cuda
3434
import cutlass
3535
import cutlass.cute as cute
36-
import cutlass.utils as cutlass_utils
3736
from cutlass.cute.runtime import from_dlpack
3837

38+
from flashinfer.cute_dsl.utils import get_num_sm
39+
3940
from .gated_delta_net_chunked import GatedDeltaNetChunkedKernel
4041

4142

@@ -157,9 +158,8 @@ def chunk_gated_delta_rule_sm100(
157158

158159
if "compiled" not in cache:
159160
# --- First call: compile the kernel ---
160-
hardware_info = cutlass_utils.HardwareInfo()
161-
num_sm = hardware_info.get_max_active_clusters(1)
162-
max_active_clusters = hardware_info.get_max_active_clusters(1)
161+
num_sm = get_num_sm(q.device)
162+
max_active_clusters = num_sm
163163

164164
gdn = GatedDeltaNetChunkedKernel(
165165
io_dtype=io_dtype,

0 commit comments

Comments
 (0)