Skip to content

[feat] add log gate and initial state pool support in blackwell gdn prefill#3167

Open
Observer007 wants to merge 5 commits intoflashinfer-ai:mainfrom
Observer007:feat/gdn_log_gate
Open

[feat] add log gate and initial state pool support in blackwell gdn prefill#3167
Observer007 wants to merge 5 commits intoflashinfer-ai:mainfrom
Observer007:feat/gdn_log_gate

Conversation

@Observer007
Copy link
Copy Markdown
Contributor

@Observer007 Observer007 commented Apr 24, 2026

📌 Description

Add optional log gate and initial state pool support. The demands are from sglang team.

No perf regression after the new feature addition.

GPU: NVIDIA B200 [Blackwell (SM100)]
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128

Heads            Seqlens           h_qk  h_v    FI Blackwell (SM100)   TFLOPS  FLA/Triton   Speedup
---------------------------------------------------------------------------------------------------
397B/122B TP8    1x8192               2    8                  0.330ms    13.0      0.388ms     1.17x +
397B/122B TP8    1x4096               2    8                  0.174ms    12.3      0.275ms     1.58x +
397B/122B TP8    1x2048               2    8                  0.096ms    11.1      0.225ms     2.33x +
397B/122B TP8    6144+2048            2    8                  0.252ms    17.0      0.334ms     1.32x +
397B/122B TP8    4096+4096            2    8                  0.175ms    24.6      0.291ms     1.66x +
397B/122B TP8    2048+6144            2    8                  0.253ms    17.0      0.335ms     1.32x +
397B/122B TP8    1024+7168            2    8                  0.292ms    14.7      0.358ms     1.23x +
397B/122B TP8    2048x4               2    8                  0.098ms    44.0      0.243ms     2.49x +
397B/122B TP8    1024x8               2    8                  0.060ms    71.6      0.247ms     4.12x +

397B/122B TP4    1x8192               4   16                  0.331ms    26.0      0.426ms     1.29x +
397B/122B TP4    1x4096               4   16                  0.175ms    24.5      0.291ms     1.67x +
397B/122B TP4    1x2048               4   16                  0.097ms    22.1      0.285ms     2.94x +
397B/122B TP4    6144+2048            4   16                  0.252ms    34.0      0.375ms     1.48x +
397B/122B TP4    4096+4096            4   16                  0.176ms    48.9      0.329ms     1.87x +
397B/122B TP4    2048+6144            4   16                  0.254ms    33.8      0.374ms     1.47x +
397B/122B TP4    1024+7168            4   16                  0.293ms    29.3      0.396ms     1.35x +
397B/122B TP4    2048x4               4   16                  0.099ms    86.7      0.335ms     3.38x +
397B/122B TP4    1024x8               4   16                  0.064ms   134.9      0.342ms     5.37x +

397B/122B TP2    1x8192               8   32                  0.333ms    51.6      0.588ms     1.77x +
397B/122B TP2    1x4096               8   32                  0.176ms    48.9      0.334ms     1.90x +
397B/122B TP2    1x2048               8   32                  0.098ms    43.9      0.247ms     2.52x +
397B/122B TP2    6144+2048            8   32                  0.255ms    67.5      0.586ms     2.30x +
397B/122B TP2    4096+4096            8   32                  0.178ms    96.7      0.588ms     3.31x +
397B/122B TP2    2048+6144            8   32                  0.256ms    67.1      0.589ms     2.30x +
397B/122B TP2    1024+7168            8   32                  0.295ms    58.2      0.590ms     2.00x +
397B/122B TP2    2048x4               8   32                  0.104ms   165.4      0.595ms     5.73x +
397B/122B TP2    1024x8               8   32                  0.120ms   143.0      0.590ms     4.91x +

397B/122B TP1    1x8192              16   64                  0.335ms   102.4      1.004ms     2.99x +
397B/122B TP1    1x4096              16   64                  0.178ms    96.7      0.528ms     2.97x +
397B/122B TP1    1x2048              16   64                  0.099ms    87.0      0.305ms     3.09x +
397B/122B TP1    6144+2048           16   64                  0.260ms   132.3      0.999ms     3.85x +
397B/122B TP1    4096+4096           16   64                  0.184ms   187.0      1.003ms     5.46x +
397B/122B TP1    2048+6144           16   64                  0.262ms   131.3      1.005ms     3.84x +
397B/122B TP1    1024+7168           16   64                  0.300ms   114.4      1.005ms     3.35x +
397B/122B TP1    2048x4              16   64                  0.200ms   172.0      1.017ms     5.09x +
397B/122B TP1    1024x8              16   64                  0.232ms   148.2      1.009ms     4.35x +

35B/9B/4B TP1    1x8192              16   32                  0.334ms    51.5      0.586ms     1.75x +
35B/9B/4B TP1    1x4096              16   32                  0.177ms    48.6      0.333ms     1.88x +
35B/9B/4B TP1    1x2048              16   32                  0.098ms    43.7      0.241ms     2.45x +
35B/9B/4B TP1    6144+2048           16   32                  0.255ms    67.3      0.586ms     2.30x +
35B/9B/4B TP1    4096+4096           16   32                  0.179ms    96.0      0.588ms     3.29x +
35B/9B/4B TP1    2048+6144           16   32                  0.258ms    66.5      0.588ms     2.28x +
35B/9B/4B TP1    1024+7168           16   32                  0.298ms    57.7      0.589ms     1.98x +
35B/9B/4B TP1    2048x4              16   32                  0.104ms   165.4      0.595ms     5.73x +
35B/9B/4B TP1    1024x8              16   32                  0.121ms   142.2      0.589ms     4.88x +

27B TP1          1x8192              16   48                  0.332ms    77.7      0.834ms     2.51x +
27B TP1          1x4096              16   48                  0.176ms    73.2      0.454ms     2.58x +
27B TP1          1x2048              16   48                  0.099ms    65.4      0.283ms     2.87x +
27B TP1          6144+2048           16   48                  0.256ms   100.7      0.771ms     3.01x +
27B TP1          4096+4096           16   48                  0.179ms   143.6      0.828ms     4.62x +
27B TP1          2048+6144           16   48                  0.258ms   100.0      0.830ms     3.22x +
27B TP1          1024+7168           16   48                  0.297ms    86.8      0.830ms     2.80x +
27B TP1          2048x4              16   48                  0.196ms   131.2      0.783ms     3.98x +
27B TP1          1024x8              16   48                  0.178ms   144.9      0.793ms     4.46x +

2B/0.8B TP1      1x8192              16   16                  0.330ms    26.1      0.423ms     1.28x +
2B/0.8B TP1      1x4096              16   16                  0.175ms    24.6      0.288ms     1.65x +
2B/0.8B TP1      1x2048              16   16                  0.097ms    22.1      0.229ms     2.35x +
2B/0.8B TP1      6144+2048           16   16                  0.252ms    34.1      0.375ms     1.49x +
2B/0.8B TP1      4096+4096           16   16                  0.176ms    48.9      0.331ms     1.88x +
2B/0.8B TP1      2048+6144           16   16                  0.253ms    34.0      0.375ms     1.49x +
2B/0.8B TP1      1024+7168           16   16                  0.292ms    29.4      0.399ms     1.37x +
2B/0.8B TP1      2048x4              16   16                  0.099ms    86.7      0.335ms     3.38x +
2B/0.8B TP1      1024x8              16   16                  0.064ms   134.9      0.344ms     5.41x +

Sym h32          1x8192              32   32                  0.332ms    51.7      0.586ms     1.76x +
Sym h32          1x4096              32   32                  0.176ms    48.8      0.331ms     1.88x +
Sym h32          1x2048              32   32                  0.098ms    43.7      0.245ms     2.50x +
Sym h32          6144+2048           32   32                  0.255ms    67.3      0.586ms     2.30x +
Sym h32          4096+4096           32   32                  0.178ms    96.6      0.588ms     3.31x +
Sym h32          2048+6144           32   32                  0.255ms    67.3      0.589ms     2.31x +
Sym h32          1024+7168           32   32                  0.295ms    58.3      0.589ms     2.00x +
Sym h32          2048x4              32   32                  0.104ms   165.0      0.595ms     5.72x +
Sym h32          1024x8              32   32                  0.121ms   141.9      0.590ms     4.88x +

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Pooled initial-state support via per-sequence initial-state indices.
    • Log-space gate mode to interpret gate inputs as natural logs.
  • Improvements

    • Call sites and kernels updated to accept pooled states and log-gate mode; call sites now use explicit keyword arguments.
    • Tighter numerical tolerances in reference checks.
  • Bug Fixes

    • Validation to prevent using indices without provided initial states.
  • Tests

    • Expanded tests for pooled states, log-gate paths, SM100 cases, and error conditions.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

Adds pooled initial-state indexing and an optional log-space gate mode to the GDN prefill path, updating SM100 (Blackwell) kernels, wrappers, public API signatures, reference implementation, tests, and a benchmark call site.

Changes

Cohort / File(s) Summary
Benchmark call
benchmarks/bench_gdn_prefill.py
Refactored a chunk_gated_delta_rule invocation to use explicit keyword arguments (e.g., initial_state, initial_state_indices, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel, output_state).
Blackwell kernel implementation
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
Added compile-time flags is_log_gate and is_initial_state_pool; threaded pooled-index tensors (s_in_indices/mS_init_indices) through host/device paths; updated _load_initial_state to handle pooled indices and zero-initialize on negative indices; adjusted cumprod indexing logic.
Blackwell wrapper
flashinfer/gdn_kernels/blackwell/gdn_prefill.py
chunk_gated_delta_rule_sm100 now accepts initial_state_indices and is_log_gate; compilation key and runtime execution include the optional index tensor and log-gate flag; runtime guard prevents indices without initial_state.
Public API wrapper
flashinfer/gdn_prefill.py
chunk_gated_delta_rule signature extended with initial_state_indices and is_log_gate; validates that indices require initial_state; forwards args to SM100 path (casting indices to int32) and rejects unsupported SM90/Hopper combinations.
Reference implementation
tests/gdn/reference_delta_rule.py
blockwise_delta_rule accepts initial_state and initial_state_indices and initializes per-sequence state from pooled entries (using per-seq indices when provided), with proper dtype/device casting and negative-index handling.
Tests
tests/gdn/test_prefill_delta_rule.py
Added SM100 tests for is_log_gate and pooled initial-state modes; convert kernel alpha to log-space when testing is_log_gate; tightened tolerances; added negative test for indices without initial_state; updated many test call sites to keyword form.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant API as chunk_gated_delta_rule API
    participant SM100 as SM100 Wrapper
    participant Compiler as Kernel Compiler
    participant Kernel as Device Kernel

    User->>API: Call with tensors + initial_state, initial_state_indices, is_log_gate
    API->>API: Validate initial_state_indices requires initial_state
    alt SM100 path
        API->>SM100: Forward tensors & flags (s_in, s_in_indices, is_log_gate, ...)
        SM100->>Compiler: Build/cache key (is_log_gate, is_initial_state_pool)
        Compiler->>Kernel: Compile kernel with flags
        SM100->>Kernel: Execute with s_in_indices, is_log_gate
        Kernel->>Kernel: Load initial state via mS_init_indices[batch_idx] (zero if negative)
        Kernel->>Kernel: Preprocess gate (treat as ln if is_log_gate)
        Kernel-->>SM100: Return outputs
    else SM90/Hopper path
        API-->>User: Raise NotImplementedError
    end
    SM100-->>API: Result
    API-->>User: Output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • aleozlx
  • yongwww
  • cyx-6
  • kahyunnam
  • jimmyzho

Poem

🐇 Pooled states in neat little stacks,
gates hum in log-time, no need to relax,
indices hop in—some true, some nil,
kernels wake, stitch memory still,
a rabbit claps: code hopped the hill!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main feature additions—log gate and initial state pool support in Blackwell GDN prefill—matching the core changes across all modified files.
Description check ✅ Passed The description includes a clear summary of changes, comprehensive benchmark results demonstrating no performance regression, and completed pre-commit checks. However, the 'All tests are passing' checkbox is unchecked, indicating incomplete test validation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for initial state pooling and natural-log space gates in the Gated Delta Net (GDN) prefill kernel for SM100 (Blackwell). Key changes include updating the kernel logic to handle indexed initial states, adding a flag to skip internal log2 conversion when gates are already in log space, and extending the Python API and reference implementations accordingly. Feedback suggests using cutlass.const_expr for compile-time branch elimination in the kernel and tightening tolerance values in new tests for consistency.

Comment thread flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
Comment thread tests/gdn/test_prefill_delta_rule.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/gdn_prefill.py (2)

317-323: ⚠️ Potential issue | 🟠 Major

Use the log-space neutral element when g is omitted.

When is_log_gate=True, the neutral gate is 0, not 1. The current fallback makes the no-g case behave like multiplying the state by e every token on SM100 instead of leaving it unchanged.

Suggested fix
-        _g = (
-            g
-            if g is not None
-            else torch.ones(
-                total_seq_len, num_sab_heads, dtype=torch.float32, device=device
-            )
-        )
+        default_gate = 0.0 if is_log_gate else 1.0
+        _g = (
+            g
+            if g is not None
+            else torch.full(
+                (total_seq_len, num_sab_heads),
+                default_gate,
+                dtype=torch.float32,
+                device=device,
+            )
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 317 - 323, The fallback for _g
currently uses ones unconditionally, which is wrong when is_log_gate=True
(neutral element should be 0 in log-space); update the assignment to set _g to a
zeros tensor when is_log_gate is True and to ones otherwise, keeping the same
shape and dtype/device (use total_seq_len, num_sab_heads, dtype=torch.float32,
device=device) and preserve the existing conditional that uses g when provided.

337-353: ⚠️ Potential issue | 🟠 Major

Validate pooled state indices before dispatch.

The new SM100 path only casts initial_state_indices to int32 and then the kernel uses those values as raw pool indices. A negative index, an out-of-range value, or a tensor with the wrong length will turn into a bad device read instead of a clean Python error.

Suggested fix
         _initial_state_indices = (
             initial_state_indices.to(torch.int32)
             if initial_state_indices is not None
             else None
         )
+        if _initial_state_indices is not None:
+            if _initial_state_indices.ndim != 1:
+                raise ValueError(
+                    "initial_state_indices must be a 1D tensor of shape [num_seqs]"
+                )
+            if _initial_state_indices.numel() != num_seqs:
+                raise ValueError(
+                    f"initial_state_indices must have {num_seqs} elements, "
+                    f"got {_initial_state_indices.numel()}"
+                )
+            if torch.any(_initial_state_indices < 0):
+                raise ValueError("initial_state_indices must be non-negative")
+            if initial_state is not None and torch.any(
+                _initial_state_indices >= initial_state.size(0)
+            ):
+                raise IndexError(
+                    "initial_state_indices contains a value outside the state pool"
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 337 - 353, The SM100 dispatch
currently casts initial_state_indices to int32 and passes them straight into
chunk_gated_delta_rule_sm100, which can cause bad device reads for negative,
out-of-range, or incorrectly sized tensors; before calling
chunk_gated_delta_rule_sm100 validate that if initial_state_indices is not None
then (1) it's a 1-D integer tensor on the same device as initial_state, (2) all
values are >= 0 and < initial_state.size(0) (use initial_state.size(0) as the
pool length), and (3) its length matches the expected number of pooled entries
(compare initial_state_indices.numel() to the expected index count for
initial_state/initial_state.shape[0] or the pool dimension you use); raise a
clear Python exception if any check fails, then cast to torch.int32 (store in
_initial_state_indices) and call chunk_gated_delta_rule_sm100 as before.
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py (1)

3026-3029: ⚠️ Potential issue | 🟠 Major

Upcast pooled initial state to the accumulator dtype before writing TMEM.

This branch currently converts the loaded state to self.state_dtype again, but the TMEM state buffer is the fp32 accumulator path. With bf16 initial_state, the first chunk starts from bf16 precision instead of the intended fp32 state.

Suggested fix
             if cutlass.const_expr(self.state_dtype != self.acc_dtype):
                 tRT_tCrState[None, 0, sub].store(
-                    tGR_tCrState[None, 0, sub].load().to(self.state_dtype)
+                    tGR_tCrState[None, 0, sub].load().to(self.acc_dtype)
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py` around lines
3026 - 3029, The branch that stores the pooled initial state into TMEM is
converting the loaded value to self.state_dtype but TMEM expects the fp32
accumulator path; update the conversion to upcast the loaded tGR_tCrState value
to self.acc_dtype before storing into tRT_tCrState (i.e., in the const_expr
branch that touches tRT_tCrState and tGR_tCrState, replace/to change the
.to(self.state_dtype) call to .to(self.acc_dtype)) so the first chunk uses the
accumulator precision instead of bf16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_gdn_prefill.py`:
- Around line 102-115: The benchmark calls chunk_gated_delta_rule with g created
via F.logsigmoid but passes it to the non-log gate path; update the call site in
bench_gdn_prefill.py so the gate representation matches the API: either pass
is_log_gate=True to chunk_gated_delta_rule (and ensure the implementation path
honors that flag) or convert the log-space gate back to linear with g.exp()
before calling chunk_gated_delta_rule; reference the chunk_gated_delta_rule call
and the g variable to make the change.

---

Outside diff comments:
In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py`:
- Around line 3026-3029: The branch that stores the pooled initial state into
TMEM is converting the loaded value to self.state_dtype but TMEM expects the
fp32 accumulator path; update the conversion to upcast the loaded tGR_tCrState
value to self.acc_dtype before storing into tRT_tCrState (i.e., in the
const_expr branch that touches tRT_tCrState and tGR_tCrState, replace/to change
the .to(self.state_dtype) call to .to(self.acc_dtype)) so the first chunk uses
the accumulator precision instead of bf16.

In `@flashinfer/gdn_prefill.py`:
- Around line 317-323: The fallback for _g currently uses ones unconditionally,
which is wrong when is_log_gate=True (neutral element should be 0 in log-space);
update the assignment to set _g to a zeros tensor when is_log_gate is True and
to ones otherwise, keeping the same shape and dtype/device (use total_seq_len,
num_sab_heads, dtype=torch.float32, device=device) and preserve the existing
conditional that uses g when provided.
- Around line 337-353: The SM100 dispatch currently casts initial_state_indices
to int32 and passes them straight into chunk_gated_delta_rule_sm100, which can
cause bad device reads for negative, out-of-range, or incorrectly sized tensors;
before calling chunk_gated_delta_rule_sm100 validate that if
initial_state_indices is not None then (1) it's a 1-D integer tensor on the same
device as initial_state, (2) all values are >= 0 and < initial_state.size(0)
(use initial_state.size(0) as the pool length), and (3) its length matches the
expected number of pooled entries (compare initial_state_indices.numel() to the
expected index count for initial_state/initial_state.shape[0] or the pool
dimension you use); raise a clear Python exception if any check fails, then cast
to torch.int32 (store in _initial_state_indices) and call
chunk_gated_delta_rule_sm100 as before.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 51f7265a-8bdd-4ffa-899a-a91e6895b7c6

📥 Commits

Reviewing files that changed from the base of the PR and between 223f2a4 and f9f68fc.

📒 Files selected for processing (6)
  • benchmarks/bench_gdn_prefill.py
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py
  • flashinfer/gdn_prefill.py
  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py

Comment thread benchmarks/bench_gdn_prefill.py
minor update

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)

473-483: The validation order is correct; consider adding _skip_if_not_sm100() for consistency.

The ValueError("initial_state_indices requires...") fires at line ~232 in chunk_gated_delta_rule, before any SM100 capability checks, so this test will correctly raise the expected error on all architectures. However, for consistency with other SM100-only feature tests in this file (test_prefill_kernel_log_gate, _test_prefill_kernel_initial_state_pool), adding _skip_if_not_sm100() would make the test scope clearer and align with the coding guideline to skip SM100-only tests on unsupported hardware.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 473 - 483, Add a call to
_skip_if_not_sm100() at the start of
test_initial_state_indices_without_state_error to make the test explicitly
skipped on non-SM100 hardware; the test currently triggers a ValueError from
chunk_gated_delta_rule (the check at ~line 232) before any SM100 checks, but for
consistency with other SM100-only tests (e.g., test_prefill_kernel_log_gate,
_test_prefill_kernel_initial_state_pool) ensure the test begins by invoking
_skip_if_not_sm100() so it is only executed on supported architectures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 473-483: Add a call to _skip_if_not_sm100() at the start of
test_initial_state_indices_without_state_error to make the test explicitly
skipped on non-SM100 hardware; the test currently triggers a ValueError from
chunk_gated_delta_rule (the check at ~line 232) before any SM100 checks, but for
consistency with other SM100-only tests (e.g., test_prefill_kernel_log_gate,
_test_prefill_kernel_initial_state_pool) ensure the test begins by invoking
_skip_if_not_sm100() so it is only executed on supported architectures.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 463800d4-9391-42bb-ab94-1ec32845fc49

📥 Commits

Reviewing files that changed from the base of the PR and between f9f68fc and eec96ed.

📒 Files selected for processing (1)
  • tests/gdn/test_prefill_delta_rule.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py (1)

3010-3034: Reuse the cached s_in_index instead of re-loading mS_init_indices[batch_idx] each iteration.

Line 3020 re-reads mS_init_indices[batch_idx] on every sub iteration of the loop even though its value was already captured into s_in_index on line 3011. Each read is a GMEM load (cacheable, but semantically redundant), and the pool-mode branch on self.is_initial_state_pool is also not wrapped in cutlass.const_expr unlike other const-flag checks in this file (e.g., lines 416, 937, 1265).

♻️ Proposed refactor
-        s_in_index = (
-            mS_init_indices[batch_idx] if self.is_initial_state_pool else batch_idx
-        )
+        if cutlass.const_expr(self.is_initial_state_pool):
+            s_in_index = mS_init_indices[batch_idx]
+        else:
+            s_in_index = batch_idx
         gS_init = cute.flat_divide(
             mS_init[None, None, head_idx, s_in_index],
             (self.mma_tiler_kv[0], self.mma_tiler_kv[1]),
         )[None, None, 0, 0]
         tGR_tCgState = thr_state_r2t.partition_S(gS_init)
         kv_acc_handle = kv_acc_producer.acquire_and_advance()
         for sub in cutlass.range(tRT_tCrState.shape[2]):
-            if self.is_initial_state_pool and mS_init_indices[batch_idx] < 0:
+            if cutlass.const_expr(self.is_initial_state_pool) and s_in_index < 0:
                 tRT_tCrState[None, 0, sub].fill(0.0)
             else:

Separately (not a bug, just a caveat): when s_in_index < 0, the mS_init[..., s_in_index] slice on line 3014 still constructs a tensor view with a negative stride offset. No load occurs because the guard on line 3020 dispatches to fill(0.0), but the code would be more robust if the slice itself was skipped or clamped on the negative path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py` around lines
3010 - 3034, Replace dynamic re-reads of mS_init_indices[batch_idx] inside the
loop with the cached s_in_index and make the pool-mode branch a compile-time
constant check; specifically, use s_in_index (captured above) in the if guard
inside the for sub loop instead of mS_init_indices[batch_idx], and wrap that
guard in cutlass.const_expr(self.is_initial_state_pool) to match other
const-flag checks; also avoid creating the mS_init slice with a negative index
by gating or clamping the initial gS_init construction when s_in_index < 0 so
you don't build a negative-stride view (i.e., only perform the mS_init[...]
slice when s_in_index >= 0 or select a safe default before creating gS_init).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py`:
- Around line 3010-3034: Replace dynamic re-reads of mS_init_indices[batch_idx]
inside the loop with the cached s_in_index and make the pool-mode branch a
compile-time constant check; specifically, use s_in_index (captured above) in
the if guard inside the for sub loop instead of mS_init_indices[batch_idx], and
wrap that guard in cutlass.const_expr(self.is_initial_state_pool) to match other
const-flag checks; also avoid creating the mS_init slice with a negative index
by gating or clamping the initial gS_init construction when s_in_index < 0 so
you don't build a negative-stride view (i.e., only perform the mS_init[...]
slice when s_in_index >= 0 or select a safe default before creating gS_init).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 232c6b8f-f144-48d7-9b63-3a0efe2bb6f1

📥 Commits

Reviewing files that changed from the base of the PR and between eec96ed and ceaef85.

📒 Files selected for processing (1)
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tests/gdn/test_prefill_delta_rule.py (3)

473-483: Consider explicitly skipping when the underlying arch is unsupported.

The ValueError is raised in the Python wrapper before any kernel dispatch, so the test should pass regardless of GPU arch — but per repo convention other tests in this file use _skip_if_unsupported() / _skip_if_not_sm100(). Adding _skip_if_unsupported() here keeps behavior uniform and avoids surprises if the wrapper later performs any arch-gated validation ordering:

🧪 Proposed skip
 def test_initial_state_indices_without_state_error():
     """initial_state_indices without initial_state should raise ValueError."""
+    _skip_if_unsupported()
     device = torch.device("cuda")

As per coding guidelines: "Skip test execution on unsupported GPU architectures using flashinfer.utils check functions".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 473 - 483, Add a GPU-arch
skip at the start of the test to match file convention: in
test_initial_state_indices_without_state_error call the repository's skip helper
(e.g., _skip_if_unsupported() or _skip_if_not_sm100()) before invoking
chunk_gated_delta_rule so the test is skipped on unsupported architectures; this
keeps behavior consistent with other tests in this file and avoids
platform-specific surprises if validation ordering changes.

393-416: Make the pool-layout convention a hard contract via an assertion.

The inline comments on lines 397–402 explain that the pool is passed as [pool_size, H, V, K] to the kernel (same layout as output_state) and transposed before the reference. Since this is the only place that documents the layout, a cheap shape assertion right before the kernel call would lock the contract and prevent silent layout regressions if output_state's convention ever changes:

🛡️ Proposed shape assert
+    assert state_pool.shape == (
+        pool_size,
+        num_sab_heads,
+        head_size,
+        head_size,
+    ), "state_pool must match kernel [pool_size, H, V, K] layout"
     chunk_gated_delta_rule(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 393 - 416, Add a sanity
assertion to lock the pool-layout contract so regressions don't silently change
the expected layout: before creating state_pool_ref and before calling
blockwise_delta_rule, assert that state_pool has shape [pool_size, H, K, V] (or
equivalently that state_pool.transpose(-1,-2) matches the kernel/output_state
layout [pool_size, H, V, K]) and raise a clear error if not; locate the
variables state_pool, state_pool_ref, and the blockwise_delta_rule call to add
this check so the test fails loudly if the pool layout convention changes.

100-102: Optional: extract the is_log_gate transform to a small helper.

The (alpha + 1e-10).log() if is_log_gate else alpha expression is duplicated in _test_prefill_kernel and _test_prefill_kernel_initial_state_pool, and the 1e-10 epsilon is also hard-coded in the reference (to_logspace_Gamma_and_gamma). A tiny helper (e.g. _alpha_to_kernel_space(alpha, is_log_gate)) would keep the epsilon consistent and reduce copy/paste as more log-gate tests land.

Also applies to: 362-362

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 100 - 102, Extract the
repeated transform "(alpha + 1e-10).log() if is_log_gate else alpha" into a
small helper function (e.g. _alpha_to_kernel_space(alpha, is_log_gate)) and use
it in _test_prefill_kernel and _test_prefill_kernel_initial_state_pool; also
centralize the epsilon (1e-10) so to_logspace_Gamma_and_gamma uses the same
constant from that helper to keep behavior consistent across tests. Ensure the
helper is imported/defined near the tests and replace the duplicated inline
expressions with calls to _alpha_to_kernel_space.
tests/gdn/reference_delta_rule.py (1)

406-424: Optional: consolidate the three-way branching to remove a duplicate zero-state construction.

The logic is correct and correctly handles the negative-index sentinel (matches the "fix for negative indices" commit). A small refactor removes the duplicated zero-tensor construction and the nested conditionals:

♻️ Proposed simplification
-        if initial_state is not None:
-            if initial_state_indices is not None:
-                if initial_state_indices[seq_idx] >= 0:
-                    pool_idx = initial_state_indices[seq_idx]
-                    state_HKV = initial_state[pool_idx].to(state_dtype).to(q.device)
-                else:
-                    state_HKV = torch.zeros(
-                        (num_sab_heads, head_size, head_size),
-                        dtype=state_dtype,
-                        device=q.device,
-                    )
-            else:
-                state_HKV = initial_state[seq_idx].to(state_dtype).to(q.device)
-        else:
-            state_HKV = torch.zeros(
-                (num_sab_heads, head_size, head_size),
-                dtype=state_dtype,
-                device=q.device,
-            )
+        pool_idx = (
+            int(initial_state_indices[seq_idx])
+            if initial_state_indices is not None
+            else seq_idx
+        )
+        if initial_state is not None and pool_idx >= 0:
+            state_HKV = initial_state[pool_idx].to(device=q.device, dtype=state_dtype)
+        else:
+            state_HKV = torch.zeros(
+                (num_sab_heads, head_size, head_size),
+                dtype=state_dtype,
+                device=q.device,
+            )

Also consider adding a sanity assert 0 <= pool_idx < initial_state.size(0) (when non-negative) so an out-of-range index in a test surfaces as an explicit failure rather than a silent wrap via tensor indexing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/reference_delta_rule.py` around lines 406 - 424, Consolidate the
nested branching around initial_state/initial_state_indices to avoid
constructing the zero-state tensor in three places: compute pool_idx only when
initial_state_indices is not None and initial_state is not None, then if
pool_idx >= 0 assert 0 <= pool_idx < initial_state.size(0) and set state_HKV =
initial_state[pool_idx].to(state_dtype).to(q.device), else set a single
zero-state tensor once (using num_sab_heads, head_size, state_dtype, q.device)
for all other cases; this removes duplicate zero-state construction and adds the
suggested bounds check for pool_idx while keeping the existing negative-index
sentinel behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gdn/reference_delta_rule.py`:
- Around line 406-424: Consolidate the nested branching around
initial_state/initial_state_indices to avoid constructing the zero-state tensor
in three places: compute pool_idx only when initial_state_indices is not None
and initial_state is not None, then if pool_idx >= 0 assert 0 <= pool_idx <
initial_state.size(0) and set state_HKV =
initial_state[pool_idx].to(state_dtype).to(q.device), else set a single
zero-state tensor once (using num_sab_heads, head_size, state_dtype, q.device)
for all other cases; this removes duplicate zero-state construction and adds the
suggested bounds check for pool_idx while keeping the existing negative-index
sentinel behavior.

In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 473-483: Add a GPU-arch skip at the start of the test to match
file convention: in test_initial_state_indices_without_state_error call the
repository's skip helper (e.g., _skip_if_unsupported() or _skip_if_not_sm100())
before invoking chunk_gated_delta_rule so the test is skipped on unsupported
architectures; this keeps behavior consistent with other tests in this file and
avoids platform-specific surprises if validation ordering changes.
- Around line 393-416: Add a sanity assertion to lock the pool-layout contract
so regressions don't silently change the expected layout: before creating
state_pool_ref and before calling blockwise_delta_rule, assert that state_pool
has shape [pool_size, H, K, V] (or equivalently that state_pool.transpose(-1,-2)
matches the kernel/output_state layout [pool_size, H, V, K]) and raise a clear
error if not; locate the variables state_pool, state_pool_ref, and the
blockwise_delta_rule call to add this check so the test fails loudly if the pool
layout convention changes.
- Around line 100-102: Extract the repeated transform "(alpha + 1e-10).log() if
is_log_gate else alpha" into a small helper function (e.g.
_alpha_to_kernel_space(alpha, is_log_gate)) and use it in _test_prefill_kernel
and _test_prefill_kernel_initial_state_pool; also centralize the epsilon (1e-10)
so to_logspace_Gamma_and_gamma uses the same constant from that helper to keep
behavior consistent across tests. Ensure the helper is imported/defined near the
tests and replace the duplicated inline expressions with calls to
_alpha_to_kernel_space.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9a0c22b5-18e5-4deb-81b5-258aacd08733

📥 Commits

Reviewing files that changed from the base of the PR and between ceaef85 and bc63825.

📒 Files selected for processing (2)
  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just one nit question

Comment thread flashinfer/gdn_prefill.py
beta: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this parameter makes more sense to insert right after initial_state; but can we double check if there any backwards compatibility concerns (breaking positional ordering) here upstream? Not sure if perplexity has picked this up yet.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaixih @YAMY1234 @hlu1 do you know if SGLang has used the latest gdn_prefill.py interface? (if so, this would be an interface breakage)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we are aware of this change. Actually, we requested this improvement during sgl-project/sglang#22921.

@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !604 has been created, and the CI pipeline #49407577 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants