[feat] add log gate and initial state pool support in blackwell gdn prefill#3167
[feat] add log gate and initial state pool support in blackwell gdn prefill#3167Observer007 wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds pooled initial-state indexing and an optional log-space gate mode to the GDN prefill path, updating SM100 (Blackwell) kernels, wrappers, public API signatures, reference implementation, tests, and a benchmark call site. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant API as chunk_gated_delta_rule API
participant SM100 as SM100 Wrapper
participant Compiler as Kernel Compiler
participant Kernel as Device Kernel
User->>API: Call with tensors + initial_state, initial_state_indices, is_log_gate
API->>API: Validate initial_state_indices requires initial_state
alt SM100 path
API->>SM100: Forward tensors & flags (s_in, s_in_indices, is_log_gate, ...)
SM100->>Compiler: Build/cache key (is_log_gate, is_initial_state_pool)
Compiler->>Kernel: Compile kernel with flags
SM100->>Kernel: Execute with s_in_indices, is_log_gate
Kernel->>Kernel: Load initial state via mS_init_indices[batch_idx] (zero if negative)
Kernel->>Kernel: Preprocess gate (treat as ln if is_log_gate)
Kernel-->>SM100: Return outputs
else SM90/Hopper path
API-->>User: Raise NotImplementedError
end
SM100-->>API: Result
API-->>User: Output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for initial state pooling and natural-log space gates in the Gated Delta Net (GDN) prefill kernel for SM100 (Blackwell). Key changes include updating the kernel logic to handle indexed initial states, adding a flag to skip internal log2 conversion when gates are already in log space, and extending the Python API and reference implementations accordingly. Feedback suggests using cutlass.const_expr for compile-time branch elimination in the kernel and tightening tolerance values in new tests for consistency.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/gdn_prefill.py (2)
317-323:⚠️ Potential issue | 🟠 MajorUse the log-space neutral element when
gis omitted.When
is_log_gate=True, the neutral gate is0, not1. The current fallback makes the no-gcase behave like multiplying the state byeevery token on SM100 instead of leaving it unchanged.Suggested fix
- _g = ( - g - if g is not None - else torch.ones( - total_seq_len, num_sab_heads, dtype=torch.float32, device=device - ) - ) + default_gate = 0.0 if is_log_gate else 1.0 + _g = ( + g + if g is not None + else torch.full( + (total_seq_len, num_sab_heads), + default_gate, + dtype=torch.float32, + device=device, + ) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 317 - 323, The fallback for _g currently uses ones unconditionally, which is wrong when is_log_gate=True (neutral element should be 0 in log-space); update the assignment to set _g to a zeros tensor when is_log_gate is True and to ones otherwise, keeping the same shape and dtype/device (use total_seq_len, num_sab_heads, dtype=torch.float32, device=device) and preserve the existing conditional that uses g when provided.
337-353:⚠️ Potential issue | 🟠 MajorValidate pooled state indices before dispatch.
The new SM100 path only casts
initial_state_indicestoint32and then the kernel uses those values as raw pool indices. A negative index, an out-of-range value, or a tensor with the wrong length will turn into a bad device read instead of a clean Python error.Suggested fix
_initial_state_indices = ( initial_state_indices.to(torch.int32) if initial_state_indices is not None else None ) + if _initial_state_indices is not None: + if _initial_state_indices.ndim != 1: + raise ValueError( + "initial_state_indices must be a 1D tensor of shape [num_seqs]" + ) + if _initial_state_indices.numel() != num_seqs: + raise ValueError( + f"initial_state_indices must have {num_seqs} elements, " + f"got {_initial_state_indices.numel()}" + ) + if torch.any(_initial_state_indices < 0): + raise ValueError("initial_state_indices must be non-negative") + if initial_state is not None and torch.any( + _initial_state_indices >= initial_state.size(0) + ): + raise IndexError( + "initial_state_indices contains a value outside the state pool" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 337 - 353, The SM100 dispatch currently casts initial_state_indices to int32 and passes them straight into chunk_gated_delta_rule_sm100, which can cause bad device reads for negative, out-of-range, or incorrectly sized tensors; before calling chunk_gated_delta_rule_sm100 validate that if initial_state_indices is not None then (1) it's a 1-D integer tensor on the same device as initial_state, (2) all values are >= 0 and < initial_state.size(0) (use initial_state.size(0) as the pool length), and (3) its length matches the expected number of pooled entries (compare initial_state_indices.numel() to the expected index count for initial_state/initial_state.shape[0] or the pool dimension you use); raise a clear Python exception if any check fails, then cast to torch.int32 (store in _initial_state_indices) and call chunk_gated_delta_rule_sm100 as before.flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py (1)
3026-3029:⚠️ Potential issue | 🟠 MajorUpcast pooled initial state to the accumulator dtype before writing TMEM.
This branch currently converts the loaded state to
self.state_dtypeagain, but the TMEM state buffer is the fp32 accumulator path. With bf16initial_state, the first chunk starts from bf16 precision instead of the intended fp32 state.Suggested fix
if cutlass.const_expr(self.state_dtype != self.acc_dtype): tRT_tCrState[None, 0, sub].store( - tGR_tCrState[None, 0, sub].load().to(self.state_dtype) + tGR_tCrState[None, 0, sub].load().to(self.acc_dtype) )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py` around lines 3026 - 3029, The branch that stores the pooled initial state into TMEM is converting the loaded value to self.state_dtype but TMEM expects the fp32 accumulator path; update the conversion to upcast the loaded tGR_tCrState value to self.acc_dtype before storing into tRT_tCrState (i.e., in the const_expr branch that touches tRT_tCrState and tGR_tCrState, replace/to change the .to(self.state_dtype) call to .to(self.acc_dtype)) so the first chunk uses the accumulator precision instead of bf16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_gdn_prefill.py`:
- Around line 102-115: The benchmark calls chunk_gated_delta_rule with g created
via F.logsigmoid but passes it to the non-log gate path; update the call site in
bench_gdn_prefill.py so the gate representation matches the API: either pass
is_log_gate=True to chunk_gated_delta_rule (and ensure the implementation path
honors that flag) or convert the log-space gate back to linear with g.exp()
before calling chunk_gated_delta_rule; reference the chunk_gated_delta_rule call
and the g variable to make the change.
---
Outside diff comments:
In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py`:
- Around line 3026-3029: The branch that stores the pooled initial state into
TMEM is converting the loaded value to self.state_dtype but TMEM expects the
fp32 accumulator path; update the conversion to upcast the loaded tGR_tCrState
value to self.acc_dtype before storing into tRT_tCrState (i.e., in the
const_expr branch that touches tRT_tCrState and tGR_tCrState, replace/to change
the .to(self.state_dtype) call to .to(self.acc_dtype)) so the first chunk uses
the accumulator precision instead of bf16.
In `@flashinfer/gdn_prefill.py`:
- Around line 317-323: The fallback for _g currently uses ones unconditionally,
which is wrong when is_log_gate=True (neutral element should be 0 in log-space);
update the assignment to set _g to a zeros tensor when is_log_gate is True and
to ones otherwise, keeping the same shape and dtype/device (use total_seq_len,
num_sab_heads, dtype=torch.float32, device=device) and preserve the existing
conditional that uses g when provided.
- Around line 337-353: The SM100 dispatch currently casts initial_state_indices
to int32 and passes them straight into chunk_gated_delta_rule_sm100, which can
cause bad device reads for negative, out-of-range, or incorrectly sized tensors;
before calling chunk_gated_delta_rule_sm100 validate that if
initial_state_indices is not None then (1) it's a 1-D integer tensor on the same
device as initial_state, (2) all values are >= 0 and < initial_state.size(0)
(use initial_state.size(0) as the pool length), and (3) its length matches the
expected number of pooled entries (compare initial_state_indices.numel() to the
expected index count for initial_state/initial_state.shape[0] or the pool
dimension you use); raise a clear Python exception if any check fails, then cast
to torch.int32 (store in _initial_state_indices) and call
chunk_gated_delta_rule_sm100 as before.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 51f7265a-8bdd-4ffa-899a-a91e6895b7c6
📒 Files selected for processing (6)
benchmarks/bench_gdn_prefill.pyflashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.pyflashinfer/gdn_kernels/blackwell/gdn_prefill.pyflashinfer/gdn_prefill.pytests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
minor update Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)
473-483: The validation order is correct; consider adding_skip_if_not_sm100()for consistency.The
ValueError("initial_state_indices requires...")fires at line ~232 inchunk_gated_delta_rule, before any SM100 capability checks, so this test will correctly raise the expected error on all architectures. However, for consistency with other SM100-only feature tests in this file (test_prefill_kernel_log_gate,_test_prefill_kernel_initial_state_pool), adding_skip_if_not_sm100()would make the test scope clearer and align with the coding guideline to skip SM100-only tests on unsupported hardware.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 473 - 483, Add a call to _skip_if_not_sm100() at the start of test_initial_state_indices_without_state_error to make the test explicitly skipped on non-SM100 hardware; the test currently triggers a ValueError from chunk_gated_delta_rule (the check at ~line 232) before any SM100 checks, but for consistency with other SM100-only tests (e.g., test_prefill_kernel_log_gate, _test_prefill_kernel_initial_state_pool) ensure the test begins by invoking _skip_if_not_sm100() so it is only executed on supported architectures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 473-483: Add a call to _skip_if_not_sm100() at the start of
test_initial_state_indices_without_state_error to make the test explicitly
skipped on non-SM100 hardware; the test currently triggers a ValueError from
chunk_gated_delta_rule (the check at ~line 232) before any SM100 checks, but for
consistency with other SM100-only tests (e.g., test_prefill_kernel_log_gate,
_test_prefill_kernel_initial_state_pool) ensure the test begins by invoking
_skip_if_not_sm100() so it is only executed on supported architectures.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 463800d4-9391-42bb-ab94-1ec32845fc49
📒 Files selected for processing (1)
tests/gdn/test_prefill_delta_rule.py
… into feat/gdn_log_gate
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py (1)
3010-3034: Reuse the cacheds_in_indexinstead of re-loadingmS_init_indices[batch_idx]each iteration.Line 3020 re-reads
mS_init_indices[batch_idx]on everysubiteration of the loop even though its value was already captured intos_in_indexon line 3011. Each read is a GMEM load (cacheable, but semantically redundant), and the pool-mode branch onself.is_initial_state_poolis also not wrapped incutlass.const_exprunlike other const-flag checks in this file (e.g., lines 416, 937, 1265).♻️ Proposed refactor
- s_in_index = ( - mS_init_indices[batch_idx] if self.is_initial_state_pool else batch_idx - ) + if cutlass.const_expr(self.is_initial_state_pool): + s_in_index = mS_init_indices[batch_idx] + else: + s_in_index = batch_idx gS_init = cute.flat_divide( mS_init[None, None, head_idx, s_in_index], (self.mma_tiler_kv[0], self.mma_tiler_kv[1]), )[None, None, 0, 0] tGR_tCgState = thr_state_r2t.partition_S(gS_init) kv_acc_handle = kv_acc_producer.acquire_and_advance() for sub in cutlass.range(tRT_tCrState.shape[2]): - if self.is_initial_state_pool and mS_init_indices[batch_idx] < 0: + if cutlass.const_expr(self.is_initial_state_pool) and s_in_index < 0: tRT_tCrState[None, 0, sub].fill(0.0) else:Separately (not a bug, just a caveat): when
s_in_index < 0, themS_init[..., s_in_index]slice on line 3014 still constructs a tensor view with a negative stride offset. No load occurs because the guard on line 3020 dispatches tofill(0.0), but the code would be more robust if the slice itself was skipped or clamped on the negative path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py` around lines 3010 - 3034, Replace dynamic re-reads of mS_init_indices[batch_idx] inside the loop with the cached s_in_index and make the pool-mode branch a compile-time constant check; specifically, use s_in_index (captured above) in the if guard inside the for sub loop instead of mS_init_indices[batch_idx], and wrap that guard in cutlass.const_expr(self.is_initial_state_pool) to match other const-flag checks; also avoid creating the mS_init slice with a negative index by gating or clamping the initial gS_init construction when s_in_index < 0 so you don't build a negative-stride view (i.e., only perform the mS_init[...] slice when s_in_index >= 0 or select a safe default before creating gS_init).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py`:
- Around line 3010-3034: Replace dynamic re-reads of mS_init_indices[batch_idx]
inside the loop with the cached s_in_index and make the pool-mode branch a
compile-time constant check; specifically, use s_in_index (captured above) in
the if guard inside the for sub loop instead of mS_init_indices[batch_idx], and
wrap that guard in cutlass.const_expr(self.is_initial_state_pool) to match other
const-flag checks; also avoid creating the mS_init slice with a negative index
by gating or clamping the initial gS_init construction when s_in_index < 0 so
you don't build a negative-stride view (i.e., only perform the mS_init[...]
slice when s_in_index >= 0 or select a safe default before creating gS_init).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 232c6b8f-f144-48d7-9b63-3a0efe2bb6f1
📒 Files selected for processing (1)
flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tests/gdn/test_prefill_delta_rule.py (3)
473-483: Consider explicitly skipping when the underlying arch is unsupported.The ValueError is raised in the Python wrapper before any kernel dispatch, so the test should pass regardless of GPU arch — but per repo convention other tests in this file use
_skip_if_unsupported()/_skip_if_not_sm100(). Adding_skip_if_unsupported()here keeps behavior uniform and avoids surprises if the wrapper later performs any arch-gated validation ordering:🧪 Proposed skip
def test_initial_state_indices_without_state_error(): """initial_state_indices without initial_state should raise ValueError.""" + _skip_if_unsupported() device = torch.device("cuda")As per coding guidelines: "Skip test execution on unsupported GPU architectures using
flashinfer.utilscheck functions".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 473 - 483, Add a GPU-arch skip at the start of the test to match file convention: in test_initial_state_indices_without_state_error call the repository's skip helper (e.g., _skip_if_unsupported() or _skip_if_not_sm100()) before invoking chunk_gated_delta_rule so the test is skipped on unsupported architectures; this keeps behavior consistent with other tests in this file and avoids platform-specific surprises if validation ordering changes.
393-416: Make the pool-layout convention a hard contract via an assertion.The inline comments on lines 397–402 explain that the pool is passed as
[pool_size, H, V, K]to the kernel (same layout asoutput_state) and transposed before the reference. Since this is the only place that documents the layout, a cheap shape assertion right before the kernel call would lock the contract and prevent silent layout regressions ifoutput_state's convention ever changes:🛡️ Proposed shape assert
+ assert state_pool.shape == ( + pool_size, + num_sab_heads, + head_size, + head_size, + ), "state_pool must match kernel [pool_size, H, V, K] layout" chunk_gated_delta_rule(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 393 - 416, Add a sanity assertion to lock the pool-layout contract so regressions don't silently change the expected layout: before creating state_pool_ref and before calling blockwise_delta_rule, assert that state_pool has shape [pool_size, H, K, V] (or equivalently that state_pool.transpose(-1,-2) matches the kernel/output_state layout [pool_size, H, V, K]) and raise a clear error if not; locate the variables state_pool, state_pool_ref, and the blockwise_delta_rule call to add this check so the test fails loudly if the pool layout convention changes.
100-102: Optional: extract theis_log_gatetransform to a small helper.The
(alpha + 1e-10).log() if is_log_gate else alphaexpression is duplicated in_test_prefill_kerneland_test_prefill_kernel_initial_state_pool, and the1e-10epsilon is also hard-coded in the reference (to_logspace_Gamma_and_gamma). A tiny helper (e.g._alpha_to_kernel_space(alpha, is_log_gate)) would keep the epsilon consistent and reduce copy/paste as more log-gate tests land.Also applies to: 362-362
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_prefill_delta_rule.py` around lines 100 - 102, Extract the repeated transform "(alpha + 1e-10).log() if is_log_gate else alpha" into a small helper function (e.g. _alpha_to_kernel_space(alpha, is_log_gate)) and use it in _test_prefill_kernel and _test_prefill_kernel_initial_state_pool; also centralize the epsilon (1e-10) so to_logspace_Gamma_and_gamma uses the same constant from that helper to keep behavior consistent across tests. Ensure the helper is imported/defined near the tests and replace the duplicated inline expressions with calls to _alpha_to_kernel_space.tests/gdn/reference_delta_rule.py (1)
406-424: Optional: consolidate the three-way branching to remove a duplicate zero-state construction.The logic is correct and correctly handles the negative-index sentinel (matches the "fix for negative indices" commit). A small refactor removes the duplicated zero-tensor construction and the nested conditionals:
♻️ Proposed simplification
- if initial_state is not None: - if initial_state_indices is not None: - if initial_state_indices[seq_idx] >= 0: - pool_idx = initial_state_indices[seq_idx] - state_HKV = initial_state[pool_idx].to(state_dtype).to(q.device) - else: - state_HKV = torch.zeros( - (num_sab_heads, head_size, head_size), - dtype=state_dtype, - device=q.device, - ) - else: - state_HKV = initial_state[seq_idx].to(state_dtype).to(q.device) - else: - state_HKV = torch.zeros( - (num_sab_heads, head_size, head_size), - dtype=state_dtype, - device=q.device, - ) + pool_idx = ( + int(initial_state_indices[seq_idx]) + if initial_state_indices is not None + else seq_idx + ) + if initial_state is not None and pool_idx >= 0: + state_HKV = initial_state[pool_idx].to(device=q.device, dtype=state_dtype) + else: + state_HKV = torch.zeros( + (num_sab_heads, head_size, head_size), + dtype=state_dtype, + device=q.device, + )Also consider adding a sanity assert
0 <= pool_idx < initial_state.size(0)(when non-negative) so an out-of-range index in a test surfaces as an explicit failure rather than a silent wrap via tensor indexing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/reference_delta_rule.py` around lines 406 - 424, Consolidate the nested branching around initial_state/initial_state_indices to avoid constructing the zero-state tensor in three places: compute pool_idx only when initial_state_indices is not None and initial_state is not None, then if pool_idx >= 0 assert 0 <= pool_idx < initial_state.size(0) and set state_HKV = initial_state[pool_idx].to(state_dtype).to(q.device), else set a single zero-state tensor once (using num_sab_heads, head_size, state_dtype, q.device) for all other cases; this removes duplicate zero-state construction and adds the suggested bounds check for pool_idx while keeping the existing negative-index sentinel behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/gdn/reference_delta_rule.py`:
- Around line 406-424: Consolidate the nested branching around
initial_state/initial_state_indices to avoid constructing the zero-state tensor
in three places: compute pool_idx only when initial_state_indices is not None
and initial_state is not None, then if pool_idx >= 0 assert 0 <= pool_idx <
initial_state.size(0) and set state_HKV =
initial_state[pool_idx].to(state_dtype).to(q.device), else set a single
zero-state tensor once (using num_sab_heads, head_size, state_dtype, q.device)
for all other cases; this removes duplicate zero-state construction and adds the
suggested bounds check for pool_idx while keeping the existing negative-index
sentinel behavior.
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 473-483: Add a GPU-arch skip at the start of the test to match
file convention: in test_initial_state_indices_without_state_error call the
repository's skip helper (e.g., _skip_if_unsupported() or _skip_if_not_sm100())
before invoking chunk_gated_delta_rule so the test is skipped on unsupported
architectures; this keeps behavior consistent with other tests in this file and
avoids platform-specific surprises if validation ordering changes.
- Around line 393-416: Add a sanity assertion to lock the pool-layout contract
so regressions don't silently change the expected layout: before creating
state_pool_ref and before calling blockwise_delta_rule, assert that state_pool
has shape [pool_size, H, K, V] (or equivalently that state_pool.transpose(-1,-2)
matches the kernel/output_state layout [pool_size, H, V, K]) and raise a clear
error if not; locate the variables state_pool, state_pool_ref, and the
blockwise_delta_rule call to add this check so the test fails loudly if the pool
layout convention changes.
- Around line 100-102: Extract the repeated transform "(alpha + 1e-10).log() if
is_log_gate else alpha" into a small helper function (e.g.
_alpha_to_kernel_space(alpha, is_log_gate)) and use it in _test_prefill_kernel
and _test_prefill_kernel_initial_state_pool; also centralize the epsilon (1e-10)
so to_logspace_Gamma_and_gamma uses the same constant from that helper to keep
behavior consistent across tests. Ensure the helper is imported/defined near the
tests and replace the duplicated inline expressions with calls to
_alpha_to_kernel_space.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9a0c22b5-18e5-4deb-81b5-258aacd08733
📒 Files selected for processing (2)
tests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
| beta: Optional[torch.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| initial_state: Optional[torch.Tensor] = None, | ||
| initial_state_indices: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
I understand this parameter makes more sense to insert right after initial_state; but can we double check if there any backwards compatibility concerns (breaking positional ordering) here upstream? Not sure if perplexity has picked this up yet.
There was a problem hiding this comment.
yes, we are aware of this change. Actually, we requested this improvement during sgl-project/sglang#22921.
|
/bot run |
📌 Description
Add optional log gate and initial state pool support. The demands are from sglang team.
No perf regression after the new feature addition.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests