[feat] Add blackwell GDN prefill kernel#3001
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Blackwell (SM100/SM100A) chunked GDN prefill support: new Blackwell package and scheduler, an SM100-specific compile-once kernel adapter with workspace management, top-level export handling updates, runtime branching in Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Prefill as GDN Prefill
participant DeviceCheck as Device/Version Check
participant KernelCache as Compiled Kernel Cache
participant CuTe as CuTe/CUTLASS (compile/exec)
participant JIT as SM90 JIT Kernel
User->>Prefill: call prefill(q,k,v,...)
Prefill->>DeviceCheck: query _has_blackwell_prefill, is_sm100a_supported, CUDA version
alt SM100 path (Blackwell + CUDA ≥13)
DeviceCheck-->>Prefill: choose SM100 route
Prefill->>KernelCache: lookup compiled kernel (dtype,HQ,HV,is_GQA,...)
alt cache hit
KernelCache-->>Prefill: return compiled callable + workspace info
else cache miss
Prefill->>CuTe: convert tensors, compile kernel, allocate workspace
CuTe-->>KernelCache: store compiled callable & device info
KernelCache-->>Prefill: return compiled callable
end
Prefill->>CuTe: execute compiled kernel with workspace & CUDA stream
CuTe-->>Prefill: outputs + optional final state
else SM90 fallback
DeviceCheck-->>Prefill: choose SM90 route
Prefill->>JIT: allocate workspace buffer, call SM90 JIT kernel
JIT-->>Prefill: outputs
end
Prefill-->>User: return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gated Delta Net (GDN) chunked prefill kernels on Blackwell (SM100) GPUs. Key changes include the addition of a new tile scheduler and a Blackwell-specific adapter for the GDN kernel, along with updates to benchmarks and tests to incorporate the new SM100 path. Feedback highlights documentation inconsistencies regarding state tensor layouts in the Blackwell adapter, a redundant calculation in the tile scheduler, and an unused variable in the prefill logic.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)
32-42: Reuse the shared arch predicates in this skip helper.Please lean on
is_sm90a_supported()/is_sm100a_supported()for the architecture half of this check and keep the CUDA-major gate layered on top if SM100 still needs it. That keeps the tests aligned with the runtime support policy in one place.As per coding guidelines,
tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(),is_sm90a_supported(),is_sm100a_supported()) to skip tests on unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 32 - 42, Replace the manual compute-capability checks in _skip_if_unsupported() with the shared predicates: call is_sm90a_supported() and is_sm100a_supported() to decide support, and only if is_sm100a_supported() is True still enforce the CUDA-major gate by parsing torch.version.cuda (as currently done) to require CUDA 13+; remove direct get_compute_capability() checks for SM90/SM100 and use those utility functions so the test skip logic aligns with the runtime support policy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 49-57: The cached mutable CUDA workspace is currently shared
across devices in _get_compiled_cache (and the other cached spots around lines
121-123 and 224-228), causing cross-device reuse; update the cache to be
device-safe by scoping the cache entry to the current CUDA device: when building
the cache key for _get_compiled_cache (and the other cache-holding functions),
include the current device id (torch.cuda.current_device() or torch.device) or
change the cached value to a dict keyed by device id so each GPU gets its own
workspace tensor; ensure the workspace tensor is created on the correct device
before storing and returned only for that device.
In `@flashinfer/gdn_prefill.py`:
- Around line 201-233: The code currently allocates the full float32 scratch
state (output_state) before choosing the backend, causing an unnecessary large
allocation when output_final_state is False; change the logic so output_state is
only allocated when output_final_state is True and the selected backend requires
it (i.e., before calling chunk_gated_delta_rule_sm100), or pass an
already-conditional None otherwise. Concretely, move or guard the allocation of
output_state behind the backend selection branch that calls
chunk_gated_delta_rule_sm100 and only create the [num_seqs, H, 128, 128] tensor
when output_final_state is True; ensure the call to chunk_gated_delta_rule_sm100
still receives output_state when needed and None when not.
- Around line 198-201: The code currently treats scale==0.0 inconsistently
between SM90 and SM100 paths; fix this by normalizing the incoming scale before
any backend dispatch: in the prefill function compute a concrete _scale value
(e.g. if scale is None use 1.0/math.sqrt(head_size), and if scale == 0.0 also
set _scale = 1.0/math.sqrt(head_size) or alternatively raise ValueError) so that
the same _scale is used for both the SM100 branch (is_sm100a_supported(device) /
_has_blackwell_prefill) and the other path; update the logic around _scale,
scale, head_size, is_sm100a_supported and the dispatch blocks so they read this
single resolved _scale.
---
Nitpick comments:
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 32-42: Replace the manual compute-capability checks in
_skip_if_unsupported() with the shared predicates: call is_sm90a_supported() and
is_sm100a_supported() to decide support, and only if is_sm100a_supported() is
True still enforce the CUDA-major gate by parsing torch.version.cuda (as
currently done) to require CUDA 13+; remove direct get_compute_capability()
checks for SM90/SM100 and use those utility functions so the test skip logic
aligns with the runtime support policy.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ef3072ef-612f-4d41-a948-df4a7c9633be
📒 Files selected for processing (9)
benchmarks/bench_gdn_prefill.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/blackwell/__init__.pyflashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.pyflashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.pyflashinfer/gdn_kernels/blackwell/gdn_prefill.pyflashinfer/gdn_prefill.pypyproject.tomltests/gdn/test_prefill_delta_rule.py
| _scale = scale if scale is not None else 1.0 / math.sqrt(head_size) | ||
|
|
||
| _cuda_major = int(torch.version.cuda.split(".")[0]) if torch.version.cuda else 0 | ||
| if _has_blackwell_prefill and is_sm100a_supported(device) and _cuda_major >= 13: |
There was a problem hiding this comment.
Make scale=0.0 backend-independent.
The SM90 launcher still treats 0.0 as the “use default 1 / sqrt(d)” sentinel, but the SM100 path forwards 0.0 literally. The same API call can therefore produce different numerics on Hopper vs. Blackwell. Please reject scale == 0.0 at the Python boundary or resolve it to one concrete value before both dispatches.
Also applies to: 222-233, 241-252
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_prefill.py` around lines 198 - 201, The code currently treats
scale==0.0 inconsistently between SM90 and SM100 paths; fix this by normalizing
the incoming scale before any backend dispatch: in the prefill function compute
a concrete _scale value (e.g. if scale is None use 1.0/math.sqrt(head_size), and
if scale == 0.0 also set _scale = 1.0/math.sqrt(head_size) or alternatively
raise ValueError) so that the same _scale is used for both the SM100 branch
(is_sm100a_supported(device) / _has_blackwell_prefill) and the other path;
update the logic around _scale, scale, head_size, is_sm100a_supported and the
dispatch blocks so they read this single resolved _scale.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
flashinfer/gdn_prefill.py (2)
182-195:⚠️ Potential issue | 🟠 MajorAvoid the eager
output_stateallocation on the SM100 no-final-state path.When
output_final_state=False, Lines 189-195 still allocate the full float32 state buffer up front, but Line 231 drops it for the SM100 launch. That buffer is unused on the Blackwell path and can be very large for biggernum_seqs/head counts.Also applies to: 231-233
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 182 - 195, The code eagerly allocates output_state even when output_final_state is False (in gdn_prefill.py where output_state is set), which wastes memory for SM100/Blackwell because that path drops the buffer later; change the logic so output_state is only allocated when output_final_state is True or when the backend/device requires a CPU/GPU buffer (e.g., detect device/backend used for the non-SM100 launch) — move or guard the allocation away from the current elif block and defer it until after the device/launch-path decision (see symbols output_final_state, output_state and the SM100/Blackwell path around the later kernel launch where the buffer is discarded at lines ~231-233) so no float32 buffer is created unnecessarily for the SM100 no-final-state path.
197-201:⚠️ Potential issue | 🟠 MajorNormalize
scale == 0.0before backend selection.Lines 198-201 resolve
Noneonly. A caller that passes0.0still gets backend-dependent behavior: the SM100 path forwards literal zero at Line 232, while the SM90 path keeps using0.0as the auto-scale sentinel at Line 251. The same API call can therefore change numerics depending on the backend.🧮 Proposed fix
- _scale = scale if scale is not None else 1.0 / math.sqrt(head_size) + default_scale = 1.0 / math.sqrt(head_size) + _scale = default_scale if scale is None or scale == 0.0 else scale @@ - scale if scale is not None else 0.0, + _scale,Also applies to: 222-233, 241-252
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 197 - 201, Normalize the sentinel value for scale before choosing the backend: treat scale==0.0 the same as scale is None and compute a single resolved value (e.g., resolved_scale = scale if scale not None and scale != 0.0 else 1.0/math.sqrt(head_size)) once before the SM100/SM90 conditional so both branches use the same numeric _scale; update references to _scale (and any downstream use at the SM100 path that currently forwards literal zero) to use resolved_scale to avoid backend-dependent behavior. Ensure you change the same pattern around the other occurrences (the blocks around the current 222-233 and 241-252 regions) so all branches read the same normalized value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell/__init__.py`:
- Around line 8-14: The optional-import guard in
flashinfer.gdn_kernels.blackwell currently only catches ImportError; update the
except clause in __init__.py to also catch RuntimeError (e.g., except
(ImportError, RuntimeError)) so that failures from the SM100 adapter are treated
as "backend unavailable", and ensure _has_blackwell_prefill is set to False and
chunk_gated_delta_rule_sm100 remains None (with the existing type ignore) in
that branch.
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 195: The code creates a CUDA stream with
cuda.CUstream(torch.cuda.current_stream().cuda_stream) which uses the current
device rather than the tensor's device; update the calls that construct the
stream (the usage of torch.cuda.current_stream()) to pass the tensor's device
explicitly (use torch.cuda.current_stream(device=q.device)) so cuda.CUstream is
created from the correct device's stream; locate occurrences around the stream
variable creation and replace the plain torch.cuda.current_stream() calls with
torch.cuda.current_stream(device=q.device) to ensure kernels run on q.device's
stream.
In `@flashinfer/gdn_prefill.py`:
- Around line 159-162: Add the `@backend_requirement` decorator to the API whose
docstring mentions SM90/SM100 (the function or class that contains these
docstring lines) and import backend_requirement; implement the required
introspection methods is_compute_capability_supported(cc) and
is_backend_supported() on that API so they mirror the runtime gating (check
SM90/SM100 compute-capability logic, the SM100 head_size==128 constraint, and
any backend dependency check used at runtime), and ensure these methods return
booleans so callers can query support before dispatch.
---
Duplicate comments:
In `@flashinfer/gdn_prefill.py`:
- Around line 182-195: The code eagerly allocates output_state even when
output_final_state is False (in gdn_prefill.py where output_state is set), which
wastes memory for SM100/Blackwell because that path drops the buffer later;
change the logic so output_state is only allocated when output_final_state is
True or when the backend/device requires a CPU/GPU buffer (e.g., detect
device/backend used for the non-SM100 launch) — move or guard the allocation
away from the current elif block and defer it until after the device/launch-path
decision (see symbols output_final_state, output_state and the SM100/Blackwell
path around the later kernel launch where the buffer is discarded at lines
~231-233) so no float32 buffer is created unnecessarily for the SM100
no-final-state path.
- Around line 197-201: Normalize the sentinel value for scale before choosing
the backend: treat scale==0.0 the same as scale is None and compute a single
resolved value (e.g., resolved_scale = scale if scale not None and scale != 0.0
else 1.0/math.sqrt(head_size)) once before the SM100/SM90 conditional so both
branches use the same numeric _scale; update references to _scale (and any
downstream use at the SM100 path that currently forwards literal zero) to use
resolved_scale to avoid backend-dependent behavior. Ensure you change the same
pattern around the other occurrences (the blocks around the current 222-233 and
241-252 regions) so all branches read the same normalized value.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 48b22d3c-48b7-4ab8-a5fa-6ff81cf76a32
📒 Files selected for processing (9)
benchmarks/bench_gdn_prefill.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/blackwell/__init__.pyflashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.pyflashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.pyflashinfer/gdn_kernels/blackwell/gdn_prefill.pyflashinfer/gdn_prefill.pypyproject.tomltests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (2)
- pyproject.toml
- benchmarks/bench_gdn_prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gdn/test_prefill_delta_rule.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
flashinfer/gdn_kernels/blackwell/__init__.py (1)
16-19: Consider sorting__all__alphabetically.Static analysis flags that
__all__is not sorted. While functionally correct, sorting improves consistency.♻️ Proposed fix
__all__ = [ - "chunk_gated_delta_rule_sm100", "_has_blackwell_prefill", + "chunk_gated_delta_rule_sm100", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell/__init__.py` around lines 16 - 19, The __all__ list is not alphabetically sorted; update the list containing "chunk_gated_delta_rule_sm100" and "_has_blackwell_prefill" so its entries are in alphabetical order (i.e., place "_has_blackwell_prefill" before "chunk_gated_delta_rule_sm100") to satisfy static-analysis sorting rules.flashinfer/gdn_kernels/blackwell/gdn_prefill.py (1)
106-108: Add a clarifying comment for theis_GQAcondition.The condition
is_GQA = HQ >= HVtreats equal head counts (standard MHA) the same as cases where query heads exceed value heads. While the kernel logic handles this correctly (the output head count and repetition factor calculations both produce correct results when HQ == HV), the naming conflates GQA (grouped query attention, where HQ > HV) with MHA (where HQ == HV). Consider adding a brief comment explaining that this condition captures all cases whereHQ >= HVfor the kernel's internal dispatch logic, even though standard MHA (HQ == HV) is semantically distinct from true GQA.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` around lines 106 - 108, Add a short clarifying comment above the is_GQA assignment explaining that the boolean is used for the kernel's internal dispatch logic and intentionally groups HQ == HV (standard MHA) with HQ > HV (true GQA) so the kernel treats both as the same path; reference the symbols HQ, HV, and is_GQA and mention that this is a pragmatic choice for output head count and repetition factor calculations, not a semantic conflation of MHA and GQA.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 127-129: Replace the incorrect SM count retrieval: instead of
calling HardwareInfo.get_max_active_clusters(1) for num_sm, import and call
get_num_sm(q.device) from flashinfer.cute_dsl.utils and use that value for
num_sm passed into GatedDeltaNetChunkedKernel and workspace size calculations;
remove or leave HardwareInfo usage only for max_active_clusters as needed and
ensure get_max_active_clusters(1) is not used to compute num_sm.
---
Nitpick comments:
In `@flashinfer/gdn_kernels/blackwell/__init__.py`:
- Around line 16-19: The __all__ list is not alphabetically sorted; update the
list containing "chunk_gated_delta_rule_sm100" and "_has_blackwell_prefill" so
its entries are in alphabetical order (i.e., place "_has_blackwell_prefill"
before "chunk_gated_delta_rule_sm100") to satisfy static-analysis sorting rules.
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 106-108: Add a short clarifying comment above the is_GQA
assignment explaining that the boolean is used for the kernel's internal
dispatch logic and intentionally groups HQ == HV (standard MHA) with HQ > HV
(true GQA) so the kernel treats both as the same path; reference the symbols HQ,
HV, and is_GQA and mention that this is a pragmatic choice for output head count
and repetition factor calculations, not a semantic conflation of MHA and GQA.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d9da8b4e-f9b5-4c4e-b287-5df251ce68df
📒 Files selected for processing (4)
docker/install/install_python_packages.shflashinfer/gdn_kernels/blackwell/__init__.pyflashinfer/gdn_kernels/blackwell/gdn_prefill.pytests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
- docker/install/install_python_packages.sh
| hardware_info = cutlass_utils.HardwareInfo() | ||
| num_sm = hardware_info.get_max_active_clusters(1) | ||
| max_active_clusters = hardware_info.get_max_active_clusters(1) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how num_sm is typically obtained in similar kernels
rg -n "num_sm\s*=" --type=py flashinfer/gdn_kernels/ -A2 -B2Repository: flashinfer-ai/flashinfer
Length of output: 1964
🏁 Script executed:
#!/bin/bash
# Check HardwareInfo API for SM count method
ast-grep --pattern $'class HardwareInfo {
$$$
}'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Locate the HardwareInfo class definition
fd -type f -name "*.py" | xargs rg -l "class HardwareInfo" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Check cutlass_utils module location
fd -type f -name "cutlass_utils.py"Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Also search for HardwareInfo in pyx/pxd files (Cython)
fd -type f \( -name "*.pyx" -o -name "*.pxd" \) | xargs rg -l "HardwareInfo" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 559
🏁 Script executed:
# Find files containing HardwareInfo class definition
rg "class HardwareInfo" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check imports in gdn_prefill.py to locate cutlass_utils
head -30 flashinfer/gdn_kernels/blackwell/gdn_prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1080
🏁 Script executed:
# Continue reading gdn_prefill.py imports
head -60 flashinfer/gdn_kernels/blackwell/gdn_prefill.py | tail -30Repository: flashinfer-ai/flashinfer
Length of output: 891
🏁 Script executed:
# Search for get_max_active_clusters method
rg "get_max_active_clusters" -A 5Repository: flashinfer-ai/flashinfer
Length of output: 8677
🏁 Script executed:
# Search for get_num_sm function definition and usage
rg "def get_num_sm|get_num_sm" -B 2 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 9982
🏁 Script executed:
# Check mla_decode imports more carefully
rg "from.*import|import" flashinfer/mla/cute_dsl/mla_decode.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 446
🏁 Script executed:
# Check what device is available in gdn_prefill.py context at lines 127-129
sed -n '110,150p' flashinfer/gdn_kernels/blackwell/gdn_prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1615
🏁 Script executed:
# Verify GatedDeltaNetChunkedKernel use of num_sm
sed -n '215,235p' flashinfer/gdn_kernels/blackwell/gdn_prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 681
num_sm should be computed using device SM count, not get_max_active_clusters(1).
Line 128 incorrectly assigns num_sm = hardware_info.get_max_active_clusters(1), which returns the max active clusters for a cluster size of 1—not the actual SM count. This value is then passed to GatedDeltaNetChunkedKernel (line 139) and used to calculate workspace size (line 221). The correct approach, used throughout the codebase (e.g., mxfp4_quantize.py, nvfp4_quantize.py, mla_decode.py), is to call get_num_sm(q.device), which returns torch.cuda.get_device_properties(q.device).multi_processor_count.
Import get_num_sm from flashinfer.cute_dsl.utils and replace line 128 with num_sm = get_num_sm(q.device).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` around lines 127 - 129,
Replace the incorrect SM count retrieval: instead of calling
HardwareInfo.get_max_active_clusters(1) for num_sm, import and call
get_num_sm(q.device) from flashinfer.cute_dsl.utils and use that value for
num_sm passed into GatedDeltaNetChunkedKernel and workspace size calculations;
remove or leave HardwareInfo usage only for max_active_clusters as needed and
ensure get_max_active_clusters(1) is not used to compute num_sm.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
benchmarks/bench_blackwell_gdn_prefill.py (2)
341-344: Add an explicit CUDA-availability guard for clearer failure mode.If CUDA is unavailable, the current flow may fail with a less actionable message before the SM100 check.
🔧 Proposed fix
def main(): @@ - device = torch.device("cuda") + if not torch.cuda.is_available(): + print("Error: CUDA is not available.") + sys.exit(1) + device = torch.device("cuda") if not is_sm100a_supported(device): print("Error: This benchmark requires SM100+ (Blackwell) GPU.") sys.exit(1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 341 - 344, Add an explicit CUDA availability check before selecting a CUDA device: call torch.cuda.is_available() and if it returns False print a clear error and sys.exit(1) before creating device or invoking is_sm100a_supported; ensure the guard is placed prior to the line that sets device = torch.device("cuda") so that subsequent calls like is_sm100a_supported(device) only run when CUDA is present.
294-295: Avoid blindExceptioncatches in sweep loops.Line 294 and Line 315 swallow all exceptions, including interruption/system-level signals, and make failures harder to triage.
🔧 Proposed fix
- except Exception as e: + except KeyboardInterrupt: + raise + except RuntimeError as e: print(f" FAILED: {e}") @@ - except Exception as e: + except KeyboardInterrupt: + raise + except RuntimeError as e: print(f" FAILED: {e}")Also applies to: 315-316
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 294 - 295, The broad "except Exception as e" handlers in the sweep loop should not swallow system-level interrupts; update the two except blocks that currently read "except Exception as e" so they re-raise KeyboardInterrupt and SystemExit (e.g., if isinstance(e, (KeyboardInterrupt, SystemExit)): raise) and only handle other exceptions by logging the error and traceback for triage, rather than silently printing; apply this change to both occurrences so interruption signals propagate and failures are logged with full context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 341-344: Add an explicit CUDA availability check before selecting
a CUDA device: call torch.cuda.is_available() and if it returns False print a
clear error and sys.exit(1) before creating device or invoking
is_sm100a_supported; ensure the guard is placed prior to the line that sets
device = torch.device("cuda") so that subsequent calls like
is_sm100a_supported(device) only run when CUDA is present.
- Around line 294-295: The broad "except Exception as e" handlers in the sweep
loop should not swallow system-level interrupts; update the two except blocks
that currently read "except Exception as e" so they re-raise KeyboardInterrupt
and SystemExit (e.g., if isinstance(e, (KeyboardInterrupt, SystemExit)): raise)
and only handle other exceptions by logging the error and traceback for triage,
rather than silently printing; apply this change to both occurrences so
interruption signals propagate and failures are logged with full context.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6e8a0c50-1328-4674-88ba-c521fef0fbb9
📒 Files selected for processing (1)
benchmarks/bench_blackwell_gdn_prefill.py
pre-commits failed at files not changed in this PR. Wonder if we should fix in separate PR |
|
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
| rtol_kv = 1e-3 | ||
| else: | ||
| atol_o = 1e-3 | ||
| atol_o = 2e-3 |
There was a problem hiding this comment.
Hello,Excellent work! I have a question that has been puzzling me: what causes the error to increase? I experimented with the #2742 branch with TF32 inverse implementation, and for long input sequences, the results aligned much more closely with the flash-linear-attention implementation. If we disregard potential FP16 numerical overflows during the inversion process, the computational precision of TF32 is theoretically equivalent to that of FP16. Could this discrepancy be attributed to the use of finer partitioning granularity (8x8 -> 16x16 -> 32x32 -> 64x64) during the inversion calculation?
Furthermore, ,QKV is obtained from qkv_factory,the input range for most Q, K, and V tensors appears to fall within [-0.4, 0.4]. Is this numerical range sufficient? For K, after applying L2 normalization, the data range becomes [-1, 1]. In Qwen3.5, Q also undergoes L2 normalization; consequently, when these inputs are fed into the gated_delta_rule computation, their range effectively becomes [-1, 1].
Looking forward to your reply,thanks in advance.
There was a problem hiding this comment.
@bestzsq The inversion unfortunately propagate error from diagonal (8x8) to lower left conner due to repeatedly truncation from fp32 acc to fp16 operand. Previously, blackwell version has a 128(Q)x128(K/V) block size config, this might be the root cause.
Blackwell kernel copied my hopper inversion impl strategy. I have experimented 3xFP16 and 3xTF32 inversion on hopper. They are much more accurate as you can store trunction error the second FP16 value. But due to the large kernel performance panelty (on Hopper), they are not upstreamed to FI. We may someday upstream it if proven to be needed.
There was a problem hiding this comment.
@guangyunh-nv Thanks for your reply. If the inversion calculation were to begin with 16x16 diagonal blocks, followed by 32x32 blocks, and finally 64x64 blocks, can the calculation error be reduced?
I experimented with modifying the matrix multiplication within the inversion process of the Triton chunk implementation in flash-linear-attention, changing the inputs to FP16 and the outputs to FP32. Although this involved some truncation from FP32 to FP16, the results demonstrated excellent consistency when compared to the unmodified Triton chunk implementation. However, in comparison to the current CuTe DSL implementation, triton chunk implementation utilizes sixteen 16x16 matrix multiplications for the non-diagonal blocks.
There was a problem hiding this comment.
The main problem with diagonal block processing (aka, substitution) is, it requires O(n^2) FMA to compute, that indicates O(n^2) number of instructions. So the smaller the better. 8x8 is a sweet spot as we can start to use Ampere style TC immediately after that.
There was a problem hiding this comment.
On Blackwell, due to its asynchrony nature, it is worth to explore it a little bit further. I tried 4x4 as start point on Hopper, but no further improvement. If Blackwell can tolerate 16x16 as its start point with no obvious perf panelty, I think it should be made a configurable parameter.
There was a problem hiding this comment.
@Observer007 Can the provided example be reproduced? I look forward to receiving any updates regarding this issue, thanks in advance!
There was a problem hiding this comment.
Yes, I can reproduce it now.
There was a problem hiding this comment.
BTW, do you think the precison of chunk size 128 is good enough or not?
There was a problem hiding this comment.
@bestzsq It is really a deeply hidden function bug, we have a fix in #3156 . After the fix, the mae of chunk size 128 and chunk size 64 are the same using your reproducer. Does the mae look good to you? Anyway, thanks again for the thorough inspection! And thanks for the explanation from @guangyunh-nv .
There was a problem hiding this comment.
@Observer007 Sorry for the late reply and thanks for the fix! I have tested it on some inputs, and the discrepancies between cute_dsl (chunk size 64), flash-linear-attention's chunk_gated_delta_rule, and fused_recurrent_gated_delta_rule are all within the same order of magnitude. I think the precison of chunk size 128/64 is good enough.
<!-- .github/pull_request_template.md --> ## 📌 Description Fixed the accuracy issue in blackwell gdn kernel found by @bestzsq. The root cause is that the legacy `max_coord` is not the actual last coord of the `sCumprod`. We change to the last coord instead. It's a deeply hidden bug that we hadn't discovered previously. Thanks to @bestzsq. Reproducer test link from @bestzsq: #3001 (comment) Reproducer test output before this pr: ``` # flash-linear-attention==0.4.2 fla vs cute64: mae: 2.82288e-03, ulp: 9040.0 fla vs cute128: mae: 3.05176e-05, ulp: 74.0 # flash-linear-attention==0.5.0 fla vs cute64: mae: 2.82288e-03, ulp: 9064.0 fla vs cute128: mae: 3.05176e-05, ulp: 74.0 ``` Reproducer test output after this pr: ``` # flash-linear-attention==0.4.2 fla vs cute64: mae: 3.05176e-05, ulp: 74.0 fla vs cute128: mae: 3.05176e-05, ulp: 74.0 # flash-linear-attention==0.5.0 fla vs cute64: mae: 3.05176e-05, ulp: 74.0 fla vs cute128: mae: 3.05176e-05, ulp: 74.0 ``` Previous local test tolerance loosen from `1e-3` to `2e-3` in #3001 : https://github.com/flashinfer-ai/flashinfer/pull/3001/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6L132 This pr tightenes the tolerance from `2e-3` to `1e-3`: https://github.com/flashinfer-ai/flashinfer/pull/3156/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6R148 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Refactor** * Improved kernel computation efficiency by consolidating internal calculation steps and removing redundant intermediate operations, reducing code complexity while preserving all existing functionality and performance characteristics. * **Tests** * Strengthened numerical validation by reducing tolerance thresholds in computational accuracy tests for greater precision, ensuring more stringent verification of output correctness and numerical consistency across test scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…3155) ## 📌 Description Fixes the `num_sm` issue CodeRabbit flagged on #3001 but which was not applied before merge: #3001 (comment) The raw `HardwareInfo().get_max_active_clusters(1)` call returns 0 / stale values in spawned subprocesses (e.g. vLLM's EngineCore workers) where the CUDA driver API context has not been made current yet. The persistent tile scheduler then leaves some CTAs without any work and the kernel deadlocks at first call. Switch to `get_num_sm(q.device)`, matching the SM120 MoE dispatch. ## 🔍 Related Issues ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Kernel compilation now derives device-specific SM and cluster counts at runtime, improving GPU resource allocation and leading to more consistent performance across different CUDA devices. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
## 📌 Description Addresses the two remaining CodeRabbit findings on [#3001](#3001) that weren't applied before merge: * **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend dispatch so the same call gives matching numerics on SM90 and SM100. The SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken attention. * **Don't eagerly allocate `output_state`** on the SM100 path when `output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway, so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch per call. SM90 still allocates unconditionally because its C++ kernel always writes into `output_state`. Dispatcher callsites now pass `output_state` directly on both branches (no inline `output_state if output_final_state else None`), so SM90 and SM100 read identically. ## 🔍 Related Issues * [[feat] Add blackwell GDN prefill kernel](#3001) * [fix(gdn): use physical SM count for SM100 persistent prefill kernel#3155](#3155) * [[fix] fix blackwell gdn accuracy issue#3156](#3156) ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed scale parameter handling to correctly interpret explicit values and apply default scaling behavior. * Improved memory efficiency by avoiding unnecessary state allocations in certain configurations. * **Improvements** * Enhanced consistency in kernel invocation logic across different hardware architectures. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Tests
Benchmarks