Skip to content

Implement Gated Delta Rule for sm_120a (Blackwell RTX)#3088

Open
guangyunh-nv wants to merge 7 commits intomainfrom
guangyunh/sm_120a
Open

Implement Gated Delta Rule for sm_120a (Blackwell RTX)#3088
guangyunh-nv wants to merge 7 commits intomainfrom
guangyunh/sm_120a

Conversation

@guangyunh-nv
Copy link
Copy Markdown
Collaborator

@guangyunh-nv guangyunh-nv commented Apr 16, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added NVIDIA SM120 (Blackwell) support for delta-rule prefill kernels and JIT module generation.
  • Behavior Changes

    • Runtime dispatch and JIT loading now select device-specific kernels (SM90 vs SM120) and emit clear errors when a requested kernel isn't available.
  • Deprecations

    • Marked unstage_smem_layout as deprecated; use restage_smem_layout.
  • Tests

    • Updated test gating to include SM120 capability and CUDA version checks.
  • Documentation

    • Public docstring updated to list SM120 as a supported architecture.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

📝 Walkthrough

Walkthrough

Adds Blackwell (SM120) support for the delta-rule prefill path: new SM120 Cutlass kernels and kernel-builder pieces, extern/explicit template instantiation headers and Jinja instantiation template, runtime device dispatch and JIT/module selection for SM120, and test gating for SM120/CUDA versions.

Changes

Cohort / File(s) Summary
SM120 Kernel Implementation
include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh, csrc/prefill_kernel_delta_rule_sm120.cu
New Cutlass-backed SM120 delta-rule launcher; computes runtime booleans, selects compile-time instantiation across 5 flags, prepares Cutlass Arguments, and provides explicit outer-instantiations for half and nv_bfloat16.
SM120 Collective & Kernel Types
include/flashinfer/flat/sm120/collective/.../flat_collective_tma_warpspecialized_delta_rule.hpp, include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp, include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp
New SM120-specific collective mainloop, warp-specialized kernel implementation, and FlatBuilder specialization composing kernel and option types.
Device Re-export
include/flashinfer/flat/sm120/device/device_universal.hpp
Re-exports hopper device_universal implementation under SM120 include path for kernel instantiation.
Extern Templates & Instantiation Templates
csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc, csrc/gdn_prefill_sm120_kernel_inst.jinja
Added extern template declarations enumerating all 5-boolean combos for half and nv_bfloat16; Jinja template to emit per-arch explicit instantiation translation units.
Dispatcher / Launcher
csrc/gdn_prefill_launcher.cu
Adds device_major == 12 branch to dispatch SM120-instantiated kernel (guarded by FLAT_SM120A_ENABLED); localizes SM90 compile-time guard to SM90 branch; emits clear runtime error if requested arch not built.
JIT Generation & Python Loader
flashinfer/jit/gdn.py, flashinfer/gdn_prefill.py
Refactored JIT generator to _gen_gdn_prefill_module(arch) that selects arch-specific NVCC flags and templates; added gen_gdn_prefill_sm120_module(); get_gdn_prefill_module(device) now chooses SM90 or SM120 module at runtime and calls module APIs with device context.
Public Headers / Deprecation
include/flashinfer/flat/prefill/prefill_kernel.hpp, include/flashinfer/flat/hopper/collective/flat_common.hpp
Added forward declaration for cutlass::arch::Sm120; deprecated unstage_smem_layout and added restage_smem_layout.
Tests
tests/gdn/test_prefill_delta_rule.py
Extended test gating to detect SM120 support and require minimum CUDA version (CUDA ≥ 12.8) for SM120 tests.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python caller
    participant Loader as get_gdn_prefill_module(device)
    participant JIT as _gen_gdn_prefill_module
    participant Module as compiled JIT module
    participant Launcher as native gdn_prefill_launcher
    participant Kernel as SM90/SM120 kernel instantiation

    Py->>Loader: request module for device
    Loader->>JIT: select arch ("sm90" or "sm120")
    JIT->>Module: compile/load arch-specific module
    Loader->>Py: return Module
    Py->>Module: call gdn_prefill(...)
    Module->>Launcher: invoke native launcher with device_major
    alt device_major == 12
        Launcher->>Kernel: call SM120 instantiated kernel
    else device_major == 9
        Launcher->>Kernel: call SM90 instantiated kernel
    else
        Launcher->>Py: return error (unsupported arch)
    end
    Kernel-->>Launcher: execute Cutlass op / return status
    Launcher-->>Module: propagate result
    Module-->>Py: return
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • aleozlx
  • cyx-6
  • jimmyzho
  • sricketts

"🐰 I hopped in with templates tight,
Blackwell kernels born tonight,
Dispatch picks arch, kernels compile,
Externs keep instantiation in style,
Checkpoints saved — hop, code takes flight!"

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description uses the provided template but leaves the critical '📌 Description' and '🔍 Related Issues' sections empty, providing no substantive explanation of the changes or their rationale. Fill in the 'Description' section with details about what the PR does and why, and the 'Related Issues' section with any relevant issue links.
Docstring Coverage ⚠️ Warning Docstring coverage is 8.82% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: implementing a Gated Delta Rule for SM 120a (Blackwell RTX) architecture.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch guangyunh/sm_120a

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the SM120 architecture (Blackwell) in the GDN prefill delta rule kernel. The changes include new kernel implementations, JIT compilation support for SM120, and updates to the launcher to dispatch based on compute capability. Review feedback identified several critical issues: undefined variables (ptr_state_checkpoints and StateMmaRegisterRequirement), hardcoded architecture selection in the Python API that breaks SM90 compatibility, and incorrect pipeline synchronization logic in the collective mainloop where a state increment occurs before the consumer release.

Comment thread include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh Outdated
Comment thread flashinfer/gdn_prefill.py Outdated
Comment on lines +1083 to +1085
// condition happens, why?
v_pipeline.consumer_release(v_smem_pipe_read);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The synchronization logic for v_pipeline is incorrect. Incrementing v_smem_pipe_read before calling consumer_release causes the pipeline to release the next stage instead of the one that was just consumed. This violates the CUTLASS pipeline protocol, leaves the current stage unreleased (potentially causing the producer to hang), and releases the next stage prematurely, which can lead to race conditions. The increment should happen after the release, consistent with the patterns used for q_pipeline and k_pipeline.

      v_pipeline.consumer_release(v_smem_pipe_read);
      ++v_smem_pipe_read;

@johnnynunez
Copy link
Copy Markdown
Contributor

@guangyunh-nv is this only useful for sm120, not 121(Spark)

@guangyunh-nv
Copy link
Copy Markdown
Collaborator Author

@johnnynunez The kernel requires 100KB smem and TMA for loading and storing. If that feature requirement is satisfied, then there is nothing preventing the kernel from running on spark.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gdn/test_prefill_delta_rule.py (1)

36-52: ⚠️ Potential issue | 🟡 Minor

Stale docstring and skip message don't mention SM120.

The helper's docstring (Line 37) still says "Skip test if not SM90 or SM100 (with CUDA 13+) architecture." and the fallback skip message (Line 52) reads "GDN prefill requires SM90 or SM100", which no longer reflects the new SM120 path.

Also worth noting: is_sm120a_supported already requires CUDA 12.8+ (see flashinfer/utils.py:563-565), so the additional cuda_major < 13 gate inside the SM120 branch is stricter than the SM120 JIT path itself (gen_gdn_prefill_sm120_module only needs -DFLAT_SM120A_ENABLED + sm120a_nvcc_flags). Confirm whether CUDA 13 is really required for SM120 GDN prefill, or whether CUDA 12.8+ is sufficient (as for SM120 elsewhere in the repo).

💡 Proposed doc/message fix
 def _skip_if_unsupported():
-    """Skip test if not SM90 or SM100 (with CUDA 13+) architecture."""
+    """Skip test if not SM90, SM100 (CUDA 13+), or SM120 architecture."""
     device = torch.device("cuda")
     ...
     elif not is_sm90a_supported(device):
-        pytest.skip("GDN prefill requires SM90 or SM100")
+        pytest.skip("GDN prefill requires SM90, SM100, or SM120")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 36 - 52, Update the stale
docstring and skip message in _skip_if_unsupported to include SM120 alongside
SM90/SM100, and remove the extra cuda_major < 13 gate in the is_sm120a_supported
branch so the SM120 path relies on is_sm120a_supported's existing CUDA 12.8+
check; specifically, edit the helper _skip_if_unsupported, adjust the branch
that currently checks is_sm120a_supported(device) to not re-check CUDA major
version (or change it only if you confirm CUDA 13 is required), and update the
final pytest.skip message to mention SM120; refer to is_sm120a_supported and
gen_gdn_prefill_sm120_module when reconciling required CUDA versions.
flashinfer/gdn_prefill.py (1)

191-198: ⚠️ Potential issue | 🟡 Minor

Docstring is stale — SM120 is now supported.

The "Note" section still advertises only SM90/SM100 while get_gdn_prefill_module now also loads the SM120 JIT module. Please extend the note so users of RTX Blackwell (SM120) discover the path.

💡 Proposed doc update
-        - Requires SM90 (Hopper) or SM100 (Blackwell) architecture.
+        - Requires SM90 (Hopper), SM100 (Blackwell), or SM120 (Blackwell RTX) architecture.
         - SM100 path requires head_size == 128.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 191 - 198, The docstring Note is
outdated: update the Note block in the docstring near get_gdn_prefill_module to
include SM120 as a supported architecture (in addition to SM90 and SM100), call
out the SM120 JIT module loading behavior, and state any SM120-specific
requirements (e.g., head_size constraints or package/version requirements)
consistent with the runtime checks in get_gdn_prefill_module so users of RTX
Blackwell can discover the correct path.
🧹 Nitpick comments (10)
include/flashinfer/flat/hopper/collective/flat_common.hpp (1)

167-170: Rename appears to be cosmetic — consider whether deprecation is warranted.

restage_smem_layout has a body identical to unstage_smem_layout (composition(layout, make_tuple(_, _, make_layout(stages)))). If the rename is purely semantic clarity ("restage" vs "unstage"), introducing a deprecation churn across the flat/delta-rule code may not justify the cost, especially since this PR's stated scope is SM120 enablement. Consider either (a) completing the migration of all callers in this PR so the deprecated symbol can be removed promptly, or (b) dropping the deprecation attribute and keeping unstage_smem_layout as an alias until a dedicated cleanup PR.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/flat/hopper/collective/flat_common.hpp` around lines 167 -
170, The two functions restage_smem_layout and unstage_smem_layout are
identical; avoid unnecessary deprecation churn—either finish migrating all call
sites to restage_smem_layout in this PR so the old symbol can be removed, or
revert the deprecation and make restage_smem_layout a simple alias of
unstage_smem_layout (or vice versa) until a dedicated cleanup PR can remove the
duplicate; update callers consistently and remove the deprecation attribute if
you choose the alias approach so no build/runtime break occurs.
include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (2)

82-85: Unused alias MathWarpGroupOrderBarrier + a FIXME — follow-up candidate.

MathWarpGroupOrderBarrier is declared with a // FIXME: remove this after moving to HMMA but does not appear to be used in operator() (the code uses OrderedMathBarriers math_barriers from the collective instead). Consider removing now, or tracking via an issue so it doesn't rot.

Want me to open a follow-up issue to remove this alias and the associated FIXME once HMMA migration lands?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 82 - 85, The alias MathWarpGroupOrderBarrier (and its accompanying
FIXME comment) is unused — locate the typedef using MathWarpGroupOrderBarrier in
this file and either remove the typedef and the FIXME line or replace the FIXME
with a TODO that references a new issue number; ensure you do not change usage
sites (operator() uses OrderedMathBarriers), so after removal build should still
compile and no symbols named MathWarpGroupOrderBarrier should remain; if you
choose to keep a placeholder, add a clear TODO with an issue ID instead of the
FIXME.

33-56: Namespace-scope helpers with generic names in a public header.

round_down and get_register_requirements are defined at flat::kernel namespace scope with quite generic names. Since this header is included by the SM120 builder and transitively by the prefill launcher, these symbols will appear in every TU that pulls it in. Consider making them static constexpr inside FlatKernelTmaWarpSpecializedDeltaRule, or moving into a detail namespace, to avoid accidental ODR/name collisions with similarly-named helpers elsewhere in the repo.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 33 - 56, The helpers round_down and get_register_requirements are
defined at namespace scope with generic names; move them out of the global
namespace by either making them static constexpr private (or protected) member
functions inside FlatKernelTmaWarpSpecializedDeltaRule or by placing them into a
dedicated detail/internal namespace (e.g., flat::kernel::detail) to avoid
ODR/name collisions; update all references inside
FlatKernelTmaWarpSpecializedDeltaRule to call the new members or
detail-qualified names and ensure their signatures remain unchanged
(round_down<T1,T2> and get_register_requirements(uint32_t,uint32_t,uint32_t)).
csrc/gdn_prefill_launcher.cu (2)

60-63: Nit: indentation of the #else branch.

The FLASHINFER_ERROR / return false inside the #else blocks (lines 61-62 and 76-77) are indented 4 spaces while the surrounding lambda body uses 6. Also, since FLASHINFER_ERROR typically throws, the trailing return false; is unreachable — harmless, but consistent with the dead-code pattern you already have in the outer else at line 83.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/gdn_prefill_launcher.cu` around lines 60 - 63, Adjust the indentation
and dead-code in the `#else` branch: align the two lines containing
FLASHINFER_ERROR and the following statement to match the surrounding lambda
body indentation (use the same 6-space indent as the surrounding block) and
remove the redundant "return false;" after FLASHINFER_ERROR (since
FLASHINFER_ERROR typically throws and the return is unreachable); locate the
branches by searching for the preprocessor "#else" blocks surrounding the
FLASHINFER_ERROR calls in the lambda in gdn_prefill_launcher.cu.

64-78: The SM120 C++ dispatch does not include a minor-version gate; SM121 would be silently routed to the SM120 kernel if it reaches this launcher.

The Python layer currently prevents SM121 from reaching gdn_prefill_launcher because it gates on is_sm120a_supported, which explicitly requires minor == 0. However, the C++ dispatcher checks only device_major == 12, accepting both SM12.0 and SM12.1. If SM121 is dispatched here in the future (e.g., via direct C++ calls or relaxed Python gates), it will silently use the SM120 kernel instantiation.

Consider documenting this major-only dispatch policy in a code comment, or add an explicit minor-version check alongside the existing major check for defensive clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/gdn_prefill_launcher.cu` around lines 64 - 78, The dispatch in
gdn_prefill_launcher currently checks only device_major == 12 and will route
SM12.1 devices to the SM120 kernel; add an explicit minor-version check (e.g.,
ensure device_minor == 0) alongside the existing device_major == 12 guard or add
a clear comment documenting the intentional major-only dispatch policy to
prevent accidental routing of SM121 to
flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm120,...>; update the
branch that currently reads "} else if (device_major == 12) {" to either require
both device_major == 12 && device_minor == 0 or to document why minors are
allowed so future callers of gdn_prefill_launcher know the behavior.
flashinfer/jit/gdn.py (1)

39-46: Refactor suggestion: unify arch flag selection and use a mapping.

The assert + two-branch if/elif can be collapsed into a small lookup, which also cleans up the RUF005 hints flagged by Ruff on Lines 44, 46, 99.

♻️ Proposed refactor
-    assert arch in ["sm90", "sm120"], (
-        "GDN prefill kernel is only supported on sm_90a and sm_120a"
-    )
-
-    if arch == "sm90":
-        arch_specific_flags = sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"]
-    elif arch == "sm120":
-        arch_specific_flags = sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"]
+    _ARCH_FLAGS = {
+        "sm90": [*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED"],
+        "sm120": [*sm120a_nvcc_flags, "-DFLAT_SM120A_ENABLED"],
+    }
+    if arch not in _ARCH_FLAGS:
+        raise ValueError(
+            f"GDN prefill kernel is only supported on sm_90a and sm_120a, got {arch!r}"
+        )
+    arch_specific_flags = _ARCH_FLAGS[arch]
...
-        extra_cuda_cflags=arch_specific_flags + ["-std=c++20"],
+        extra_cuda_cflags=[*arch_specific_flags, "-std=c++20"],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gdn.py` around lines 39 - 46, Replace the assert + if/elif
selection for arch with a mapping lookup to remove branching and satisfy Ruff:
create a dict mapping arch strings ("sm90", "sm120") to their flag lists (e.g.,
{"sm90": sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"], "sm120": sm120a_nvcc_flags
+ ["-DFLAT_SM120A_ENABLED"]}), then set arch_specific_flags = mapping[arch];
raise a clear ValueError (or use mapping.get with a fallback and raise) if arch
is unsupported; reference the existing variables arch, sm90a_nvcc_flags,
sm120a_nvcc_flags and arch_specific_flags when making the change.
include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (3)

88-94: Use bool (or constexpr bool) for boolean-semantic options.

NeedsAlpha, NeedsBeta, NeedsDecay are declared as int but logically boolean; they feed std::conditional_t<...> and if constexpr (NeedsAlpha) throughout. Keeping them typed as bool matches kIsPersistent/kInitStateFromInput/kEnableCheckpointing directly above and avoids silent widening:

-  static constexpr int NeedsAlpha =
-      find_option_t<Tag::kNeedsAlpha, cute::true_type, Options>::value;
-  static constexpr int NeedsBeta = find_option_t<Tag::kNeedsBeta, cute::true_type, Options>::value;
-
-  static constexpr int NeedsDecay =
-      find_option_t<Tag::kNeedsDecay, cute::false_type, Options>::value;
+  static constexpr bool NeedsAlpha =
+      find_option_t<Tag::kNeedsAlpha, cute::true_type, Options>::value;
+  static constexpr bool NeedsBeta =
+      find_option_t<Tag::kNeedsBeta, cute::true_type, Options>::value;
+
+  static constexpr bool NeedsDecay =
+      find_option_t<Tag::kNeedsDecay, cute::false_type, Options>::value;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 88 - 94, The three option flags NeedsAlpha, NeedsBeta, and
NeedsDecay are declared as int but are used as booleans in std::conditional_t
and if constexpr; change their declarations to constexpr bool (e.g., static
constexpr bool NeedsAlpha = find_option_t<Tag::kNeedsAlpha, cute::true_type,
Options>::value;) for NeedsAlpha, NeedsBeta, and NeedsDecay, keeping the
existing default find_option_t expressions and the static_assert(!NeedsDecay,
...) intact so all boolean checks and type-dependent conditionals behave
correctly.

415-433: Nit: return true && is redundant and integer division can mis-classify heads.

  • The leading true && on line 430 has no effect — likely a debugging leftover.
  • ratio = num_q_heads / num_v_heads (or vice versa) uses integer division, so non-multiple ratios (e.g., 5 q vs 2 v) silently floor to 2 and the subsequent num_q_heads == ratio * num_v_heads check then fails, which is the correct outcome but relies on the fall-through check. Minor, but a (num_q_heads % num_v_heads == 0) || (num_v_heads % num_q_heads == 0) precondition would make the intent explicit.

No functional fix required; flagging the true && cleanup:

-    return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
+    return ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
            (problem_size.head_size <= get<2>(TileShape{})) &&
            ((problem_size.head_size % Alignment) == 0);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 415 - 433, Remove the redundant "true &&" at the start of the
return expression in can_implement and make the head-ratio logic explicit:
compute ratio only when one head count divides the other and require
(problem_size.num_q_heads % problem_size.num_v_heads == 0) ||
(problem_size.num_v_heads % problem_size.num_q_heads == 0) before using the
integer ratio; keep the existing is_gqa_like and is_gva_like checks but only
evaluate them after that divisibility precondition to avoid relying on implicit
integer truncation in the ratio calculation.

146-147: Stray empty statement.

   using SmemLayoutV_SD = decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{}));
-  ;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 146 - 147, Remove the stray empty statement after the type alias:
in the declaration of SmemLayoutV_SD (using SmemLayoutV_SD =
decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{}));) delete
the extra trailing semicolon so the line ends with a single semicolon; no other
changes to restage_smem_layout, SmemLayoutK_SD, Int or StagesV usages are
required.
include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh (1)

137-151: Consider surfacing the CUTLASS status in the error messages.

throw std::runtime_error("can_implement failed") etc. discard the actual cutlass::Status code, which makes triage painful when these fire in the field (e.g., kErrorWorkspaceNull vs. kErrorInvalidProblem vs. kErrorNotSupported all surface as the same string). Easy win:

Proposed refactor
-    status = op.can_implement(arguments);
-    if (status != cutlass::Status::kSuccess) {
-      throw std::runtime_error("can_implement failed");
-    }
+    status = op.can_implement(arguments);
+    if (status != cutlass::Status::kSuccess) {
+      throw std::runtime_error(std::string("can_implement failed: ") +
+                               cutlassGetStatusString(status));
+    }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh` around
lines 137 - 151, Replace the three generic exceptions so they include the
CUTLASS status value/string from the local variable status when any of
op.can_implement(arguments), op.initialize(arguments, workspace_buffer, stream)
or op.run(stream) return non-success; update the thrown std::runtime_error
messages in the blocks referencing status and the call that failed
(op.can_implement, op.initialize, op.run) to include the status code or a
human-readable status string to aid triage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/gdn_prefill_sm120_kernel_inst.jinja`:
- Around line 31-33: Comment incorrectly references the SM90 extern file; update
the comment above the explicit template instantiation for
launch_delta_rule_prefill_kernel_gbai to reference the SM120 extern file name
(flat_prefill_kernel_delta_rule_sm120_extern.inc) and/or adjust wording to match
cutlass::arch::Sm120 and the current template instantiation line (template void
launch_delta_rule_prefill_kernel_gbai<... , cutlass::arch::Sm120, ... >(...));
keep the rest of the comment intact so it clearly states that parameter types
must match the extern template declaration for SM120.

In `@flashinfer/jit/gdn.py`:
- Around line 33-38: Update the module docstring in flashinfer/jit/gdn.py to
reflect the correct count: change the phrase "32 separate kernel instantiation
files (2 dtypes × 16 boolean combinations)" to "64 separate kernel instantiation
files (2 dtypes × 32 boolean combinations)" (and keep the note about the
original launcher file) to match the itertools.product([False, True], repeat=5)
loop used to generate kernel variants and the inline comment near the loop.

In `@include/flashinfer/flat/hopper/collective/flat_common.hpp`:
- Around line 95-99: The deprecated function unstage_smem_layout is still called
from flat_collective_tma_warpspecialized_delta_rule (for SmemLayoutQ_SD,
SmemLayoutK_DS, SmemLayoutV_DS) which will trigger deprecation warnings/errors;
fix by either updating those call sites to use restage_smem_layout or make
unstage_smem_layout delegate to restage_smem_layout so callers remain valid. If
delegating, ensure restage_smem_layout is declared/defined before
unstage_smem_layout (or add a forward declaration) and implement
unstage_smem_layout to simply call and return restage_smem_layout(layout,
stages) to keep behavior identical.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 438-440: Remove the unused local variable `t` in the
`to_underlying_arguments` function: delete the line `int64_t t =
problem_size.total_seqlen;` and ensure no remaining code references `t`; if the
intent was to represent a different extent than `s` (e.g., total keys vs
values), replace usages with the correct variable instead of adding `t`. This
eliminates the dead variable and avoids implying separate K/V extents while
keeping `s` (`int64_t s = problem_size.total_seqlen;`) and `d` unchanged.
- Around line 1158-1197: The local BlkSeqKV is wrongly shadowing the class-level
alias by using get<0>(TileShape{}) (should be get<1>), and ckpt_blk_interval can
become zero causing a div-by-zero when used as a modulus; fix by (1) changing
the local declaration to use get<1>(TileShape{}) or removing the local and
referencing the existing BlkSeqKV alias so compute_loop_body/kv_checkpoint_store
use the correct KV tile size, and (2) ensure checkpoint interval cannot be zero
by either adding a can_implement check that params.checkpoint_every_n_tokens > 0
&& params.checkpoint_every_n_tokens % BlkSeqKV == 0 or by guarding all modulus
uses with a runtime check (kEnableCheckpointing && ckpt_blk_interval > 0) before
doing (blk+1) % ckpt_blk_interval and num_blocks % ckpt_blk_interval.
- Around line 186-187: Add a brief inline comment immediately above the typedefs
InverseType and CollectiveInverse explaining that InverseType is intentionally
fixed to cutlass::half_t because CollectiveInverse requires fp16, and that
explicit conversions handle cases where Element (e.g., bfloat16_t) differs;
reference the conversion logic that handles Element != InverseType so readers
know the precision trade-off is deliberate and localized to this hot path.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 403-406: The debug printf references an undefined symbol
StateMmaRegisterRequirement; change that to the correct MmaRegisterRequirement
used by this kernel. Edit the DPRINTF0_WG call in
flat_kernel_tma_warpspecialized_delta_rule.hpp (the branch handling
WarpGroupRole::Math0/Math1) to print MmaRegisterRequirement instead of
StateMmaRegisterRequirement so the identifier matches the template instantiation
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>() and avoids the
undeclared identifier error.
- Around line 108-110: Remove the unused ClusterBarrier field from the
SharedStorage struct: delete the declaration of cutlass::arch::ClusterBarrier
load_warp_barrier so SharedStorage no longer reserves the unused 16 bytes;
verify operator() and related code references (none expected) compile cleanly
after removal and mirror the same change in the Hopper variant of the file to
keep both versions consistent.

---

Outside diff comments:
In `@flashinfer/gdn_prefill.py`:
- Around line 191-198: The docstring Note is outdated: update the Note block in
the docstring near get_gdn_prefill_module to include SM120 as a supported
architecture (in addition to SM90 and SM100), call out the SM120 JIT module
loading behavior, and state any SM120-specific requirements (e.g., head_size
constraints or package/version requirements) consistent with the runtime checks
in get_gdn_prefill_module so users of RTX Blackwell can discover the correct
path.

In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 36-52: Update the stale docstring and skip message in
_skip_if_unsupported to include SM120 alongside SM90/SM100, and remove the extra
cuda_major < 13 gate in the is_sm120a_supported branch so the SM120 path relies
on is_sm120a_supported's existing CUDA 12.8+ check; specifically, edit the
helper _skip_if_unsupported, adjust the branch that currently checks
is_sm120a_supported(device) to not re-check CUDA major version (or change it
only if you confirm CUDA 13 is required), and update the final pytest.skip
message to mention SM120; refer to is_sm120a_supported and
gen_gdn_prefill_sm120_module when reconciling required CUDA versions.

---

Nitpick comments:
In `@csrc/gdn_prefill_launcher.cu`:
- Around line 60-63: Adjust the indentation and dead-code in the `#else` branch:
align the two lines containing FLASHINFER_ERROR and the following statement to
match the surrounding lambda body indentation (use the same 6-space indent as
the surrounding block) and remove the redundant "return false;" after
FLASHINFER_ERROR (since FLASHINFER_ERROR typically throws and the return is
unreachable); locate the branches by searching for the preprocessor "#else"
blocks surrounding the FLASHINFER_ERROR calls in the lambda in
gdn_prefill_launcher.cu.
- Around line 64-78: The dispatch in gdn_prefill_launcher currently checks only
device_major == 12 and will route SM12.1 devices to the SM120 kernel; add an
explicit minor-version check (e.g., ensure device_minor == 0) alongside the
existing device_major == 12 guard or add a clear comment documenting the
intentional major-only dispatch policy to prevent accidental routing of SM121 to
flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm120,...>; update the
branch that currently reads "} else if (device_major == 12) {" to either require
both device_major == 12 && device_minor == 0 or to document why minors are
allowed so future callers of gdn_prefill_launcher know the behavior.

In `@flashinfer/jit/gdn.py`:
- Around line 39-46: Replace the assert + if/elif selection for arch with a
mapping lookup to remove branching and satisfy Ruff: create a dict mapping arch
strings ("sm90", "sm120") to their flag lists (e.g., {"sm90": sm90a_nvcc_flags +
["-DFLAT_SM90A_ENABLED"], "sm120": sm120a_nvcc_flags +
["-DFLAT_SM120A_ENABLED"]}), then set arch_specific_flags = mapping[arch]; raise
a clear ValueError (or use mapping.get with a fallback and raise) if arch is
unsupported; reference the existing variables arch, sm90a_nvcc_flags,
sm120a_nvcc_flags and arch_specific_flags when making the change.

In `@include/flashinfer/flat/hopper/collective/flat_common.hpp`:
- Around line 167-170: The two functions restage_smem_layout and
unstage_smem_layout are identical; avoid unnecessary deprecation churn—either
finish migrating all call sites to restage_smem_layout in this PR so the old
symbol can be removed, or revert the deprecation and make restage_smem_layout a
simple alias of unstage_smem_layout (or vice versa) until a dedicated cleanup PR
can remove the duplicate; update callers consistently and remove the deprecation
attribute if you choose the alias approach so no build/runtime break occurs.

In `@include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh`:
- Around line 137-151: Replace the three generic exceptions so they include the
CUTLASS status value/string from the local variable status when any of
op.can_implement(arguments), op.initialize(arguments, workspace_buffer, stream)
or op.run(stream) return non-success; update the thrown std::runtime_error
messages in the blocks referencing status and the call that failed
(op.can_implement, op.initialize, op.run) to include the status code or a
human-readable status string to aid triage.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 88-94: The three option flags NeedsAlpha, NeedsBeta, and
NeedsDecay are declared as int but are used as booleans in std::conditional_t
and if constexpr; change their declarations to constexpr bool (e.g., static
constexpr bool NeedsAlpha = find_option_t<Tag::kNeedsAlpha, cute::true_type,
Options>::value;) for NeedsAlpha, NeedsBeta, and NeedsDecay, keeping the
existing default find_option_t expressions and the static_assert(!NeedsDecay,
...) intact so all boolean checks and type-dependent conditionals behave
correctly.
- Around line 415-433: Remove the redundant "true &&" at the start of the return
expression in can_implement and make the head-ratio logic explicit: compute
ratio only when one head count divides the other and require
(problem_size.num_q_heads % problem_size.num_v_heads == 0) ||
(problem_size.num_v_heads % problem_size.num_q_heads == 0) before using the
integer ratio; keep the existing is_gqa_like and is_gva_like checks but only
evaluate them after that divisibility precondition to avoid relying on implicit
integer truncation in the ratio calculation.
- Around line 146-147: Remove the stray empty statement after the type alias: in
the declaration of SmemLayoutV_SD (using SmemLayoutV_SD =
decltype(restage_smem_layout(SmemLayoutK_SD{}, Int<StagesV::value>{}));) delete
the extra trailing semicolon so the line ends with a single semicolon; no other
changes to restage_smem_layout, SmemLayoutK_SD, Int or StagesV usages are
required.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 82-85: The alias MathWarpGroupOrderBarrier (and its accompanying
FIXME comment) is unused — locate the typedef using MathWarpGroupOrderBarrier in
this file and either remove the typedef and the FIXME line or replace the FIXME
with a TODO that references a new issue number; ensure you do not change usage
sites (operator() uses OrderedMathBarriers), so after removal build should still
compile and no symbols named MathWarpGroupOrderBarrier should remain; if you
choose to keep a placeholder, add a clear TODO with an issue ID instead of the
FIXME.
- Around line 33-56: The helpers round_down and get_register_requirements are
defined at namespace scope with generic names; move them out of the global
namespace by either making them static constexpr private (or protected) member
functions inside FlatKernelTmaWarpSpecializedDeltaRule or by placing them into a
dedicated detail/internal namespace (e.g., flat::kernel::detail) to avoid
ODR/name collisions; update all references inside
FlatKernelTmaWarpSpecializedDeltaRule to call the new members or
detail-qualified names and ensure their signatures remain unchanged
(round_down<T1,T2> and get_register_requirements(uint32_t,uint32_t,uint32_t)).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 065d402b-afa1-4498-b401-f39437844f53

📥 Commits

Reviewing files that changed from the base of the PR and between 0e18a1c and ba18e9f.

📒 Files selected for processing (14)
  • csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc
  • csrc/gdn_prefill_launcher.cu
  • csrc/gdn_prefill_sm120_kernel_inst.jinja
  • csrc/prefill_kernel_delta_rule_sm120.cu
  • flashinfer/gdn_prefill.py
  • flashinfer/jit/gdn.py
  • include/flashinfer/flat/hopper/collective/flat_common.hpp
  • include/flashinfer/flat/prefill/prefill_kernel.hpp
  • include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh
  • include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
  • include/flashinfer/flat/sm120/device/device_universal.hpp
  • include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp
  • include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
  • tests/gdn/test_prefill_delta_rule.py

Comment thread csrc/gdn_prefill_sm120_kernel_inst.jinja
Comment thread flashinfer/jit/gdn.py
Comment on lines 33 to 38
"""Generate JIT module for GDN prefill kernel with separate compilation.

This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)
plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.
"""
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring count is stale: 32 combinations, not 16.

The docstring says "32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)", but the loop at Line 69 uses itertools.product([False, True], repeat=5) → 32 combinations, and the in-line comment at Line 60 already says "2 dtypes × 32 boolean combinations". So the total is 64 files (2 × 32), which matches the Line 60 comment.

💡 Proposed fix
-    This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)
-    plus the original launcher file. The separate files enable parallel compilation by ninja,
+    This generates 64 separate kernel instantiation files (2 dtypes x 32 boolean combinations)
+    plus the original launcher file. The separate files enable parallel compilation by ninja,
     significantly reducing build time on multi-core machines.
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 35-35: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gdn.py` around lines 33 - 38, Update the module docstring in
flashinfer/jit/gdn.py to reflect the correct count: change the phrase "32
separate kernel instantiation files (2 dtypes × 16 boolean combinations)" to "64
separate kernel instantiation files (2 dtypes × 32 boolean combinations)" (and
keep the note about the original launcher file) to match the
itertools.product([False, True], repeat=5) loop used to generate kernel variants
and the inline comment near the loop.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from SM90 version, we should open a seperate PR to fix them all in once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyunh-nv, understood — deferring the docstring fix to a follow-up PR that cleans it up across both the SM90 and SM120 versions makes sense. Would you like me to open a GitHub issue to track this fix so it doesn't get lost?


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Each `gen_*_module()` function in `flashinfer/jit/` must: (1) compute unique URI from parameters, (2) create generation directory, (3) optionally render Jinja template, (4) copy source files, (5) return JitSpec

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.

Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:114-116
Timestamp: 2026-03-27T20:33:11.994Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (flashinfer-ai/flashinfer), `_compute_optimal_warps_for_k` must receive `sf_blocks_per_warp` as an explicit parameter (not use the global `SF_BLOCKS_PER_WARP=16` constant). The `MXFP8QuantizeSwizzledKernel` constructor calls it with `self._sf_blocks_per_warp`, which is `SF_BLOCKS_PER_WARP=16` when `use_2t_per_sf=True` and `SF_BLOCKS_PER_WARP_SMALL=8` when `use_2t_per_sf=False`. Using the wrong constant causes fractional `rows_per_block` (e.g., K=3072 4T/SF: 30 warps → 960 threads → 2.5 truncated to 2 → write race from excess threads overlapping the next block's first row). MXFP4 and NVFP4 are unaffected because they use 1 thread per SF block with no multi-thread variant.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py:343-380
Timestamp: 2026-04-14T19:10:27.074Z
Learning: In `flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py` (flashinfer-ai/flashinfer PR `#3066`), `MoEDynamicKernel._setup_attributes()` intentionally omits the full SMEM post-check loop present in `MoEStaticKernel`. The `_compute_stages` output is already conservatively clamped (`max(1, min(ab_stage, 4))`) and further reduced by the k_tile_cnt divisibility check, yielding ab_stage=1 or 2 in all tested configurations — well within SM120's 232KB SMEM budget even with the extra staged sB_up/sSFB_up pair. A proper `_shared_storage_size_bytes()` for the dynamic kernel's different struct layout would be needed for a full post-check; the maintainer deferred this to a follow-up. Do not re-flag the missing post-check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:384-385
Timestamp: 2026-03-27T20:51:45.564Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (`MXFP8QuantizeSwizzledKernel`, small-K path), the padding-column zeroing in the swizzled small-K path requires a thread-stride loop, not a simple predicated write. Because `sf_col_idx = local_tidx // _threads_per_sf` is bounded by `[0, num_sf_blocks_per_row)`, a bare `if sf_col_idx >= num_sf_blocks_per_row` guard is unreachable. The correct pattern (matching MXFP4/NVFP4 swizzled kernels) is:
- Padding rows: loop starting at `sf_col_idx`, striding by `num_sf_blocks_per_row`, up to `padded_sf_cols`.
- Real rows: loop starting at `num_sf_blocks_per_row + sf_col_idx`, striding by `num_sf_blocks_per_row`, guarded by `const_expr(self.num_sf_blocks_per_row != self.padded_sf_cols)` so it is eliminated at compile time when `K/32` is a multiple of 4 (no column padding needed).

Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes × 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3026
File: include/flashinfer/gemm/fp4_gemm_template_sm120.h:267-270
Timestamp: 2026-04-09T21:51:00.268Z
Learning: In flashinfer-ai/flashinfer, `include/flashinfer/gemm/fp4_gemm_template_sm120.h` is gated by `#define FLASHINFER_ENABLE_SM120` and is only included from `fp4_gemm_cutlass_template_sm120.h`, which is compiled exclusively for SM120/SM121 targets. Adding a runtime `Sm12xOnly` architecture guard inside this file is redundant — there is no code path that instantiates these kernels on non-SM12x hardware. Do not suggest adding such guards to this file.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use SHA256 hashing for source files and include operation type, parameters, compilation flags, and CUDA architecture in URI computation for cache invalidation

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3080
File: flashinfer/fused_moe/cute_dsl/b12x_moe.py:48-49
Timestamp: 2026-04-16T01:51:16.398Z
Learning: In flashinfer-ai/flashinfer, only use `backend_requirement` when an API dispatches across multiple backends. For single-backend, architecture-gated APIs that exclusively target a specific compute capability (e.g., SM120/SM121, such as `b12x_fused_moe` / `B12xMoEWrapper` in `flashinfer/fused_moe/cute_dsl/b12x_moe.py`), prefer and keep `supported_compute_capability([120, 121])` instead of suggesting a replacement with `backend_requirement`.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/fused_moe.py:206-220
Timestamp: 2026-04-14T19:11:17.176Z
Learning: In `flashinfer/fused_moe/cute_dsl/fused_moe.py` (flashinfer-ai/flashinfer PR `#3066`), the SM120/SM121 dispatch paths (`_moe_core_impl`, `CuteDslMoEWrapper.run`, and `cute_dsl_fused_moe_nvfp4`) intentionally do NOT forward `local_expert_offset` to `launch_sm120_moe`. Expert Parallelism (EP) is unsupported on SM120: the dynamic kernel (`MoEDynamicKernel`) lacks `global_to_local_expert` remapping, and EP tests are gated to SM100-only via `sm100_only`. Passing `local_expert_offset` without kernel-side support would silently produce incorrect results. Do not flag the missing `local_expert_offset` propagation in SM120 call sites as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.

Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: include/flashinfer/norm/ln_silu_headers.cuh:258-270
Timestamp: 2026-04-03T20:17:43.361Z
Learning: In `include/flashinfer/norm/ln_silu_headers.cuh`, the pre-SM80 `#else` branch inside `struct Converter<float2, nv_bfloat162>::convert` (the union-based fallback) is intentionally dead code. `fused_rmsnorm_silu` requires SM80+ at runtime, so the `#if __CUDA_ARCH__ >= 800` path (using `__float22bfloat162_rn`) is the only path that ever compiles. Do not flag the union member aliasing issue in the `#else` branch as a bug.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use `functools.cache` decorator on JIT module generator functions to implement Python-level module caching

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to {include/flashinfer/**/*.cuh,csrc/**/*.cu} : For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh:232-262
Timestamp: 2026-04-02T18:45:38.854Z
Learning: In `include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh` (flashinfer-ai/flashinfer PR `#2962`), the per-step `state_dst_slots` precompute has three mutually exclusive branches:
1. `dst_state_batch_indices` present → always write unless index == pad_slot_id (caller controls slots via pad_slot_id; no update_state gating needed).
2. `intermediate_states` present → always cache every step (no update_state gating needed).
3. Neither → only write at last step when params.update_state is true.
`intermediate_states_buffer` and `dst_state_batch_indices` are enforced mutually exclusive by a Python-side ValueError in `flashinfer/mamba/selective_state_update.py`. Do not flag the absence of `update_state` gating in branches 1 and 2 as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).

Comment on lines 95 to 99
template <class Layout, class Stages = _1>
[[deprecated("use restage_smem_layout instead")]]
CUTE_DEVICE constexpr auto unstage_smem_layout(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, make_layout(stages)));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Deprecation will fire on existing in-tree callers; migrate them or delegate.

unstage_smem_layout is marked [[deprecated]], but include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (lines 157, 158, 160) still invokes it three times to build SmemLayoutQ_SD, SmemLayoutK_DS, and SmemLayoutV_DS. Depending on warning configuration (-Werror=deprecated-declarations), this will either spam build warnings or break the build. Either migrate those call sites to restage_smem_layout in this PR, or have the deprecated function delegate to the new one to keep behavior in lockstep.

♻️ Suggested delegation to avoid body duplication
 template <class Layout, class Stages = _1>
 [[deprecated("use restage_smem_layout instead")]]
 CUTE_DEVICE constexpr auto unstage_smem_layout(Layout const& layout, Stages stages = {}) {
-  return composition(layout, make_tuple(_, _, make_layout(stages)));
+  return restage_smem_layout(layout, stages);
 }

Note this requires moving restage_smem_layout above unstage_smem_layout (or adding a forward declaration).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/flat/hopper/collective/flat_common.hpp` around lines 95 -
99, The deprecated function unstage_smem_layout is still called from
flat_collective_tma_warpspecialized_delta_rule (for SmemLayoutQ_SD,
SmemLayoutK_DS, SmemLayoutV_DS) which will trigger deprecation warnings/errors;
fix by either updating those call sites to use restage_smem_layout or make
unstage_smem_layout delegate to restage_smem_layout so callers remain valid. If
delegating, ensure restage_smem_layout is declared/defined before
unstage_smem_layout (or add a forward declaration) and implement
unstage_smem_layout to simply call and return restage_smem_layout(layout,
stages) to keep behavior identical.

Comment on lines +108 to +110

alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check for any use of load_warp_barrier across the repo.
rg -nP -C2 '\bload_warp_barrier\b'

Repository: flashinfer-ai/flashinfer

Length of output: 1217


🏁 Script executed:

#!/bin/bash
# Get the full context of the SharedStorage struct and operator() in the SM120 file
cat -n include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | sed -n '80,150p'

Repository: flashinfer-ai/flashinfer

Length of output: 3081


🏁 Script executed:

#!/bin/bash
# Check if load_warp_barrier appears in any operator() or method implementations
rg -A 50 'operator\(\)' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2069


🏁 Script executed:

#!/bin/bash
# Check if there are any references to storage through collective operations or other indirect access
rg 'storage\.' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1026


🏁 Script executed:

#!/bin/bash
# Search for any usage patterns involving load_warp_barrier
# Check for constructor initialization, or other patterns
rg '(load_warp_barrier|&.*load_warp|::arrive|->arrive)' include/flashinfer/flat/sm120/kernel/

Repository: flashinfer-ai/flashinfer

Length of output: 216


🏁 Script executed:

#!/bin/bash
# Check if there are other similar files (Hopper version) that might provide context
wc -l include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
# Show lines around load_warp_barrier in Hopper version too
sed -n '110,130p' include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 918


🏁 Script executed:

#!/bin/bash
# Check if ClusterBarrier is used elsewhere in the SM120 file (besides declaration)
rg 'ClusterBarrier|arrive_and_wait' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 132


🏁 Script executed:

#!/bin/bash
# Check the Hopper file's operator() to see if load_warp_barrier is used there
rg -A 150 'operator\(\)' include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp | grep -E '(load_warp|storage\.)' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 794


Remove load_warp_barrier from SharedStorage — it's declared but never initialized or used.

The member cutlass::arch::ClusterBarrier load_warp_barrier (line 109) appears nowhere else in the codebase. All other storage members in SharedStorage are initialized and consumed in operator(), but this one is dead code. It wastes 16 bytes of aligned shared memory. The same unused member exists in the Hopper version of the file as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 108 - 110, Remove the unused ClusterBarrier field from the
SharedStorage struct: delete the declaration of cutlass::arch::ClusterBarrier
load_warp_barrier so SharedStorage no longer reserves the unused 16 bytes;
verify operator() and related code references (none expected) compile cleanly
after removal and mirror the same change in the Hopper variant of the file to
keep both versions consistent.

@johnnynunez
Copy link
Copy Markdown
Contributor

I think that they have the same SMEM.
Can you confirm it again? @eugr

i am with other things now

@guangyunh-nv
Copy link
Copy Markdown
Collaborator Author

@yzh119 @jiahanc This is basically downgraded from SM90. SM90 implemented a 3-staged pipeline where LD/ST warp in a WG and state unrelated computation in a WG and state related computation is in a third group of WGs. For SM120, we merged computation WGs into one group (2 WGs) due to register pressure from HMMA operands. Otherwise, they are largely the same.

NOTE: this kernel can (should) be further downgrade from TMA to LDGSTS with no warp specialization, then it can be run on SM80 and SM89 ampere and ada devices, this is left as a homework for the community :)

Copy link
Copy Markdown
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is not the scope of this PR, just share it here. The current sm90 prefill path looks like:

flowchart LR
    q -->|l2norm| qfi
    k -->|l2norm| kfi
    g -->|"exp + .float32"| alpha
    beta -->|".float32"| beta_fi
    indices -->|"where(≥0, idx, last_slot)\n.int64()"| idx64
    idx64 --> gather["gather + .float32"]
    pool["ssm_states\n(pool)"] --> gather
    gather --> init
    qfi & kfi & v & alpha & beta_fi & init --> K["chunk_gated_delta_rule"]
    K --> out_state -->|"index_copy_(idx64)"| pool2["ssm_states\n(pool)"]
Loading

Ideally, we could at least fuse the pool indexing and exp(g) parts.

Comment thread flashinfer/gdn_prefill.py
if is_sm90a_supported(device):
module = gen_gdn_prefill_sm90_module().build_and_load()
elif is_sm120a_supported(device):
module = gen_gdn_prefill_sm120_module().build_and_load()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the test, we mention SM120 GDN prefill requires CUDA 13+, but here we only check is_sm120a_supported that checks 12.8 under the hood. Seems like contradicting to each other. What is the minimum recomended cuda version for the sm120 prefill kernels?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually not quite sure what is the recommended version. According to https://docs.nvidia.com/cuda/archive/12.8.1/blackwell-compatibility-guide/index.html

CUDA applications built using CUDA Toolkit 12.8 are compatible with Blackwell architecture as long as they are built to include kernels in native cubin (compute capability 10.0) or PTX form or both.

So I think we are safe to relax the requirement to 12.8.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relaxed to 12.8 in test

Comment thread flashinfer/gdn_prefill.py
@eugr
Copy link
Copy Markdown

eugr commented Apr 20, 2026

I think that they have the same SMEM. Can you confirm it again? @eugr

Yes, sm120/121 have the same 100KB SMEM (99 usable, I think?).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (1)

109-110: ⚠️ Potential issue | 🟡 Minor

Unused load_warp_barrier wastes 16 B of aligned smem.

cutlass::arch::ClusterBarrier load_warp_barrier is declared here but never referenced in operator() or anywhere else in the codebase. Safe to delete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 109 - 110, The struct contains an unused member alignas(16)
cutlass::arch::ClusterBarrier load_warp_barrier which wastes 16 bytes of shared
memory; remove the declaration of load_warp_barrier from the struct (the member
declared alongside other smem members) and rebuild to confirm no references
remain (search for load_warp_barrier and operator() to ensure it isn't used).
include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (1)

1158-1197: ⚠️ Potential issue | 🟡 Minor

BlkSeqKV shadow + ckpt_blk_interval == 0 div-by-zero still present.

Pre-existing issue carried in from the SM90 path (and acknowledged to be fixed in a separate PR per the prior discussion): the local BlkSeqKV at Line 1158 is actually get<0>(TileShape{}) in the SM90 variant, and the % ckpt_blk_interval uses at Lines 1181 and 1192 have no guard against a zero divisor when kEnableCheckpointing is true but checkpoint_every_n_tokens < BlkSeqKV (or simply not a multiple of it). The SM120 version here uses the correct class-level BlkSeqKV, but the divisor guard is still missing. Worth tracking alongside the SM90 cleanup so this kernel doesn't ship with latent UB if a caller ever passes a misconfigured checkpoint stride.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 1158 - 1197, The code computes ckpt_blk_interval from
checkpoint_every_n_tokens / BlkSeqKV and then uses it as a divisor in modulo
checks (in the loop calling compute_loop_body and kv_checkpoint_store), which
can be zero and cause UB; change the logic around ckpt_blk_interval (and its use
sites) so that when kEnableCheckpointing is true you either (a) compute
ckpt_blk_interval with a safe floor of 1 (e.g., if checkpoint_every_n_tokens <
BlkSeqKV then set ckpt_blk_interval = 1) or (b) guard every use with an explicit
check that ckpt_blk_interval > 0 before performing (blk+1) % ckpt_blk_interval
or num_blocks % ckpt_blk_interval; update the initialization near BlkSeqKV and
the checkpointing branches that call kv_checkpoint_store to use the safe
non-zero value or the guard so no modulo by zero can occur.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 394-402: The StoreO branch (when ldst_warp_role ==
LdStWarpRole::StoreO) currently calls scheduler.get_next_work(...) once and
directly invokes CollectiveMainloop::store, which diverges from sibling roles
and fails to iterate or guard on work_desc.is_valid(...). Change this to mirror
the other roles: obtain the first work_desc via scheduler.get_next_work(...),
then drain the scheduler with a for loop like for (; work_desc.is_valid(...);
work_desc = scheduler.get_next_work(...)) and call
collective_mainloop.store(...) inside that loop (ensuring any o_pipeline
producer/consumer handshakes remain inside the loop) so StoreO processes every
valid task and skips invalid descriptors.

---

Duplicate comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 1158-1197: The code computes ckpt_blk_interval from
checkpoint_every_n_tokens / BlkSeqKV and then uses it as a divisor in modulo
checks (in the loop calling compute_loop_body and kv_checkpoint_store), which
can be zero and cause UB; change the logic around ckpt_blk_interval (and its use
sites) so that when kEnableCheckpointing is true you either (a) compute
ckpt_blk_interval with a safe floor of 1 (e.g., if checkpoint_every_n_tokens <
BlkSeqKV then set ckpt_blk_interval = 1) or (b) guard every use with an explicit
check that ckpt_blk_interval > 0 before performing (blk+1) % ckpt_blk_interval
or num_blocks % ckpt_blk_interval; update the initialization near BlkSeqKV and
the checkpointing branches that call kv_checkpoint_store to use the safe
non-zero value or the guard so no modulo by zero can occur.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`:
- Around line 109-110: The struct contains an unused member alignas(16)
cutlass::arch::ClusterBarrier load_warp_barrier which wastes 16 bytes of shared
memory; remove the declaration of load_warp_barrier from the struct (the member
declared alongside other smem members) and rebuild to confirm no references
remain (search for load_warp_barrier and operator() to ensure it isn't used).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 813d7454-be91-4dfe-9f52-689d0621f455

📥 Commits

Reviewing files that changed from the base of the PR and between ba18e9f and f809316.

📒 Files selected for processing (5)
  • csrc/gdn_prefill_sm120_kernel_inst.jinja
  • flashinfer/gdn_prefill.py
  • include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
  • include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
  • tests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (2)
  • tests/gdn/test_prefill_delta_rule.py
  • csrc/gdn_prefill_sm120_kernel_inst.jinja
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gdn_prefill.py

Comment on lines +394 to +402
} else if (ldst_warp_role == LdStWarpRole::StoreO) {
auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size);
DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n",
work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len);
auto tile_shape = typename CollectiveMainloop::TileShape{};
collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps,
params.problem_size, tile_shape, work_desc, o_pipeline,
o_smem_pipe_read, storage.tensors.mainloop.smem_o);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Confirm the scheduler-loop asymmetry and check the hopper counterpart for reference.
rg -nP -C2 'get_next_work|is_valid' include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
echo "---- hopper equivalent ----"
fd -a 'flat_kernel_tma_warpspecialized_delta_rule.hpp' include/flashinfer/flat/hopper | xargs -I{} sh -c 'echo "=== {} ==="; rg -nP -C2 "get_next_work|is_valid" {}'

Repository: flashinfer-ai/flashinfer

Length of output: 6062


StoreO warp does not iterate the scheduler, unlike all sibling roles.

LoadQKV (Lines 352–365), LoadBeta (Lines 366–379), LoadAlpha (Lines 380–393) and the math warp groups (Lines 407–419) all drain the scheduler via for (; work_desc.is_valid(...); work_desc = scheduler.get_next_work(...)). The StoreO branch calls scheduler.get_next_work exactly once and then invokes collective_mainloop.store(...) without a loop.

With the current non-persistent scheduler this is observationally fine (one work descriptor per block), but:

  • Behavior silently diverges the moment any persistent / multi-tile scheduler is wired in — LoadQKV/math produce outputs for every task while StoreO only drains the first one, so later tasks deadlock on o_pipeline.producer_commit (consumer never arrives) and/or publish wrong data.
  • It also means StoreO doesn't even check work_desc.is_valid(...) before consuming the first descriptor, so on an empty grid assignment it would operate on an invalid work_desc.

Mirror the other roles' scheduler loop here (and guard on is_valid).

🛠️ Suggested fix
       } else if (ldst_warp_role == LdStWarpRole::StoreO) {
-        auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size);
-        DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n",
-                    work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len);
-        auto tile_shape = typename CollectiveMainloop::TileShape{};
-        collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps,
-                                  params.problem_size, tile_shape, work_desc, o_pipeline,
-                                  o_smem_pipe_read, storage.tensors.mainloop.smem_o);
+        auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size);
+        CUTE_NO_UNROLL
+        for (; work_desc.is_valid(params.scheduler);
+             work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) {
+          DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n",
+                      work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len);
+          auto tile_shape = typename CollectiveMainloop::TileShape{};
+          collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps,
+                                    params.problem_size, tile_shape, work_desc, o_pipeline,
+                                    o_smem_pipe_read, storage.tensors.mainloop.smem_o);
+        }
       }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`
around lines 394 - 402, The StoreO branch (when ldst_warp_role ==
LdStWarpRole::StoreO) currently calls scheduler.get_next_work(...) once and
directly invokes CollectiveMainloop::store, which diverges from sibling roles
and fails to iterate or guard on work_desc.is_valid(...). Change this to mirror
the other roles: obtain the first work_desc via scheduler.get_next_work(...),
then drain the scheduler with a for loop like for (; work_desc.is_valid(...);
work_desc = scheduler.get_next_work(...)) and call
collective_mainloop.store(...) inside that loop (ensuring any o_pipeline
producer/consumer handshakes remain inside the loop) so StoreO processes every
valid task and skips invalid descriptors.

@guangyunh-nv
Copy link
Copy Markdown
Collaborator Author

@kaixih As for exp(g), I'd like address it together with hopper version. Otherwise, we will need to handle it different in the python wrapper before calling into kernel launch wrapper.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (1)

416-432: ⚠️ Potential issue | 🟠 Major

Reject invalid checkpoint intervals before device modulus.

checkpoint_every_n_tokens is documented as a multiple of BlkSeqKV, but can_implement does not enforce it. Values like 1..63 make ckpt_blk_interval == 0, and Lines 1180/1191 can execute % ckpt_blk_interval; non-multiples can also make the kernel write checkpoints at a different cadence than the caller likely allocated.

Proposed fix
   static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
     auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
                      ? problem_size.num_q_heads / problem_size.num_v_heads
                      : problem_size.num_v_heads / problem_size.num_q_heads;
 
     constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value;
+    bool checkpoint_interval_valid = true;
+    if constexpr (kEnableCheckpointing) {
+      checkpoint_interval_valid = args.checkpoint_every_n_tokens > 0 &&
+                                  args.checkpoint_every_n_tokens % int(BlkSeqKV) == 0;
+    }
 
     bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) &&
                        (problem_size.num_q_heads == ratio * problem_size.num_k_heads) &&
                        (problem_size.num_q_heads == ratio * problem_size.num_v_heads);
@@
     return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
            (problem_size.head_size <= get<2>(TileShape{})) &&
-           ((problem_size.head_size % Alignment) == 0);
+           ((problem_size.head_size % Alignment) == 0) && checkpoint_interval_valid;
   }
@@
       compute_loop_body(blk, /*is_first_block_=*/cute::false_type{},
                         /*is_final_block_=*/cute::false_type{});
       if constexpr (kEnableCheckpointing) {
-        if ((blk + 1) % ckpt_blk_interval == 0) {
+        if (ckpt_blk_interval > 0 && (blk + 1) % ckpt_blk_interval == 0) {
           kv_checkpoint_store(ckpt_count++);
         }
       }
@@
       // is always available via output_state from kv_store() below.
       if constexpr (kEnableCheckpointing) {
-        if (num_blocks % ckpt_blk_interval == 0) {
+        if (ckpt_blk_interval > 0 && num_blocks % ckpt_blk_interval == 0) {
           kv_checkpoint_store(ckpt_count);
         }
       }

Also applies to: 1157-1193

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 416 - 432, can_implement must validate that
checkpoint_every_n_tokens is a positive multiple of BlkSeqKV before any
device-modulus or kernel-related checks; add an early guard in can_implement
that reads the checkpoint_every_n_tokens from Arguments (or the appropriate arg
field), computes ckpt_blk_interval = checkpoint_every_n_tokens / BlkSeqKV and
returns false if checkpoint_every_n_tokens == 0 or (checkpoint_every_n_tokens %
BlkSeqKV) != 0 (or ckpt_blk_interval == 0), so the function rejects invalid
intervals before performing the rest of the head/Alignment/device checks (refer
to can_implement, Arguments, checkpoint_every_n_tokens, BlkSeqKV, and
ckpt_blk_interval).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 924-980: The three state-transfer lambdas kv_load, kv_store, and
kv_checkpoint_store currently construct tiled_copy_kv with
Copy_Atom<AutoVectorizingCopy, Element>{}, but Element may be fp16/bf16 while
the source/destination global buffers (ptr_input_state, ptr_output_state,
ptr_state_checkpoints) and accumulator tKVrKV are float; change the copy atom to
use float (i.e. Copy_Atom<AutoVectorizingCopy, float>{}) in each occurrence so
the memory-to-accumulator transfers use the float copy path and avoid
truncation; apply the identical change in the hopper variant’s corresponding
lambdas as well.

---

Duplicate comments:
In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`:
- Around line 416-432: can_implement must validate that
checkpoint_every_n_tokens is a positive multiple of BlkSeqKV before any
device-modulus or kernel-related checks; add an early guard in can_implement
that reads the checkpoint_every_n_tokens from Arguments (or the appropriate arg
field), computes ckpt_blk_interval = checkpoint_every_n_tokens / BlkSeqKV and
returns false if checkpoint_every_n_tokens == 0 or (checkpoint_every_n_tokens %
BlkSeqKV) != 0 (or ckpt_blk_interval == 0), so the function rejects invalid
intervals before performing the rest of the head/Alignment/device checks (refer
to can_implement, Arguments, checkpoint_every_n_tokens, BlkSeqKV, and
ckpt_blk_interval).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 91d0a547-abf7-40f1-8cfc-829a56508e16

📥 Commits

Reviewing files that changed from the base of the PR and between f809316 and b17d688.

📒 Files selected for processing (1)
  • include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp

Comment on lines +924 to +980
auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA {
DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx,
v_head_idx);
int num_state_heads = problem_size.num_sab_heads;
int state_head_idx = work_desc.o_head_idx();
auto gKV = make_tensor(make_gmem_ptr(params.ptr_input_state),
make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{},
num_state_heads, problem_size.num_seqs)))(
_, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous

auto tiled_copy_kv =
make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx);

auto tKVgKV = thr_copy_kv.partition_S(select_tensor<1, 0>(gKV));
copy(tiled_copy_kv, tKVgKV, tKVrKV);
};

auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop
DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx,
v_head_idx);
int num_state_heads = problem_size.num_sab_heads;
int state_head_idx = work_desc.o_head_idx(); // num_o_heads == num_sab_heads
auto gKV = make_tensor(make_gmem_ptr(params.ptr_output_state),
make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{},
num_state_heads, problem_size.num_seqs)))(
_, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous

auto tiled_copy_kv =
make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx);

auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV));
copy(tiled_copy_kv, tKVrKV, tKVgKV);
};

auto kv_checkpoint_store = [&](int checkpoint_idx) INLINE_LAMBDA {
if constexpr (kEnableCheckpointing) {
DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> checkpoint[%d]\n", seq_idx, q_head_idx,
k_head_idx, v_head_idx, checkpoint_idx);
int num_state_heads = problem_size.num_sab_heads;
int state_head_idx = work_desc.o_head_idx();
int64_t ckpt_offset = params.checkpoint_cu_starts[seq_idx] + checkpoint_idx;

// Layout: [total_checkpoints, num_sab_heads, HeadSizeQK, HeadSizeV] LayoutLeft
auto gKV =
make_tensor(make_gmem_ptr(params.ptr_state_checkpoints +
ckpt_offset * num_state_heads * HeadSizeQK * HeadSizeV +
state_head_idx * HeadSizeQK * HeadSizeV),
make_layout(make_shape(Int<HeadSizeQK>{}, Int<HeadSizeV>{})));

auto tiled_copy_kv =
make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx);

auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV));
copy(tiled_copy_kv, tKVrKV, tKVgKV);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Compare state pointer types and state copy atom types in delta-rule code.
# Expected: copies to/from float state tensors should use float/accumulator copy typing, not Element.

rg -nP -C4 'ptr_(input|output)_state|ptr_state_checkpoints|Copy_Atom<AutoVectorizingCopy,\s*(Element|float|ElementAccumulatorKV)' -g '*delta_rule*'

Repository: flashinfer-ai/flashinfer

Length of output: 30808


🏁 Script executed:

#!/bin/bash
# Find the template definition and Element type in the file
head -100 include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "template|Element|struct|class" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 777


🏁 Script executed:

#!/bin/bash
# Look for Element type definition and template parameters in this specific file
sed -n '1,100p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | tail -50

Repository: flashinfer-ai/flashinfer

Length of output: 2338


🏁 Script executed:

#!/bin/bash
# Search for how Element is defined as a template parameter
rg "template.*Element" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -A 5 -B 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 467


🏁 Script executed:

#!/bin/bash
# Check if there are any float Copy_Atom usages in delta-rule files for comparison
rg "Copy_Atom<.*float" include/flashinfer/flat/ -g '*delta_rule*' -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look for ElementAccumulatorKV usage in copy operations to understand the correct pattern
rg "Copy_Atom.*ElementAccumulatorKV|Copy_Atom.*float" include/flashinfer/flat/ -g '*delta_rule*' -B 3 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check what tKVrKV is - look for its declaration/type
rg "tKVrKV\s*=" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 5 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 426


🏁 Script executed:

#!/bin/bash
# See the broader context around the kv_load/store lambdas to understand tKVrKV type
sed -n '900,940p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 1887


🏁 Script executed:

#!/bin/bash
# Check the hopper variant to see if it has the same issue
sed -n '780,850p' include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 3571


🏁 Script executed:

#!/bin/bash
# Look for TiledMmaKV definition to understand the copy atom better
rg "using TiledMmaKV" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 466


🏁 Script executed:

#!/bin/bash
# Check if there's ElementAccumulatorKV being used anywhere for state copies
rg "ElementAccumulatorKV" include/flashinfer/flat/ -g '*delta_rule*' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1442


🏁 Script executed:

#!/bin/bash
# Look at how kv_tiled_mma is used with copy atoms in the same file
rg "make_tiled_copy.*kv_tiled_mma" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 910


🏁 Script executed:

#!/bin/bash
# Look at how Q/K/V copies use copy atoms in the same file for comparison
rg "make_tiled_copy.*Copy_Atom.*Element" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp -B 1 -A 1 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1921


🏁 Script executed:

#!/bin/bash
# Check the partition_fragment_C to understand what tKVrKV really is
rg "partition_fragment_C.*kv_thr_mma" include/flashinfer/flat/ -g '*delta_rule*' -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 1512


🏁 Script executed:

#!/bin/bash
# Verify that partition_fragment_C creates a float accumulator by checking MMA definition
sed -n '300,400p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "TiledMmaKV|MmaOp|ElementAccumulatorKV" -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Double-check: Look for how the state pointers are cast and used to confirm they're actually float*
rg "ptr_input_state|ptr_output_state|ptr_state_checkpoints" include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 617


🏁 Script executed:

#!/bin/bash
# Look at the hopper variant to see if it has the same pattern and issue
sed -n '780,850p' include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -A 20 "kv_load ="

Repository: flashinfer-ai/flashinfer

Length of output: 1192


🏁 Script executed:

#!/bin/bash
# Verify Element can be fp16/bf16 by looking at instantiations
rg "FlatMainloopTmaWarpSpecializedDeltaRule<" include/flashinfer/flat/ -g '*.cuh' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check if there are any instantiations with explicit dtype parameters
rg "FlatMainloopTmaWarpSpecializedDeltaRule<.*__half|FlatMainloopTmaWarpSpecializedDeltaRule<.*bf16|FlatMainloopTmaWarpSpecializedDeltaRule<.*float" include/flashinfer/flat/ -g '*.cuh'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look at the actual kernel launcher to see what Element types are used
rg "FlatMainloopTmaWarpSpecializedDeltaRule" include/flashinfer/flat/sm120/ -g '*.cuh' -B 5 -A 5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Search for where FlatMainloopTmaWarpSpecializedDeltaRule is instantiated with specific types
find include/flashinfer/flat/sm120 -name "*.cuh" -o -name "*.hpp" | xargs grep -l "FlatMainloopTmaWarpSpecializedDeltaRule"

Repository: flashinfer-ai/flashinfer

Length of output: 231


🏁 Script executed:

#!/bin/bash
# Check the kernel launcher files to see instantiations
cat include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.cuh 2>/dev/null | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look for how the kernel is called with what Element types
rg "launch_flat_collective_tma_warpspecialized_delta_rule" include/flashinfer -g '*.cuh' -B 5 -A 5 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/an/bash
# Check what the actual data types are in the Python API
find . -name "*.py" -type f -exec grep -l "delta_rule" {} \; | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 573


🏁 Script executed:

#!/bin/bash
# Check the kernel builder to understand how the kernel is instantiated
sed -n '1,100p' include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp | grep -E "Element|struct|template" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 459


🏁 Script executed:

#!/bin/bash
# Look at how the kernel is actually built with what types
rg "make_kernel.*DeltaRule|FlatMainloopTmaWarpSpecializedDeltaRule" include/flashinfer/flat/sm120/kernel/ -B 3 -A 3 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1029


🏁 Script executed:

#!/bin/bash
# Check the Python interface to understand what dtypes are supported
grep -r "float16\|fp16\|bf16\|float32" benchmarks/bench_gdn_*.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1919


🏁 Script executed:

#!/bin/bash
# Look at the gdn_decode.py to see what Element types can be passed
sed -n '1,150p' flashinfer/gdn_decode.py | grep -E "dtype|Element|float" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 759


🏁 Script executed:

#!/bin/bash
# Confirm that ElementAccumulatorKV is always float and check the struct definition
sed -n '1,100p' include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp | grep -E "using ElementAccumulatorKV|static_assert.*ElementAccumulatorKV"

Repository: flashinfer-ai/flashinfer

Length of output: 183


🏁 Script executed:

#!/bin/bash
# Check if there's an issue with copying float to Element when Element is fp16/bf16
# Look at the CUTLASS documentation by checking what AutoVectorizingCopy expects
rg "AutoVectorizingCopy" include/flashinfer/ -B 2 -A 2 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 3303


Use float copy atom for float state tensors.

The state transfer operations (kv_load, kv_store, kv_checkpoint_store) copy between float global memory pointers (ptr_input_state, ptr_output_state, ptr_state_checkpoints) and the float accumulator fragment tKVrKV. Using Copy_Atom<AutoVectorizingCopy, Element>{} is incorrect because Element can be fp16/bf16 at compile time, which would select an incompatible copy type and cause data truncation when transferring float data.

Update all three occurrences (lines 935, 953, 976) to use float:

Proposed fix
      auto tiled_copy_kv =
-          make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+          make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);
@@
      auto tiled_copy_kv =
-          make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+          make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);
@@
        auto tiled_copy_kv =
-            make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, Element>{}, kv_tiled_mma);
+            make_tiled_copy_C(Copy_Atom<AutoVectorizingCopy, float>{}, kv_tiled_mma);

The same issue exists in the hopper variant at lines 800, 818, 841 of include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`
around lines 924 - 980, The three state-transfer lambdas kv_load, kv_store, and
kv_checkpoint_store currently construct tiled_copy_kv with
Copy_Atom<AutoVectorizingCopy, Element>{}, but Element may be fp16/bf16 while
the source/destination global buffers (ptr_input_state, ptr_output_state,
ptr_state_checkpoints) and accumulator tKVrKV are float; change the copy atom to
use float (i.e. Copy_Atom<AutoVectorizingCopy, float>{}) in each occurrence so
the memory-to-accumulator transfers use the float copy path and avoid
truncation; apply the identical change in the hopper variant’s corresponding
lambdas as well.

bool needs_beta = beta != nullptr;
bool needs_alpha = alpha != nullptr;
bool init_state = input_state != nullptr;
bool enable_ckpt = checkpoint_every_n_tokens > 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this is only supported by the SM120 kernel, not the SM90 kernel, is it correct?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are all supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants