Skip to content

Commit 2a67426

Browse files
authored
feat: add Gated DeltaNet linear-attention layer for Qwen3-5 (#1059)
Standalone GDN stack depthwise causal conv1d into a gated delta-rule recurrence wired up for the Qwen3-5 checkpoint layout. Decoupled from the full model assembly so it can land independently. - `MergedColumnParallelLinear` (srt/layers/linear.py) - GDN kernels + backend (srt/layers/attention/linear/)
1 parent 17db854 commit 2a67426

11 files changed

Lines changed: 3012 additions & 0 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Gated DeltaNet (GDN) reference kernels.
2+
3+
Public entry points:
4+
5+
* :func:`ragged_gated_delta_rule_ref` — token-by-token ``lax.scan`` over a
6+
packed ragged batch (extend / chunked-prefill).
7+
* :func:`decode_gated_delta_rule_ref` — parallel single-step recurrence
8+
across the batch (decode fast path).
9+
* :func:`jax_causal_conv1d_prefill` / :func:`jax_causal_conv1d_update` —
10+
depthwise causal conv1d helpers (ragged prefill + single-token decode).
11+
"""
12+
13+
from sgl_jax.srt.kernels.gdn.gated_delta import (
14+
decode_gated_delta_rule_ref,
15+
jax_causal_conv1d_prefill,
16+
jax_causal_conv1d_update,
17+
ragged_gated_delta_rule_ref,
18+
)
19+
20+
__all__ = [
21+
"decode_gated_delta_rule_ref",
22+
"jax_causal_conv1d_prefill",
23+
"jax_causal_conv1d_update",
24+
"ragged_gated_delta_rule_ref",
25+
]

python/sgl_jax/srt/kernels/gdn/gated_delta.py

Lines changed: 522 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)