Implement Gated Delta Rule for sm_120a (Blackwell RTX)#3088
Implement Gated Delta Rule for sm_120a (Blackwell RTX)#3088guangyunh-nv wants to merge 7 commits intomainfrom
Conversation
📝 WalkthroughWalkthroughAdds Blackwell (SM120) support for the delta-rule prefill path: new SM120 Cutlass kernels and kernel-builder pieces, extern/explicit template instantiation headers and Jinja instantiation template, runtime device dispatch and JIT/module selection for SM120, and test gating for SM120/CUDA versions. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python caller
participant Loader as get_gdn_prefill_module(device)
participant JIT as _gen_gdn_prefill_module
participant Module as compiled JIT module
participant Launcher as native gdn_prefill_launcher
participant Kernel as SM90/SM120 kernel instantiation
Py->>Loader: request module for device
Loader->>JIT: select arch ("sm90" or "sm120")
JIT->>Module: compile/load arch-specific module
Loader->>Py: return Module
Py->>Module: call gdn_prefill(...)
Module->>Launcher: invoke native launcher with device_major
alt device_major == 12
Launcher->>Kernel: call SM120 instantiated kernel
else device_major == 9
Launcher->>Kernel: call SM90 instantiated kernel
else
Launcher->>Py: return error (unsupported arch)
end
Kernel-->>Launcher: execute Cutlass op / return status
Launcher-->>Module: propagate result
Module-->>Py: return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the SM120 architecture (Blackwell) in the GDN prefill delta rule kernel. The changes include new kernel implementations, JIT compilation support for SM120, and updates to the launcher to dispatch based on compute capability. Review feedback identified several critical issues: undefined variables (ptr_state_checkpoints and StateMmaRegisterRequirement), hardcoded architecture selection in the Python API that breaks SM90 compatibility, and incorrect pipeline synchronization logic in the collective mainloop where a state increment occurs before the consumer release.
| // condition happens, why? | ||
| v_pipeline.consumer_release(v_smem_pipe_read); | ||
|
|
There was a problem hiding this comment.
The synchronization logic for v_pipeline is incorrect. Incrementing v_smem_pipe_read before calling consumer_release causes the pipeline to release the next stage instead of the one that was just consumed. This violates the CUTLASS pipeline protocol, leaves the current stage unreleased (potentially causing the producer to hang), and releases the next stage prematurely, which can lead to race conditions. The increment should happen after the release, consistent with the patterns used for q_pipeline and k_pipeline.
v_pipeline.consumer_release(v_smem_pipe_read);
++v_smem_pipe_read;|
@guangyunh-nv is this only useful for sm120, not 121(Spark) |
|
@johnnynunez The kernel requires 100KB smem and TMA for loading and storing. If that feature requirement is satisfied, then there is nothing preventing the kernel from running on spark. |
f41b0f2 to
bbbba9f
Compare
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gdn/test_prefill_delta_rule.py (1)
36-52:⚠️ Potential issue | 🟡 MinorStale docstring and skip message don't mention SM120.
The helper's docstring (Line 37) still says "Skip test if not SM90 or SM100 (with CUDA 13+) architecture." and the fallback skip message (Line 52) reads "GDN prefill requires SM90 or SM100", which no longer reflects the new SM120 path.
Also worth noting:
is_sm120a_supportedalready requires CUDA 12.8+ (seeflashinfer/utils.py:563-565), so the additionalcuda_major < 13gate inside the SM120 branch is stricter than the SM120 JIT path itself (gen_gdn_prefill_sm120_moduleonly needs-DFLAT_SM120A_ENABLED+sm120a_nvcc_flags). Confirm whether CUDA 13 is really required for SM120 GDN prefill, or whether CUDA 12.8+ is sufficient (as for SM120 elsewhere in the repo).💡 Proposed doc/message fix
def _skip_if_unsupported(): - """Skip test if not SM90 or SM100 (with CUDA 13+) architecture.""" + """Skip test if not SM90, SM100 (CUDA 13+), or SM120 architecture.""" device = torch.device("cuda") ... elif not is_sm90a_supported(device): - pytest.skip("GDN prefill requires SM90 or SM100") + pytest.skip("GDN prefill requires SM90, SM100, or SM120")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 36 - 52, Update the stale docstring and skip message in _skip_if_unsupported to include SM120 alongside SM90/SM100, and remove the extra cuda_major < 13 gate in the is_sm120a_supported branch so the SM120 path relies on is_sm120a_supported's existing CUDA 12.8+ check; specifically, edit the helper _skip_if_unsupported, adjust the branch that currently checks is_sm120a_supported(device) to not re-check CUDA major version (or change it only if you confirm CUDA 13 is required), and update the final pytest.skip message to mention SM120; refer to is_sm120a_supported and gen_gdn_prefill_sm120_module when reconciling required CUDA versions.flashinfer/gdn_prefill.py (1)
191-198:⚠️ Potential issue | 🟡 MinorDocstring is stale — SM120 is now supported.
The "Note" section still advertises only SM90/SM100 while
get_gdn_prefill_modulenow also loads the SM120 JIT module. Please extend the note so users of RTX Blackwell (SM120) discover the path.💡 Proposed doc update
- - Requires SM90 (Hopper) or SM100 (Blackwell) architecture. + - Requires SM90 (Hopper), SM100 (Blackwell), or SM120 (Blackwell RTX) architecture. - SM100 path requires head_size == 128.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 191 - 198, The docstring Note is outdated: update the Note block in the docstring near get_gdn_prefill_module to include SM120 as a supported architecture (in addition to SM90 and SM100), call out the SM120 JIT module loading behavior, and state any SM120-specific requirements (e.g., head_size constraints or package/version requirements) consistent with the runtime checks in get_gdn_prefill_module so users of RTX Blackwell can discover the correct path.
🧹 Nitpick comments (10)
include/flashinfer/flat/hopper/collective/flat_common.hpp (1)
167-170: Rename appears to be cosmetic — consider whether deprecation is warranted.
restage_smem_layouthas a body identical tounstage_smem_layout(composition(layout, make_tuple(_, _, make_layout(stages)))). If the rename is purely semantic clarity ("restage" vs "unstage"), introducing a deprecation churn across the flat/delta-rule code may not justify the cost, especially since this PR's stated scope is SM120 enablement. Consider either (a) completing the migration of all callers in this PR so the deprecated symbol can be removed promptly, or (b) dropping the deprecation attribute and keepingunstage_smem_layoutas an alias until a dedicated cleanup PR.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/hopper/collective/flat_common.hpp` around lines 167 - 170, The two functions restage_smem_layout and unstage_smem_layout are identical; avoid unnecessary deprecation churn—either finish migrating all call sites to restage_smem_layout in this PR so the old symbol can be removed, or revert the deprecation and make restage_smem_layout a simple alias of unstage_smem_layout (or vice versa) until a dedicated cleanup PR can remove the duplicate; update callers consistently and remove the deprecation attribute if you choose the alias approach so no build/runtime break occurs.include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (2)
82-85: Unused aliasMathWarpGroupOrderBarrier+ a FIXME — follow-up candidate.
MathWarpGroupOrderBarrieris declared with a// FIXME: remove this after moving to HMMAbut does not appear to be used inoperator()(the code usesOrderedMathBarriers math_barriersfrom the collective instead). Consider removing now, or tracking via an issue so it doesn't rot.Want me to open a follow-up issue to remove this alias and the associated FIXME once HMMA migration lands?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp` around lines 82 - 85, The alias MathWarpGroupOrderBarrier (and its accompanying FIXME comment) is unused — locate the typedef using MathWarpGroupOrderBarrier in this file and either remove the typedef and the FIXME line or replace the FIXME with a TODO that references a new issue number; ensure you do not change usage sites (operator() uses OrderedMathBarriers), so after removal build should still compile and no symbols named MathWarpGroupOrderBarrier should remain; if you choose to keep a placeholder, add a clear TODO with an issue ID instead of the FIXME.
33-56: Namespace-scope helpers with generic names in a public header.
round_downandget_register_requirementsare defined atflat::kernelnamespace scope with quite generic names. Since this header is included by the SM120 builder and transitively by the prefill launcher, these symbols will appear in every TU that pulls it in. Consider making themstatic constexprinsideFlatKernelTmaWarpSpecializedDeltaRule, or moving into a detail namespace, to avoid accidental ODR/name collisions with similarly-named helpers elsewhere in the repo.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp` around lines 33 - 56, The helpers round_down and get_register_requirements are defined at namespace scope with generic names; move them out of the global namespace by either making them static constexpr private (or protected) member functions inside FlatKernelTmaWarpSpecializedDeltaRule or by placing them into a dedicated detail/internal namespace (e.g., flat::kernel::detail) to avoid ODR/name collisions; update all references inside FlatKernelTmaWarpSpecializedDeltaRule to call the new members or detail-qualified names and ensure their signatures remain unchanged (round_down<T1,T2> and get_register_requirements(uint32_t,uint32_t,uint32_t)).csrc/gdn_prefill_launcher.cu (2)
60-63: Nit: indentation of the#elsebranch.The
FLASHINFER_ERROR/return falseinside the#elseblocks (lines 61-62 and 76-77) are indented 4 spaces while the surrounding lambda body uses 6. Also, sinceFLASHINFER_ERRORtypically throws, the trailingreturn false;is unreachable — harmless, but consistent with the dead-code pattern you already have in the outerelseat line 83.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/gdn_prefill_launcher.cu` around lines 60 - 63, Adjust the indentation and dead-code in the `#else` branch: align the two lines containing FLASHINFER_ERROR and the following statement to match the surrounding lambda body indentation (use the same 6-space indent as the surrounding block) and remove the redundant "return false;" after FLASHINFER_ERROR (since FLASHINFER_ERROR typically throws and the return is unreachable); locate the branches by searching for the preprocessor "#else" blocks surrounding the FLASHINFER_ERROR calls in the lambda in gdn_prefill_launcher.cu.
64-78: The SM120 C++ dispatch does not include a minor-version gate; SM121 would be silently routed to the SM120 kernel if it reaches this launcher.The Python layer currently prevents SM121 from reaching
gdn_prefill_launcherbecause it gates onis_sm120a_supported, which explicitly requiresminor == 0. However, the C++ dispatcher checks onlydevice_major == 12, accepting both SM12.0 and SM12.1. If SM121 is dispatched here in the future (e.g., via direct C++ calls or relaxed Python gates), it will silently use the SM120 kernel instantiation.Consider documenting this major-only dispatch policy in a code comment, or add an explicit minor-version check alongside the existing major check for defensive clarity.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/gdn_prefill_launcher.cu` around lines 64 - 78, The dispatch in gdn_prefill_launcher currently checks only device_major == 12 and will route SM12.1 devices to the SM120 kernel; add an explicit minor-version check (e.g., ensure device_minor == 0) alongside the existing device_major == 12 guard or add a clear comment documenting the intentional major-only dispatch policy to prevent accidental routing of SM121 to flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm120,...>; update the branch that currently reads "} else if (device_major == 12) {" to either require both device_major == 12 && device_minor == 0 or to document why minors are allowed so future callers of gdn_prefill_launcher know the behavior.flashinfer/jit/gdn.py (1)
39-46: Refactor suggestion: unify arch flag selection and use a mapping.The
assert+ two-branchif/elifcan be collapsed into a small lookup, which also cleans up the RUF005 hints flagged by Ruff on Lines 44, 46, 99.♻️ Proposed refactor
- assert arch in ["sm90", "sm120"], ( - "GDN prefill kernel is only supported on sm_90a and sm_120a" - ) - - if arch == "sm90": - arch_specific_flags = sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"] - elif arch == "sm120": - arch_specific_flags = sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"] + _ARCH_FLAGS = { + "sm90": [*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED"], + "sm120": [*sm120a_nvcc_flags, "-DFLAT_SM120A_ENABLED"], + } + if arch not in _ARCH_FLAGS: + raise ValueError( + f"GDN prefill kernel is only supported on sm_90a and sm_120a, got {arch!r}" + ) + arch_specific_flags = _ARCH_FLAGS[arch] ... - extra_cuda_cflags=arch_specific_flags + ["-std=c++20"], + extra_cuda_cflags=[*arch_specific_flags, "-std=c++20"],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/gdn.py` around lines 39 - 46, Replace the assert + if/elif selection for arch with a mapping lookup to remove branching and satisfy Ruff: create a dict mapping arch strings ("sm90", "sm120") to their flag lists (e.g., {"sm90": sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"], "sm120": sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"]}), then set arch_specific_flags = mapping[arch]; raise a clear ValueError (or use mapping.get with a fallback and raise) if arch is unsupported; reference the existing variables arch, sm90a_nvcc_flags, sm120a_nvcc_flags and arch_specific_flags when making the change.include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (3)
88-94: Usebool(orconstexpr bool) for boolean-semantic options.
NeedsAlpha,NeedsBeta,NeedsDecayare declared asintbut logically boolean; they feedstd::conditional_t<...>andif constexpr (NeedsAlpha)throughout. Keeping them typed asboolmatcheskIsPersistent/kInitStateFromInput/kEnableCheckpointingdirectly above and avoids silent widening:- static constexpr int NeedsAlpha = - find_option_t<Tag::kNeedsAlpha, cute::true_type, Options>::value; - static constexpr int NeedsBeta = find_option_t<Tag::kNeedsBeta, cute::true_type, Options>::value; - - static constexpr int NeedsDecay = - find_option_t<Tag::kNeedsDecay, cute::false_type, Options>::value; + static constexpr bool NeedsAlpha = + find_option_t<Tag::kNeedsAlpha, cute::true_type, Options>::value; + static constexpr bool NeedsBeta = + find_option_t<Tag::kNeedsBeta, cute::true_type, Options>::value; + + static constexpr bool NeedsDecay = + find_option_t<Tag::kNeedsDecay, cute::false_type, Options>::value;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp` around lines 88 - 94, The three option flags NeedsAlpha, NeedsBeta, and NeedsDecay are declared as int but are used as booleans in std::conditional_t and if constexpr; change their declarations to constexpr bool (e.g., static constexpr bool NeedsAlpha = find_option_t<Tag::kNeedsAlpha, cute::true_type, Options>::value;) for NeedsAlpha, NeedsBeta, and NeedsDecay, keeping the existing default find_option_t expressions and the static_assert(!NeedsDecay, ...) intact so all boolean checks and type-dependent conditionals behave correctly.
415-433: Nit:return true &&is redundant and integer division can mis-classify heads.
- The leading
true &&on line 430 has no effect — likely a debugging leftover.ratio = num_q_heads / num_v_heads(or vice versa) uses integer division, so non-multiple ratios (e.g., 5 q vs 2 v) silently floor to 2 and the subsequentnum_q_heads == ratio * num_v_headscheck then fails, which is the correct outcome but relies on the fall-through check. Minor, but a(num_q_heads % num_v_heads == 0) || (num_v_heads % num_q_heads == 0)precondition would make the intent explicit.No functional fix required; flagging the
true &&cleanup:- return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && + return ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && (problem_size.head_size <= get<2>(TileShape{})) && ((problem_size.head_size % Alignment) == 0);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp` around lines 415 - 433, Remove the redundant "true &&" at the start of the return expression in can_implement and make the head-ratio logic explicit: compute ratio only when one head count divides the other and require (problem_size.num_q_heads % problem_size.num_v_heads == 0) || (problem_size.num_v_heads % problem_size.num_q_heads == 0) before using the integer ratio; keep the existing is_gqa_like and is_gva_like checks but only evaluate them after that divisibility precondition to avoid relying on implicit integer truncation in the ratio calculation.
146-147: Stray empty statement.using SmemLayoutV_SD = decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{})); - ;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp` around lines 146 - 147, Remove the stray empty statement after the type alias: in the declaration of SmemLayoutV_SD (using SmemLayoutV_SD = decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{}));) delete the extra trailing semicolon so the line ends with a single semicolon; no other changes to restage_smem_layout, SmemLayoutK_SD, Int or StagesV usages are required.include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh (1)
137-151: Consider surfacing the CUTLASS status in the error messages.
throw std::runtime_error("can_implement failed")etc. discard the actualcutlass::Statuscode, which makes triage painful when these fire in the field (e.g.,kErrorWorkspaceNullvs.kErrorInvalidProblemvs.kErrorNotSupportedall surface as the same string). Easy win:Proposed refactor
- status = op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("can_implement failed"); - } + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error(std::string("can_implement failed: ") + + cutlassGetStatusString(status)); + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh` around lines 137 - 151, Replace the three generic exceptions so they include the CUTLASS status value/string from the local variable status when any of op.can_implement(arguments), op.initialize(arguments, workspace_buffer, stream) or op.run(stream) return non-success; update the thrown std::runtime_error messages in the blocks referencing status and the call that failed (op.can_implement, op.initialize, op.run) to include the status code or a human-readable status string to aid triage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/gdn_prefill_sm120_kernel_inst.jinja`:
- Around line 31-33: Comment incorrectly references the SM90 extern file; update
the comment above the explicit template instantiation for
launch_delta_rule_prefill_kernel_gbai to reference the SM120 extern file name
(flat_prefill_kernel_delta_rule_sm120_extern.inc) and/or adjust wording to match
cutlass::arch::Sm120 and the current template instantiation line (template void
launch_delta_rule_prefill_kernel_gbai<... , cutlass::arch::Sm120, ... >(...));
keep the rest of the comment intact so it clearly states that parameter types
must match the extern template declaration for SM120.
In `@flashinfer/jit/gdn.py`:
- Around line 33-38: Update the module docstring in flashinfer/jit/gdn.py to
reflect the correct count: change the phrase "32 separate kernel instantiation
files (2 dtypes × 16 boolean combinations)" to "64 separate kernel instantiation
files (2 dtypes × 32 boolean combinations)" (and keep the note about the
original launcher file) to match the itertools.product([False, True], repeat=5)
loop used to generate kernel variants and the inline comment near the loop.
In `@include/flashinfer/flat/hopper/collective/flat_common.hpp`:
- Around line 95-99: The deprecated function unstage_smem_layout is still called
from flat_collective_tma_warpspecialized_delta_rule (for SmemLayoutQ_SD,
SmemLayoutK_DS, SmemLayoutV_DS) which will trigger deprecation warnings/errors;
fix by either updating those call sites to use restage_smem_layout or make
unstage_smem_layout delegate to restage_smem_layout so callers remain valid. If
delegating, ensure restage_smem_layout is declared/defined before
unstage_smem_layout (or add a forward declaration) and implement
unstage_smem_layout to simply call and return restage_smem_layout(layout,
stages) to keep behavior identical.
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 438-440: Remove the unused local variable `t` in the
`to_underlying_arguments` function: delete the line `int64_t t =
problem_size.total_seqlen;` and ensure no remaining code references `t`; if the
intent was to represent a different extent than `s` (e.g., total keys vs
values), replace usages with the correct variable instead of adding `t`. This
eliminates the dead variable and avoids implying separate K/V extents while
keeping `s` (`int64_t s = problem_size.total_seqlen;`) and `d` unchanged.
- Around line 1158-1197: The local BlkSeqKV is wrongly shadowing the class-level
alias by using get<0>(TileShape{}) (should be get<1>), and ckpt_blk_interval can
become zero causing a div-by-zero when used as a modulus; fix by (1) changing
the local declaration to use get<1>(TileShape{}) or removing the local and
referencing the existing BlkSeqKV alias so compute_loop_body/kv_checkpoint_store
use the correct KV tile size, and (2) ensure checkpoint interval cannot be zero
by either adding a can_implement check that params.checkpoint_every_n_tokens > 0
&& params.checkpoint_every_n_tokens % BlkSeqKV == 0 or by guarding all modulus
uses with a runtime check (kEnableCheckpointing && ckpt_blk_interval > 0) before
doing (blk+1) % ckpt_blk_interval and num_blocks % ckpt_blk_interval.
- Around line 186-187: Add a brief inline comment immediately above the typedefs
InverseType and CollectiveInverse explaining that InverseType is intentionally
fixed to cutlass::half_t because CollectiveInverse requires fp16, and that
explicit conversions handle cases where Element (e.g., bfloat16_t) differs;
reference the conversion logic that handles Element != InverseType so readers
know the precision trade-off is deliberate and localized to this hot path.
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 403-406: The debug printf references an undefined symbol
StateMmaRegisterRequirement; change that to the correct MmaRegisterRequirement
used by this kernel. Edit the DPRINTF0_WG call in
flat_kernel_tma_warpspecialized_delta_rule.hpp (the branch handling
WarpGroupRole::Math0/Math1) to print MmaRegisterRequirement instead of
StateMmaRegisterRequirement so the identifier matches the template instantiation
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>() and avoids the
undeclared identifier error.
- Around line 108-110: Remove the unused ClusterBarrier field from the
SharedStorage struct: delete the declaration of cutlass::arch::ClusterBarrier
load_warp_barrier so SharedStorage no longer reserves the unused 16 bytes;
verify operator() and related code references (none expected) compile cleanly
after removal and mirror the same change in the Hopper variant of the file to
keep both versions consistent.
---
Outside diff comments:
In `@flashinfer/gdn_prefill.py`:
- Around line 191-198: The docstring Note is outdated: update the Note block in
the docstring near get_gdn_prefill_module to include SM120 as a supported
architecture (in addition to SM90 and SM100), call out the SM120 JIT module
loading behavior, and state any SM120-specific requirements (e.g., head_size
constraints or package/version requirements) consistent with the runtime checks
in get_gdn_prefill_module so users of RTX Blackwell can discover the correct
path.
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 36-52: Update the stale docstring and skip message in
_skip_if_unsupported to include SM120 alongside SM90/SM100, and remove the extra
cuda_major < 13 gate in the is_sm120a_supported branch so the SM120 path relies
on is_sm120a_supported's existing CUDA 12.8+ check; specifically, edit the
helper _skip_if_unsupported, adjust the branch that currently checks
is_sm120a_supported(device) to not re-check CUDA major version (or change it
only if you confirm CUDA 13 is required), and update the final pytest.skip
message to mention SM120; refer to is_sm120a_supported and
gen_gdn_prefill_sm120_module when reconciling required CUDA versions.
---
Nitpick comments:
In `@csrc/gdn_prefill_launcher.cu`:
- Around line 60-63: Adjust the indentation and dead-code in the `#else` branch:
align the two lines containing FLASHINFER_ERROR and the following statement to
match the surrounding lambda body indentation (use the same 6-space indent as
the surrounding block) and remove the redundant "return false;" after
FLASHINFER_ERROR (since FLASHINFER_ERROR typically throws and the return is
unreachable); locate the branches by searching for the preprocessor "#else"
blocks surrounding the FLASHINFER_ERROR calls in the lambda in
gdn_prefill_launcher.cu.
- Around line 64-78: The dispatch in gdn_prefill_launcher currently checks only
device_major == 12 and will route SM12.1 devices to the SM120 kernel; add an
explicit minor-version check (e.g., ensure device_minor == 0) alongside the
existing device_major == 12 guard or add a clear comment documenting the
intentional major-only dispatch policy to prevent accidental routing of SM121 to
flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm120,...>; update the
branch that currently reads "} else if (device_major == 12) {" to either require
both device_major == 12 && device_minor == 0 or to document why minors are
allowed so future callers of gdn_prefill_launcher know the behavior.
In `@flashinfer/jit/gdn.py`:
- Around line 39-46: Replace the assert + if/elif selection for arch with a
mapping lookup to remove branching and satisfy Ruff: create a dict mapping arch
strings ("sm90", "sm120") to their flag lists (e.g., {"sm90": sm90a_nvcc_flags +
["-DFLAT_SM90A_ENABLED"], "sm120": sm120a_nvcc_flags +
["-DFLAT_SM120A_ENABLED"]}), then set arch_specific_flags = mapping[arch]; raise
a clear ValueError (or use mapping.get with a fallback and raise) if arch is
unsupported; reference the existing variables arch, sm90a_nvcc_flags,
sm120a_nvcc_flags and arch_specific_flags when making the change.
In `@include/flashinfer/flat/hopper/collective/flat_common.hpp`:
- Around line 167-170: The two functions restage_smem_layout and
unstage_smem_layout are identical; avoid unnecessary deprecation churn—either
finish migrating all call sites to restage_smem_layout in this PR so the old
symbol can be removed, or revert the deprecation and make restage_smem_layout a
simple alias of unstage_smem_layout (or vice versa) until a dedicated cleanup PR
can remove the duplicate; update callers consistently and remove the deprecation
attribute if you choose the alias approach so no build/runtime break occurs.
In `@include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh`:
- Around line 137-151: Replace the three generic exceptions so they include the
CUTLASS status value/string from the local variable status when any of
op.can_implement(arguments), op.initialize(arguments, workspace_buffer, stream)
or op.run(stream) return non-success; update the thrown std::runtime_error
messages in the blocks referencing status and the call that failed
(op.can_implement, op.initialize, op.run) to include the status code or a
human-readable status string to aid triage.
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 88-94: The three option flags NeedsAlpha, NeedsBeta, and
NeedsDecay are declared as int but are used as booleans in std::conditional_t
and if constexpr; change their declarations to constexpr bool (e.g., static
constexpr bool NeedsAlpha = find_option_t<Tag::kNeedsAlpha, cute::true_type,
Options>::value;) for NeedsAlpha, NeedsBeta, and NeedsDecay, keeping the
existing default find_option_t expressions and the static_assert(!NeedsDecay,
...) intact so all boolean checks and type-dependent conditionals behave
correctly.
- Around line 415-433: Remove the redundant "true &&" at the start of the return
expression in can_implement and make the head-ratio logic explicit: compute
ratio only when one head count divides the other and require
(problem_size.num_q_heads % problem_size.num_v_heads == 0) ||
(problem_size.num_v_heads % problem_size.num_q_heads == 0) before using the
integer ratio; keep the existing is_gqa_like and is_gva_like checks but only
evaluate them after that divisibility precondition to avoid relying on implicit
integer truncation in the ratio calculation.
- Around line 146-147: Remove the stray empty statement after the type alias: in
the declaration of SmemLayoutV_SD (using SmemLayoutV_SD =
decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{}));) delete
the extra trailing semicolon so the line ends with a single semicolon; no other
changes to restage_smem_layout, SmemLayoutK_SD, Int or StagesV usages are
required.
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 82-85: The alias MathWarpGroupOrderBarrier (and its accompanying
FIXME comment) is unused — locate the typedef using MathWarpGroupOrderBarrier in
this file and either remove the typedef and the FIXME line or replace the FIXME
with a TODO that references a new issue number; ensure you do not change usage
sites (operator() uses OrderedMathBarriers), so after removal build should still
compile and no symbols named MathWarpGroupOrderBarrier should remain; if you
choose to keep a placeholder, add a clear TODO with an issue ID instead of the
FIXME.
- Around line 33-56: The helpers round_down and get_register_requirements are
defined at namespace scope with generic names; move them out of the global
namespace by either making them static constexpr private (or protected) member
functions inside FlatKernelTmaWarpSpecializedDeltaRule or by placing them into a
dedicated detail/internal namespace (e.g., flat::kernel::detail) to avoid
ODR/name collisions; update all references inside
FlatKernelTmaWarpSpecializedDeltaRule to call the new members or
detail-qualified names and ensure their signatures remain unchanged
(round_down<T1,T2> and get_register_requirements(uint32_t,uint32_t,uint32_t)).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 065d402b-afa1-4498-b401-f39437844f53
📒 Files selected for processing (14)
csrc/flat_prefill_kernel_delta_rule_sm120_extern.inccsrc/gdn_prefill_launcher.cucsrc/gdn_prefill_sm120_kernel_inst.jinjacsrc/prefill_kernel_delta_rule_sm120.cuflashinfer/gdn_prefill.pyflashinfer/jit/gdn.pyinclude/flashinfer/flat/hopper/collective/flat_common.hppinclude/flashinfer/flat/prefill/prefill_kernel.hppinclude/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuhinclude/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hppinclude/flashinfer/flat/sm120/device/device_universal.hppinclude/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hppinclude/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpptests/gdn/test_prefill_delta_rule.py
| """Generate JIT module for GDN prefill kernel with separate compilation. | ||
|
|
||
| This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations) | ||
| plus the original launcher file. The separate files enable parallel compilation by ninja, | ||
| significantly reducing build time on multi-core machines. | ||
| """ |
There was a problem hiding this comment.
Docstring count is stale: 32 combinations, not 16.
The docstring says "32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)", but the loop at Line 69 uses itertools.product([False, True], repeat=5) → 32 combinations, and the in-line comment at Line 60 already says "2 dtypes × 32 boolean combinations". So the total is 64 files (2 × 32), which matches the Line 60 comment.
💡 Proposed fix
- This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)
- plus the original launcher file. The separate files enable parallel compilation by ninja,
+ This generates 64 separate kernel instantiation files (2 dtypes x 32 boolean combinations)
+ plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.🧰 Tools
🪛 Ruff (0.15.10)
[warning] 35-35: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/jit/gdn.py` around lines 33 - 38, Update the module docstring in
flashinfer/jit/gdn.py to reflect the correct count: change the phrase "32
separate kernel instantiation files (2 dtypes × 16 boolean combinations)" to "64
separate kernel instantiation files (2 dtypes × 32 boolean combinations)" (and
keep the note about the original launcher file) to match the
itertools.product([False, True], repeat=5) loop used to generate kernel variants
and the inline comment near the loop.
There was a problem hiding this comment.
This is copied from SM90 version, we should open a seperate PR to fix them all in once.
There was a problem hiding this comment.
@guangyunh-nv, understood — deferring the docstring fix to a follow-up PR that cleans it up across both the SM90 and SM120 versions makes sense. Would you like me to open a GitHub issue to track this fix so it doesn't get lost?
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Each `gen_*_module()` function in `flashinfer/jit/` must: (1) compute unique URI from parameters, (2) create generation directory, (3) optionally render Jinja template, (4) copy source files, (5) return JitSpec
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.
Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:114-116
Timestamp: 2026-03-27T20:33:11.994Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (flashinfer-ai/flashinfer), `_compute_optimal_warps_for_k` must receive `sf_blocks_per_warp` as an explicit parameter (not use the global `SF_BLOCKS_PER_WARP=16` constant). The `MXFP8QuantizeSwizzledKernel` constructor calls it with `self._sf_blocks_per_warp`, which is `SF_BLOCKS_PER_WARP=16` when `use_2t_per_sf=True` and `SF_BLOCKS_PER_WARP_SMALL=8` when `use_2t_per_sf=False`. Using the wrong constant causes fractional `rows_per_block` (e.g., K=3072 4T/SF: 30 warps → 960 threads → 2.5 truncated to 2 → write race from excess threads overlapping the next block's first row). MXFP4 and NVFP4 are unaffected because they use 1 thread per SF block with no multi-thread variant.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py:343-380
Timestamp: 2026-04-14T19:10:27.074Z
Learning: In `flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py` (flashinfer-ai/flashinfer PR `#3066`), `MoEDynamicKernel._setup_attributes()` intentionally omits the full SMEM post-check loop present in `MoEStaticKernel`. The `_compute_stages` output is already conservatively clamped (`max(1, min(ab_stage, 4))`) and further reduced by the k_tile_cnt divisibility check, yielding ab_stage=1 or 2 in all tested configurations — well within SM120's 232KB SMEM budget even with the extra staged sB_up/sSFB_up pair. A proper `_shared_storage_size_bytes()` for the dynamic kernel's different struct layout would be needed for a full post-check; the maintainer deferred this to a follow-up. Do not re-flag the missing post-check as a bug.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:384-385
Timestamp: 2026-03-27T20:51:45.564Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (`MXFP8QuantizeSwizzledKernel`, small-K path), the padding-column zeroing in the swizzled small-K path requires a thread-stride loop, not a simple predicated write. Because `sf_col_idx = local_tidx // _threads_per_sf` is bounded by `[0, num_sf_blocks_per_row)`, a bare `if sf_col_idx >= num_sf_blocks_per_row` guard is unreachable. The correct pattern (matching MXFP4/NVFP4 swizzled kernels) is:
- Padding rows: loop starting at `sf_col_idx`, striding by `num_sf_blocks_per_row`, up to `padded_sf_cols`.
- Real rows: loop starting at `num_sf_blocks_per_row + sf_col_idx`, striding by `num_sf_blocks_per_row`, guarded by `const_expr(self.num_sf_blocks_per_row != self.padded_sf_cols)` so it is eliminated at compile time when `K/32` is a multiple of 4 (no column padding needed).
Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes × 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3026
File: include/flashinfer/gemm/fp4_gemm_template_sm120.h:267-270
Timestamp: 2026-04-09T21:51:00.268Z
Learning: In flashinfer-ai/flashinfer, `include/flashinfer/gemm/fp4_gemm_template_sm120.h` is gated by `#define FLASHINFER_ENABLE_SM120` and is only included from `fp4_gemm_cutlass_template_sm120.h`, which is compiled exclusively for SM120/SM121 targets. Adding a runtime `Sm12xOnly` architecture guard inside this file is redundant — there is no code path that instantiates these kernels on non-SM12x hardware. Do not suggest adding such guards to this file.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use SHA256 hashing for source files and include operation type, parameters, compilation flags, and CUDA architecture in URI computation for cache invalidation
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3080
File: flashinfer/fused_moe/cute_dsl/b12x_moe.py:48-49
Timestamp: 2026-04-16T01:51:16.398Z
Learning: In flashinfer-ai/flashinfer, only use `backend_requirement` when an API dispatches across multiple backends. For single-backend, architecture-gated APIs that exclusively target a specific compute capability (e.g., SM120/SM121, such as `b12x_fused_moe` / `B12xMoEWrapper` in `flashinfer/fused_moe/cute_dsl/b12x_moe.py`), prefer and keep `supported_compute_capability([120, 121])` instead of suggesting a replacement with `backend_requirement`.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/fused_moe.py:206-220
Timestamp: 2026-04-14T19:11:17.176Z
Learning: In `flashinfer/fused_moe/cute_dsl/fused_moe.py` (flashinfer-ai/flashinfer PR `#3066`), the SM120/SM121 dispatch paths (`_moe_core_impl`, `CuteDslMoEWrapper.run`, and `cute_dsl_fused_moe_nvfp4`) intentionally do NOT forward `local_expert_offset` to `launch_sm120_moe`. Expert Parallelism (EP) is unsupported on SM120: the dynamic kernel (`MoEDynamicKernel`) lacks `global_to_local_expert` remapping, and EP tests are gated to SM100-only via `sm100_only`. Passing `local_expert_offset` without kernel-side support would silently produce incorrect results. Do not flag the missing `local_expert_offset` propagation in SM120 call sites as a bug.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.
Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: include/flashinfer/norm/ln_silu_headers.cuh:258-270
Timestamp: 2026-04-03T20:17:43.361Z
Learning: In `include/flashinfer/norm/ln_silu_headers.cuh`, the pre-SM80 `#else` branch inside `struct Converter<float2, nv_bfloat162>::convert` (the union-based fallback) is intentionally dead code. `fused_rmsnorm_silu` requires SM80+ at runtime, so the `#if __CUDA_ARCH__ >= 800` path (using `__float22bfloat162_rn`) is the only path that ever compiles. Do not flag the union member aliasing issue in the `#else` branch as a bug.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use `functools.cache` decorator on JIT module generator functions to implement Python-level module caching
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to {include/flashinfer/**/*.cuh,csrc/**/*.cu} : For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh:232-262
Timestamp: 2026-04-02T18:45:38.854Z
Learning: In `include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh` (flashinfer-ai/flashinfer PR `#2962`), the per-step `state_dst_slots` precompute has three mutually exclusive branches:
1. `dst_state_batch_indices` present → always write unless index == pad_slot_id (caller controls slots via pad_slot_id; no update_state gating needed).
2. `intermediate_states` present → always cache every step (no update_state gating needed).
3. Neither → only write at last step when params.update_state is true.
`intermediate_states_buffer` and `dst_state_batch_indices` are enforced mutually exclusive by a Python-side ValueError in `flashinfer/mamba/selective_state_update.py`. Do not flag the absence of `update_state` gating in branches 1 and 2 as a bug.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).
| template <class Layout, class Stages = _1> | ||
| [[deprecated("use restage_smem_layout instead")]] | ||
| CUTE_DEVICE constexpr auto unstage_smem_layout(Layout const& layout, Stages stages = {}) { | ||
| return composition(layout, make_tuple(_, _, make_layout(stages))); | ||
| } |
There was a problem hiding this comment.
Deprecation will fire on existing in-tree callers; migrate them or delegate.
unstage_smem_layout is marked [[deprecated]], but include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (lines 157, 158, 160) still invokes it three times to build SmemLayoutQ_SD, SmemLayoutK_DS, and SmemLayoutV_DS. Depending on warning configuration (-Werror=deprecated-declarations), this will either spam build warnings or break the build. Either migrate those call sites to restage_smem_layout in this PR, or have the deprecated function delegate to the new one to keep behavior in lockstep.
♻️ Suggested delegation to avoid body duplication
template <class Layout, class Stages = _1>
[[deprecated("use restage_smem_layout instead")]]
CUTE_DEVICE constexpr auto unstage_smem_layout(Layout const& layout, Stages stages = {}) {
- return composition(layout, make_tuple(_, _, make_layout(stages)));
+ return restage_smem_layout(layout, stages);
}Note this requires moving restage_smem_layout above unstage_smem_layout (or adding a forward declaration).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/flat/hopper/collective/flat_common.hpp` around lines 95 -
99, The deprecated function unstage_smem_layout is still called from
flat_collective_tma_warpspecialized_delta_rule (for SmemLayoutQ_SD,
SmemLayoutK_DS, SmemLayoutV_DS) which will trigger deprecation warnings/errors;
fix by either updating those call sites to use restage_smem_layout or make
unstage_smem_layout delegate to restage_smem_layout so callers remain valid. If
delegating, ensure restage_smem_layout is declared/defined before
unstage_smem_layout (or add a forward declaration) and implement
unstage_smem_layout to simply call and return restage_smem_layout(layout,
stages) to keep behavior identical.
|
|
||
| alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; | ||
| }; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check for any use of load_warp_barrier across the repo.
rg -nP -C2 '\bload_warp_barrier\b'Repository: flashinfer-ai/flashinfer
Length of output: 1217
🏁 Script executed:
#!/bin/bash
# Get the full context of the SharedStorage struct and operator() in the SM120 file
cat -n include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | sed -n '80,150p'Repository: flashinfer-ai/flashinfer
Length of output: 3081
🏁 Script executed:
#!/bin/bash
# Check if load_warp_barrier appears in any operator() or method implementations
rg -A 50 'operator\(\)' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2069
🏁 Script executed:
#!/bin/bash
# Check if there are any references to storage through collective operations or other indirect access
rg 'storage\.' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1026
🏁 Script executed:
#!/bin/bash
# Search for any usage patterns involving load_warp_barrier
# Check for constructor initialization, or other patterns
rg '(load_warp_barrier|&.*load_warp|::arrive|->arrive)' include/flashinfer/flat/sm120/kernel/Repository: flashinfer-ai/flashinfer
Length of output: 216
🏁 Script executed:
#!/bin/bash
# Check if there are other similar files (Hopper version) that might provide context
wc -l include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
# Show lines around load_warp_barrier in Hopper version too
sed -n '110,130p' include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hppRepository: flashinfer-ai/flashinfer
Length of output: 918
🏁 Script executed:
#!/bin/bash
# Check if ClusterBarrier is used elsewhere in the SM120 file (besides declaration)
rg 'ClusterBarrier|arrive_and_wait' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hppRepository: flashinfer-ai/flashinfer
Length of output: 132
🏁 Script executed:
#!/bin/bash
# Check the Hopper file's operator() to see if load_warp_barrier is used there
rg -A 150 'operator\(\)' include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | grep -E '(load_warp|storage\.)' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 794
Remove load_warp_barrier from SharedStorage — it's declared but never initialized or used.
The member cutlass::arch::ClusterBarrier load_warp_barrier (line 109) appears nowhere else in the codebase. All other storage members in SharedStorage are initialized and consumed in operator(), but this one is dead code. It wastes 16 bytes of aligned shared memory. The same unused member exists in the Hopper version of the file as well.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 108 - 110, Remove the unused ClusterBarrier field from the
SharedStorage struct: delete the declaration of cutlass::arch::ClusterBarrier
load_warp_barrier so SharedStorage no longer reserves the unused 16 bytes;
verify operator() and related code references (none expected) compile cleanly
after removal and mirror the same change in the Hopper variant of the file to
keep both versions consistent.
|
I think that they have the same SMEM. i am with other things now |
|
@yzh119 @jiahanc This is basically downgraded from SM90. SM90 implemented a 3-staged pipeline where LD/ST warp in a WG and state unrelated computation in a WG and state related computation is in a third group of WGs. For SM120, we merged computation WGs into one group (2 WGs) due to register pressure from HMMA operands. Otherwise, they are largely the same. NOTE: this kernel can (should) be further downgrade from TMA to LDGSTS with no warp specialization, then it can be run on SM80 and SM89 ampere and ada devices, this is left as a homework for the community :) |
There was a problem hiding this comment.
Maybe it is not the scope of this PR, just share it here. The current sm90 prefill path looks like:
flowchart LR
q -->|l2norm| qfi
k -->|l2norm| kfi
g -->|"exp + .float32"| alpha
beta -->|".float32"| beta_fi
indices -->|"where(≥0, idx, last_slot)\n.int64()"| idx64
idx64 --> gather["gather + .float32"]
pool["ssm_states\n(pool)"] --> gather
gather --> init
qfi & kfi & v & alpha & beta_fi & init --> K["chunk_gated_delta_rule"]
K --> out_state -->|"index_copy_(idx64)"| pool2["ssm_states\n(pool)"]
Ideally, we could at least fuse the pool indexing and exp(g) parts.
| if is_sm90a_supported(device): | ||
| module = gen_gdn_prefill_sm90_module().build_and_load() | ||
| elif is_sm120a_supported(device): | ||
| module = gen_gdn_prefill_sm120_module().build_and_load() |
There was a problem hiding this comment.
In the test, we mention SM120 GDN prefill requires CUDA 13+, but here we only check is_sm120a_supported that checks 12.8 under the hood. Seems like contradicting to each other. What is the minimum recomended cuda version for the sm120 prefill kernels?
There was a problem hiding this comment.
I am actually not quite sure what is the recommended version. According to https://docs.nvidia.com/cuda/archive/12.8.1/blackwell-compatibility-guide/index.html
CUDA applications built using CUDA Toolkit 12.8 are compatible with Blackwell architecture as long as they are built to include kernels in native cubin (compute capability 10.0) or PTX form or both.
So I think we are safe to relax the requirement to 12.8.
There was a problem hiding this comment.
Relaxed to 12.8 in test
Yes, sm120/121 have the same 100KB SMEM (99 usable, I think?). |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (1)
109-110:⚠️ Potential issue | 🟡 MinorUnused
load_warp_barrierwastes 16 B of aligned smem.
cutlass::arch::ClusterBarrier load_warp_barrieris declared here but never referenced inoperator()or anywhere else in the codebase. Safe to delete.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp` around lines 109 - 110, The struct contains an unused member alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier which wastes 16 bytes of shared memory; remove the declaration of load_warp_barrier from the struct (the member declared alongside other smem members) and rebuild to confirm no references remain (search for load_warp_barrier and operator() to ensure it isn't used).include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (1)
1158-1197:⚠️ Potential issue | 🟡 Minor
BlkSeqKVshadow +ckpt_blk_interval == 0div-by-zero still present.Pre-existing issue carried in from the SM90 path (and acknowledged to be fixed in a separate PR per the prior discussion): the local
BlkSeqKVat Line 1158 is actuallyget<0>(TileShape{})in the SM90 variant, and the% ckpt_blk_intervaluses at Lines 1181 and 1192 have no guard against a zero divisor whenkEnableCheckpointingis true butcheckpoint_every_n_tokens < BlkSeqKV(or simply not a multiple of it). The SM120 version here uses the correct class-levelBlkSeqKV, but the divisor guard is still missing. Worth tracking alongside the SM90 cleanup so this kernel doesn't ship with latent UB if a caller ever passes a misconfigured checkpoint stride.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp` around lines 1158 - 1197, The code computes ckpt_blk_interval from checkpoint_every_n_tokens / BlkSeqKV and then uses it as a divisor in modulo checks (in the loop calling compute_loop_body and kv_checkpoint_store), which can be zero and cause UB; change the logic around ckpt_blk_interval (and its use sites) so that when kEnableCheckpointing is true you either (a) compute ckpt_blk_interval with a safe floor of 1 (e.g., if checkpoint_every_n_tokens < BlkSeqKV then set ckpt_blk_interval = 1) or (b) guard every use with an explicit check that ckpt_blk_interval > 0 before performing (blk+1) % ckpt_blk_interval or num_blocks % ckpt_blk_interval; update the initialization near BlkSeqKV and the checkpointing branches that call kv_checkpoint_store to use the safe non-zero value or the guard so no modulo by zero can occur.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 394-402: The StoreO branch (when ldst_warp_role ==
LdStWarpRole::StoreO) currently calls scheduler.get_next_work(...) once and
directly invokes CollectiveMainloop::store, which diverges from sibling roles
and fails to iterate or guard on work_desc.is_valid(...). Change this to mirror
the other roles: obtain the first work_desc via scheduler.get_next_work(...),
then drain the scheduler with a for loop like for (; work_desc.is_valid(...);
work_desc = scheduler.get_next_work(...)) and call
collective_mainloop.store(...) inside that loop (ensuring any o_pipeline
producer/consumer handshakes remain inside the loop) so StoreO processes every
valid task and skips invalid descriptors.
---
Duplicate comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 1158-1197: The code computes ckpt_blk_interval from
checkpoint_every_n_tokens / BlkSeqKV and then uses it as a divisor in modulo
checks (in the loop calling compute_loop_body and kv_checkpoint_store), which
can be zero and cause UB; change the logic around ckpt_blk_interval (and its use
sites) so that when kEnableCheckpointing is true you either (a) compute
ckpt_blk_interval with a safe floor of 1 (e.g., if checkpoint_every_n_tokens <
BlkSeqKV then set ckpt_blk_interval = 1) or (b) guard every use with an explicit
check that ckpt_blk_interval > 0 before performing (blk+1) % ckpt_blk_interval
or num_blocks % ckpt_blk_interval; update the initialization near BlkSeqKV and
the checkpointing branches that call kv_checkpoint_store to use the safe
non-zero value or the guard so no modulo by zero can occur.
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 109-110: The struct contains an unused member alignas(16)
cutlass::arch::ClusterBarrier load_warp_barrier which wastes 16 bytes of shared
memory; remove the declaration of load_warp_barrier from the struct (the member
declared alongside other smem members) and rebuild to confirm no references
remain (search for load_warp_barrier and operator() to ensure it isn't used).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 813d7454-be91-4dfe-9f52-689d0621f455
📒 Files selected for processing (5)
csrc/gdn_prefill_sm120_kernel_inst.jinjaflashinfer/gdn_prefill.pyinclude/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hppinclude/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpptests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (2)
- tests/gdn/test_prefill_delta_rule.py
- csrc/gdn_prefill_sm120_kernel_inst.jinja
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/gdn_prefill.py
| } else if (ldst_warp_role == LdStWarpRole::StoreO) { | ||
| auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); | ||
| DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n", | ||
| work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len); | ||
| auto tile_shape = typename CollectiveMainloop::TileShape{}; | ||
| collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps, | ||
| params.problem_size, tile_shape, work_desc, o_pipeline, | ||
| o_smem_pipe_read, storage.tensors.mainloop.smem_o); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Confirm the scheduler-loop asymmetry and check the hopper counterpart for reference.
rg -nP -C2 'get_next_work|is_valid' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
echo "---- hopper equivalent ----"
fd -a 'flat_kernel_tma_warpspecialized_delta_rule.hpp' include/flashinfer/flat/hopper | xargs -I{} sh -c 'echo "=== {} ==="; rg -nP -C2 "get_next_work|is_valid" {}'Repository: flashinfer-ai/flashinfer
Length of output: 6062
StoreO warp does not iterate the scheduler, unlike all sibling roles.
LoadQKV (Lines 352–365), LoadBeta (Lines 366–379), LoadAlpha (Lines 380–393) and the math warp groups (Lines 407–419) all drain the scheduler via for (; work_desc.is_valid(...); work_desc = scheduler.get_next_work(...)). The StoreO branch calls scheduler.get_next_work exactly once and then invokes collective_mainloop.store(...) without a loop.
With the current non-persistent scheduler this is observationally fine (one work descriptor per block), but:
- Behavior silently diverges the moment any persistent / multi-tile scheduler is wired in —
LoadQKV/math produce outputs for every task whileStoreOonly drains the first one, so later tasks deadlock ono_pipeline.producer_commit(consumer never arrives) and/or publish wrong data. - It also means
StoreOdoesn't even checkwork_desc.is_valid(...)before consuming the first descriptor, so on an empty grid assignment it would operate on an invalidwork_desc.
Mirror the other roles' scheduler loop here (and guard on is_valid).
🛠️ Suggested fix
} else if (ldst_warp_role == LdStWarpRole::StoreO) {
- auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size);
- DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n",
- work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len);
- auto tile_shape = typename CollectiveMainloop::TileShape{};
- collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps,
- params.problem_size, tile_shape, work_desc, o_pipeline,
- o_smem_pipe_read, storage.tensors.mainloop.smem_o);
+ auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size);
+ CUTE_NO_UNROLL
+ for (; work_desc.is_valid(params.scheduler);
+ work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) {
+ DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n",
+ work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len);
+ auto tile_shape = typename CollectiveMainloop::TileShape{};
+ collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps,
+ params.problem_size, tile_shape, work_desc, o_pipeline,
+ o_smem_pipe_read, storage.tensors.mainloop.smem_o);
+ }
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 394 - 402, The StoreO branch (when ldst_warp_role ==
LdStWarpRole::StoreO) currently calls scheduler.get_next_work(...) once and
directly invokes CollectiveMainloop::store, which diverges from sibling roles
and fails to iterate or guard on work_desc.is_valid(...). Change this to mirror
the other roles: obtain the first work_desc via scheduler.get_next_work(...),
then drain the scheduler with a for loop like for (; work_desc.is_valid(...);
work_desc = scheduler.get_next_work(...)) and call
collective_mainloop.store(...) inside that loop (ensuring any o_pipeline
producer/consumer handshakes remain inside the loop) so StoreO processes every
valid task and skips invalid descriptors.
|
@kaixih As for exp(g), I'd like address it together with hopper version. Otherwise, we will need to handle it different in the python wrapper before calling into kernel launch wrapper. |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (1)
416-432:⚠️ Potential issue | 🟠 MajorReject invalid checkpoint intervals before device modulus.
checkpoint_every_n_tokensis documented as a multiple ofBlkSeqKV, butcan_implementdoes not enforce it. Values like1..63makeckpt_blk_interval == 0, and Lines 1180/1191 can execute% ckpt_blk_interval; non-multiples can also make the kernel write checkpoints at a different cadence than the caller likely allocated.Proposed fix
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { auto ratio = problem_size.num_q_heads > problem_size.num_v_heads ? problem_size.num_q_heads / problem_size.num_v_heads : problem_size.num_v_heads / problem_size.num_q_heads; constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value; + bool checkpoint_interval_valid = true; + if constexpr (kEnableCheckpointing) { + checkpoint_interval_valid = args.checkpoint_every_n_tokens > 0 && + args.checkpoint_every_n_tokens % int(BlkSeqKV) == 0; + } bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && (problem_size.num_q_heads == ratio * problem_size.num_v_heads); @@ return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && (problem_size.head_size <= get<2>(TileShape{})) && - ((problem_size.head_size % Alignment) == 0); + ((problem_size.head_size % Alignment) == 0) && checkpoint_interval_valid; } @@ compute_loop_body(blk, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::false_type{}); if constexpr (kEnableCheckpointing) { - if ((blk + 1) % ckpt_blk_interval == 0) { + if (ckpt_blk_interval > 0 && (blk + 1) % ckpt_blk_interval == 0) { kv_checkpoint_store(ckpt_count++); } } @@ // is always available via output_state from kv_store() below. if constexpr (kEnableCheckpointing) { - if (num_blocks % ckpt_blk_interval == 0) { + if (ckpt_blk_interval > 0 && num_blocks % ckpt_blk_interval == 0) { kv_checkpoint_store(ckpt_count); } }Also applies to: 1157-1193
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp` around lines 416 - 432, can_implement must validate that checkpoint_every_n_tokens is a positive multiple of BlkSeqKV before any device-modulus or kernel-related checks; add an early guard in can_implement that reads the checkpoint_every_n_tokens from Arguments (or the appropriate arg field), computes ckpt_blk_interval = checkpoint_every_n_tokens / BlkSeqKV and returns false if checkpoint_every_n_tokens == 0 or (checkpoint_every_n_tokens % BlkSeqKV) != 0 (or ckpt_blk_interval == 0), so the function rejects invalid intervals before performing the rest of the head/Alignment/device checks (refer to can_implement, Arguments, checkpoint_every_n_tokens, BlkSeqKV, and ckpt_blk_interval).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 924-980: The three state-transfer lambdas kv_load, kv_store, and
kv_checkpoint_store currently construct tiled_copy_kv with
Copy_Atom<AutoVectorizingCopy, Element>{}, but Element may be fp16/bf16 while
the source/destination global buffers (ptr_input_state, ptr_output_state,
ptr_state_checkpoints) and accumulator tKVrKV are float; change the copy atom to
use float (i.e. Copy_Atom<AutoVectorizingCopy, float>{}) in each occurrence so
the memory-to-accumulator transfers use the float copy path and avoid
truncation; apply the identical change in the hopper variant’s corresponding
lambdas as well.
---
Duplicate comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 416-432: can_implement must validate that
checkpoint_every_n_tokens is a positive multiple of BlkSeqKV before any
device-modulus or kernel-related checks; add an early guard in can_implement
that reads the checkpoint_every_n_tokens from Arguments (or the appropriate arg
field), computes ckpt_blk_interval = checkpoint_every_n_tokens / BlkSeqKV and
returns false if checkpoint_every_n_tokens == 0 or (checkpoint_every_n_tokens %
BlkSeqKV) != 0 (or ckpt_blk_interval == 0), so the function rejects invalid
intervals before performing the rest of the head/Alignment/device checks (refer
to can_implement, Arguments, checkpoint_every_n_tokens, BlkSeqKV, and
ckpt_blk_interval).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 91d0a547-abf7-40f1-8cfc-829a56508e16
📒 Files selected for processing (1)
include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
| auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { | ||
| DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, | ||
| v_head_idx); | ||
| int num_state_heads = problem_size.num_sab_heads; | ||
| int state_head_idx = work_desc.o_head_idx(); | ||
| auto gKV = make_tensor(make_gmem_ptr(params.ptr_input_state), | ||
| make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{}, | ||
| num_state_heads, problem_size.num_seqs)))( | ||
| _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous | ||
|
|
||
| auto tiled_copy_kv = | ||
| make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma); | ||
| auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); | ||
|
|
||
| auto tKVgKV = thr_copy_kv.partition_S(select_tensor<1, 0>(gKV)); | ||
| copy(tiled_copy_kv, tKVgKV, tKVrKV); | ||
| }; | ||
|
|
||
| auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop | ||
| DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, | ||
| v_head_idx); | ||
| int num_state_heads = problem_size.num_sab_heads; | ||
| int state_head_idx = work_desc.o_head_idx(); // num_o_heads == num_sab_heads | ||
| auto gKV = make_tensor(make_gmem_ptr(params.ptr_output_state), | ||
| make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{}, | ||
| num_state_heads, problem_size.num_seqs)))( | ||
| _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous | ||
|
|
||
| auto tiled_copy_kv = | ||
| make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma); | ||
| auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); | ||
|
|
||
| auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV)); | ||
| copy(tiled_copy_kv, tKVrKV, tKVgKV); | ||
| }; | ||
|
|
||
| auto kv_checkpoint_store = [&](int checkpoint_idx) INLINE_LAMBDA { | ||
| if constexpr (kEnableCheckpointing) { | ||
| DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> checkpoint[%d]\n", seq_idx, q_head_idx, | ||
| k_head_idx, v_head_idx, checkpoint_idx); | ||
| int num_state_heads = problem_size.num_sab_heads; | ||
| int state_head_idx = work_desc.o_head_idx(); | ||
| int64_t ckpt_offset = params.checkpoint_cu_starts[seq_idx] + checkpoint_idx; | ||
|
|
||
| // Layout: [total_checkpoints, num_sab_heads, HeadSizeQK, HeadSizeV] LayoutLeft | ||
| auto gKV = | ||
| make_tensor(make_gmem_ptr(params.ptr_state_checkpoints + | ||
| ckpt_offset * num_state_heads * HeadSizeQK * HeadSizeV + | ||
| state_head_idx * HeadSizeQK * HeadSizeV), | ||
| make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{}))); | ||
|
|
||
| auto tiled_copy_kv = | ||
| make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma); | ||
| auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); | ||
|
|
||
| auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV)); | ||
| copy(tiled_copy_kv, tKVrKV, tKVgKV); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Compare state pointer types and state copy atom types in delta-rule code.
# Expected: copies to/from float state tensors should use float/accumulator copy typing, not Element.
rg -nP -C4 'ptr_(input|output)_state|ptr_state_checkpoints|Copy_Atom<AutoVectorizingCopy,\s*(Element|float|ElementAccumulatorKV)' -g '*delta_rule*'Repository: flashinfer-ai/flashinfer
Length of output: 30808
🏁 Script executed:
#!/bin/bash
# Find the template definition and Element type in the file
head -100 include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "template|Element|struct|class" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 777
🏁 Script executed:
#!/bin/bash
# Look for Element type definition and template parameters in this specific file
sed -n '1,100p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | tail -50Repository: flashinfer-ai/flashinfer
Length of output: 2338
🏁 Script executed:
#!/bin/bash
# Search for how Element is defined as a template parameter
rg "template.*Element" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -A 5 -B 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 467
🏁 Script executed:
#!/bin/bash
# Check if there are any float Copy_Atom usages in delta-rule files for comparison
rg "Copy_Atom<.*float" include/flashinfer/flat/ -g '*delta_rule*' -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for ElementAccumulatorKV usage in copy operations to understand the correct pattern
rg "Copy_Atom.*ElementAccumulatorKV|Copy_Atom.*float" include/flashinfer/flat/ -g '*delta_rule*' -B 3 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check what tKVrKV is - look for its declaration/type
rg "tKVrKV\s*=" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 5 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 426
🏁 Script executed:
#!/bin/bash
# See the broader context around the kv_load/store lambdas to understand tKVrKV type
sed -n '900,940p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hppRepository: flashinfer-ai/flashinfer
Length of output: 1887
🏁 Script executed:
#!/bin/bash
# Check the hopper variant to see if it has the same issue
sed -n '780,850p' include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hppRepository: flashinfer-ai/flashinfer
Length of output: 3571
🏁 Script executed:
#!/bin/bash
# Look for TiledMmaKV definition to understand the copy atom better
rg "using TiledMmaKV" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 466
🏁 Script executed:
#!/bin/bash
# Check if there's ElementAccumulatorKV being used anywhere for state copies
rg "ElementAccumulatorKV" include/flashinfer/flat/ -g '*delta_rule*' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1442
🏁 Script executed:
#!/bin/bash
# Look at how kv_tiled_mma is used with copy atoms in the same file
rg "make_tiled_copy.*kv_tiled_mma" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 910
🏁 Script executed:
#!/bin/bash
# Look at how Q/K/V copies use copy atoms in the same file for comparison
rg "make_tiled_copy.*Copy_Atom.*Element" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 1 -A 1 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1921
🏁 Script executed:
#!/bin/bash
# Check the partition_fragment_C to understand what tKVrKV really is
rg "partition_fragment_C.*kv_thr_mma" include/flashinfer/flat/ -g '*delta_rule*' -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 1512
🏁 Script executed:
#!/bin/bash
# Verify that partition_fragment_C creates a float accumulator by checking MMA definition
sed -n '300,400p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "TiledMmaKV|MmaOp|ElementAccumulatorKV" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Double-check: Look for how the state pointers are cast and used to confirm they're actually float*
rg "ptr_input_state|ptr_output_state|ptr_state_checkpoints" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | head -10Repository: flashinfer-ai/flashinfer
Length of output: 617
🏁 Script executed:
#!/bin/bash
# Look at the hopper variant to see if it has the same pattern and issue
sed -n '780,850p' include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -A 20 "kv_load ="Repository: flashinfer-ai/flashinfer
Length of output: 1192
🏁 Script executed:
#!/bin/bash
# Verify Element can be fp16/bf16 by looking at instantiations
rg "FlatMainloopTmaWarpSpecializedDeltaRule<" include/flashinfer/flat/ -g '*.cuh' | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check if there are any instantiations with explicit dtype parameters
rg "FlatMainloopTmaWarpSpecializedDeltaRule<.*__half|FlatMainloopTmaWarpSpecializedDeltaRule<.*bf16|FlatMainloopTmaWarpSpecializedDeltaRule<.*float" include/flashinfer/flat/ -g '*.cuh'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look at the actual kernel launcher to see what Element types are used
rg "FlatMainloopTmaWarpSpecializedDeltaRule" include/flashinfer/flat/sm120/ -g '*.cuh' -B 5 -A 5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Search for where FlatMainloopTmaWarpSpecializedDeltaRule is instantiated with specific types
find include/flashinfer/flat/sm120 -name "*.cuh" -o -name "*.hpp" | xargs grep -l "FlatMainloopTmaWarpSpecializedDeltaRule"Repository: flashinfer-ai/flashinfer
Length of output: 231
🏁 Script executed:
#!/bin/bash
# Check the kernel launcher files to see instantiations
cat include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.cuh 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for how the kernel is called with what Element types
rg "launch_flat_collective_tma_warpspecialized_delta_rule" include/flashinfer -g '*.cuh' -B 5 -A 5 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/an/bash
# Check what the actual data types are in the Python API
find . -name "*.py" -type f -exec grep -l "delta_rule" {} \; | head -5Repository: flashinfer-ai/flashinfer
Length of output: 573
🏁 Script executed:
#!/bin/bash
# Check the kernel builder to understand how the kernel is instantiated
sed -n '1,100p' include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp | grep -E "Element|struct|template" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 459
🏁 Script executed:
#!/bin/bash
# Look at how the kernel is actually built with what types
rg "make_kernel.*DeltaRule|FlatMainloopTmaWarpSpecializedDeltaRule" include/flashinfer/flat/sm120/kernel/ -B 3 -A 3 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1029
🏁 Script executed:
#!/bin/bash
# Check the Python interface to understand what dtypes are supported
grep -r "float16\|fp16\|bf16\|float32" benchmarks/bench_gdn_*.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1919
🏁 Script executed:
#!/bin/bash
# Look at the gdn_decode.py to see what Element types can be passed
sed -n '1,150p' flashinfer/gdn_decode.py | grep -E "dtype|Element|float" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 759
🏁 Script executed:
#!/bin/bash
# Confirm that ElementAccumulatorKV is always float and check the struct definition
sed -n '1,100p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "using ElementAccumulatorKV|static_assert.*ElementAccumulatorKV"Repository: flashinfer-ai/flashinfer
Length of output: 183
🏁 Script executed:
#!/bin/bash
# Check if there's an issue with copying float to Element when Element is fp16/bf16
# Look at the CUTLASS documentation by checking what AutoVectorizingCopy expects
rg "AutoVectorizingCopy" include/flashinfer/ -B 2 -A 2 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 3303
Use float copy atom for float state tensors.
The state transfer operations (kv_load, kv_store, kv_checkpoint_store) copy between float global memory pointers (ptr_input_state, ptr_output_state, ptr_state_checkpoints) and the float accumulator fragment tKVrKV. Using Copy_Atom<AutoVectorizingCopy, Element>{} is incorrect because Element can be fp16/bf16 at compile time, which would select an incompatible copy type and cause data truncation when transferring float data.
Update all three occurrences (lines 935, 953, 976) to use float:
Proposed fix
auto tiled_copy_kv =
- make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+ make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);
@@
auto tiled_copy_kv =
- make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+ make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);
@@
auto tiled_copy_kv =
- make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+ make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);The same issue exists in the hopper variant at lines 800, 818, 841 of include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 924 - 980, The three state-transfer lambdas kv_load, kv_store, and
kv_checkpoint_store currently construct tiled_copy_kv with
Copy_Atom<AutoVectorizingCopy, Element>{}, but Element may be fp16/bf16 while
the source/destination global buffers (ptr_input_state, ptr_output_state,
ptr_state_checkpoints) and accumulator tKVrKV are float; change the copy atom to
use float (i.e. Copy_Atom<AutoVectorizingCopy, float>{}) in each occurrence so
the memory-to-accumulator transfers use the float copy path and avoid
truncation; apply the identical change in the hopper variant’s corresponding
lambdas as well.
| bool needs_beta = beta != nullptr; | ||
| bool needs_alpha = alpha != nullptr; | ||
| bool init_state = input_state != nullptr; | ||
| bool enable_ckpt = checkpoint_every_n_tokens > 0; |
There was a problem hiding this comment.
As far as I can tell, this is only supported by the SM120 kernel, not the SM90 kernel, is it correct?
There was a problem hiding this comment.
They are all supported.
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Behavior Changes
Deprecations
Tests
Documentation