Skip to content

feat: add Gated DeltaNet linear-attention layer for Qwen3-5#1059

Merged
hlFu merged 6 commits into
sgl-project:mainfrom
hlFu:feat/GDN-dev
May 14, 2026
Merged

feat: add Gated DeltaNet linear-attention layer for Qwen3-5#1059
hlFu merged 6 commits into
sgl-project:mainfrom
hlFu:feat/GDN-dev

Conversation

@hlFu
Copy link
Copy Markdown
Collaborator

@hlFu hlFu commented May 12, 2026

Standalone GDN stack — depthwise causal conv1d into a gated delta-rule
recurrence — wired up for the Qwen3-5 checkpoint layout. Decoupled from
the full model assembly so it can land independently.

  • MergedColumnParallelLinear (srt/layers/linear.py): column-parallel
    linear with N logical outputs fused into one weight; per-device
    block-concat layout matches sglang/vLLM so each component shards
    cleanly on its own head dim (GQA-safe).
  • GDN kernels + backend (srt/layers/attention/linear/): gated_delta.py
    primitives (l2norm, single delta step, ragged + decode kernels, causal
    conv1d prefill/update); GDNAttnBackend runs conv + recurrence under
    jax.shard_map; GDNAttnMetadata carries packed-batch boundaries
    through jit; Qwen3_5GatedDeltaNet glues four projections + gated
    GemmaRMSNorm + out_proj.
  • ForwardBatch: adds mamba_cache_indices and gdn_metadata for
    per-request slot indexing into the recurrent state pool.
  • Tests (test/srt/, 45 cases across 6 new files) cover kernels against a
    Python reference, shard_map dispatch, and end-to-end shape/finiteness;
    all pass on CPU at TP=1.

Motivation

Refs #1007.

Qwen3-5 uses a hybrid attention architecture where a subset of decoder
layers replace full self-attention with a Gated DeltaNet (GDN) linear-
attention block (depthwise causal conv1d → gated delta-rule recurrence).
This PR lands the standalone GDN stack — JAX kernels, backend, layer, and
the ForwardBatch plumbing for the recurrent state pool — so the
follow-up wiring PRs can build on a stable, independently-tested
foundation:

This PR — kernels + backend + layer + ForwardBatch hooks (no
caller in production code yet).

Modifications

See the bullet summary at the top. Two structural points worth calling
out explicitly:

  • Per-shard layout convention. MergedColumnParallelLinear produces
    per-device block-concat [q_tp | k_tp | v_tp] activations, and the
    conv1d_weight in GDNAttnBackend is loaded with the same stripe so
    the per-shard conv1d channels line up with the activation layout
    (matches sglang's mamba_v2_sharded_weight_loader). The loader
    contract is documented in gdn_backend.py. The follow-up loader PR is
    responsible for the actual stripe-rearrange.
  • Divisibility guards. GDNAttnBackend.__init__ asserts
    num_k_heads % TP == 0, num_v_heads % TP == 0, conv_dim % TP == 0,
    and num_v_heads % num_k_heads == 0 (the GQA repeat factor). These
    are stricter than MergedColumnParallelLinear's key_dim % TP check
    and catch e.g. num_k_heads=1, TP=2 configs where key_dim is
    divisible but num_k_heads // TP silently floors to 0.

Accuracy Tests

N/A for this PR — the GDN layer is not yet wired into any model forward
(no callers in production code), so it doesn't change the outputs of any
currently-served model.

Kernel-level correctness is covered in the unit tests:

  • test_ragged_gated_delta_rule_ref.py — kernel vs. a token-by-token
    straight-Python reference (atol=1e-3, rtol=1e-3 on output; atol=1e-4,
    rtol=1e-4 on state).
  • test_gated_delta.py — primitive equality tests (_l2norm,
    _gated_delta_step leading-dim-agnostic property, causal conv1d
    update/prefill with multi-request boundaries, K=1 edge case,
    short-request left-pad), plus decode-vs-ragged numerical equivalence
    (including GQA).

End-to-end accuracy against an HF Qwen3-5 checkpoint will land in the
integration follow-up PR (MMLU / GPQA on a small Qwen3-5 build).

Run locally:
JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8

python -m pytest test/srt/test_gated_delta.py

test/srt/test_gdn_backend.py

test/srt/test_gdn_metadata.py

test/srt/test_qwen3_5_gated_delta_net.py

test/srt/test_ragged_gated_delta_rule_ref.py

test/srt/test_merged_column_parallel_linear.py -v

CI: all five files are registered in unit-test-tpu-v6e-1 in
test/srt/run_suite.py.

Benchmarking and Profiling

N/A — no model currently dispatches to this code path, so this PR
cannot regress (or improve) end-to-end inference speed. The
recurrence's lax.scan and the conv1d's per-token gather are both
shape-dependent on T; padding/bucketing of prefill T is the
caller's responsibility and will be exercised in the integration PR.

Single-layer benchmarking (decode step-time, prefill T-scaling, TP
scaling) will be reported in the integration follow-up alongside the
first end-to-end run.

Checklist

  • Please use English, otherwise it will be closed.
  • The purpose of the PR, or link existing issues this PR will resolve.
  • The test plan, such as providing test command.
  • (Optional) The necessary documentation update.

Standalone GDN stack depthwise causal conv1d into a gated delta-rule
recurrence wired up for the Qwen3-5 checkpoint layout. Decoupled from
the full model assembly so it can land independently.

- `MergedColumnParallelLinear` (srt/layers/linear.py): column-parallel
  linear with N logical outputs fused into one weight; per-device
  block-concat layout matches sglang/vLLM so each component shards
  cleanly on its own head dim (GQA-safe).
- GDN kernels + backend (srt/layers/attention/linear/): `gated_delta.py`
  primitives (l2norm, single delta step, ragged + decode kernels, causal
  conv1d prefill/update); `GDNAttnBackend` runs conv + recurrence under
  `jax.shard_map`; `GDNAttnMetadata` carries packed-batch boundaries
  through jit; `Qwen3_5GatedDeltaNet` glues four projections + gated
  GemmaRMSNorm + `out_proj`.
- `ForwardBatch`: adds `mamba_cache_indices` and `gdn_metadata` for
  per-request slot indexing into the recurrent state pool.
- Tests (test/srt/, 45 cases across 6 new files) cover kernels against a
  Python reference, shard_map dispatch, and end-to-end shape/finiteness;
  all pass on CPU at TP=1.
@gemini-code-assist
Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@cjx0709
Copy link
Copy Markdown
Collaborator

cjx0709 commented May 12, 2026

Hello, thanks for your contribution. Structure-wise, could we algin this more closely with the KDA PR(#1051)(you can also take a look for reviewing it, this pr will be merged in next one or two days).
KDA separates the layers as:

  • low-level kernels/reference kernels under python/sgl_jax/srt/kernels/kda/
  • reusable recurrent-attention backend under python/sgl_jax/srt/layers/attention/linear/kda_backend.py
  • shared short conv helper under python/sgl_jax/srt/layers/attention/linear/short_convolution.py
  • unit tests under python/sgl_jax/test/...

For GDN, some commen class and instruction can be reused:

  • make GDNAttnBackend inherit/use LinearRecurrentAttnBackend and its metadata instead of adding a parallel GDN-specific metadata path,
  • like KDA, we recomand recurrent_state_pool for the call input args instread of using conv_state_in and recurrent_state_in

@hlFu
Copy link
Copy Markdown
Collaborator Author

hlFu commented May 12, 2026

@cjx0709 Thanks for your comments! I think they make sense to me and I changed according to these points in my latest commit. Please check again.

  • low-level kernels/reference kernels under python/sgl_jax/srt/kernels/kda/
  • reusable recurrent-attention backend under python/sgl_jax/srt/layers/attention/linear/kda_backend.py
  • unit tests under python/sgl_jax/test/...
  • make GDNAttnBackend inherit/use LinearRecurrentAttnBackend and its metadata instead of adding a parallel GDN-specific metadata path,
  • like KDA, we recomand recurrent_state_pool for the call input args instread of using conv_state_in and recurrent_state_in

For:

  • shared short conv helper under python/sgl_jax/srt/layers/attention/linear/short_convolution.py

I think following it up with refactoring might be better since:

  1. This introduces a dependency on another open PR
  2. The implementation and interface is slightly different.
    • The current GDN conv implementation takes in mixed qkv in shape [D, T] while short_conv takes in [T, D]. The consideration on GDN side is the GPU conv kernel takes [D, T] so in order to make it easier replaceable it keeps the same shape.
    • conv pool and indices are passed in conv kernel individually, allowing read and write to happen both in the kernel.

@cjx0709
Copy link
Copy Markdown
Collaborator

cjx0709 commented May 13, 2026

@hlFu I see your new commits. That's all right, the comment that relying on PRs that haven't been merged yet is indeed not a good idea.
And I have a new question, like the photo below:
image
make algin to the sglang or sglang-jax another attentionbackend desgin, we better use RadixLinearAttention instead of just using XXXAttnBackend(but now I find the radixlinearAttention function signature not 100% match for qkv_fused and qkv split, if you want to merge it in the current situation, it's ok. We can leave a todo issue for later. After qwen3.5 and kimi linear are merged, we can refactor the files by referring to the class structure of sglang's attention backend and use some commen basic kernel. This will help to unify our code).
What do you think?

Comment thread python/sgl_jax/srt/kernels/gdn/gated_delta.py
@hlFu
Copy link
Copy Markdown
Collaborator Author

hlFu commented May 13, 2026

@hlFu I see your new commits. That's all right, the comment that relying on PRs that haven't been merged yet is indeed not a good idea. And I have a new question, like the photo below: image make algin to the sglang or sglang-jax another attentionbackend desgin, we better use RadixLinearAttention instead of just using XXXAttnBackend(but now I find the radixlinearAttention function signature not 100% match for qkv_fused and qkv split, if you want to merge it in the current situation, it's ok. We can leave a todo issue for later. After qwen3.5 and kimi linear are merged, we can refactor the files by referring to the class structure of sglang's attention backend and use some commen basic kernel. This will help to unify our code). What do you think?

Yes, RadixLinearAttention should be used. I planned to wire it when having the full Qwen3 model. As you pointed out, the signature is different for KDA and GDN forward, so I think maybe it's better to make it a follow-up. I commented a TODO there.

Copy link
Copy Markdown
Collaborator

@cjx0709 cjx0709 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hlFu hlFu merged commit 2a67426 into sgl-project:main May 14, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants