feat: add Gated DeltaNet linear-attention layer for Qwen3-5#1059
Conversation
Standalone GDN stack depthwise causal conv1d into a gated delta-rule recurrence wired up for the Qwen3-5 checkpoint layout. Decoupled from the full model assembly so it can land independently. - `MergedColumnParallelLinear` (srt/layers/linear.py): column-parallel linear with N logical outputs fused into one weight; per-device block-concat layout matches sglang/vLLM so each component shards cleanly on its own head dim (GQA-safe). - GDN kernels + backend (srt/layers/attention/linear/): `gated_delta.py` primitives (l2norm, single delta step, ragged + decode kernels, causal conv1d prefill/update); `GDNAttnBackend` runs conv + recurrence under `jax.shard_map`; `GDNAttnMetadata` carries packed-batch boundaries through jit; `Qwen3_5GatedDeltaNet` glues four projections + gated GemmaRMSNorm + `out_proj`. - `ForwardBatch`: adds `mamba_cache_indices` and `gdn_metadata` for per-request slot indexing into the recurrent state pool. - Tests (test/srt/, 45 cases across 6 new files) cover kernels against a Python reference, shard_map dispatch, and end-to-end shape/finiteness; all pass on CPU at TP=1.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Hello, thanks for your contribution. Structure-wise, could we algin this more closely with the KDA PR(#1051)(you can also take a look for reviewing it, this pr will be merged in next one or two days).
For GDN, some commen class and instruction can be reused:
|
|
@cjx0709 Thanks for your comments! I think they make sense to me and I changed according to these points in my latest commit. Please check again.
For:
I think following it up with refactoring might be better since:
|
|
@hlFu I see your new commits. That's all right, the comment that relying on PRs that haven't been merged yet is indeed not a good idea. |
Yes, RadixLinearAttention should be used. I planned to wire it when having the full Qwen3 model. As you pointed out, the signature is different for KDA and GDN forward, so I think maybe it's better to make it a follow-up. I commented a TODO there. |


Standalone GDN stack — depthwise causal conv1d into a gated delta-rule
recurrence — wired up for the Qwen3-5 checkpoint layout. Decoupled from
the full model assembly so it can land independently.
MergedColumnParallelLinear(srt/layers/linear.py): column-parallellinear with N logical outputs fused into one weight; per-device
block-concat layout matches sglang/vLLM so each component shards
cleanly on its own head dim (GQA-safe).
gated_delta.pyprimitives (l2norm, single delta step, ragged + decode kernels, causal
conv1d prefill/update);
GDNAttnBackendruns conv + recurrence underjax.shard_map;GDNAttnMetadatacarries packed-batch boundariesthrough jit;
Qwen3_5GatedDeltaNetglues four projections + gatedGemmaRMSNorm +
out_proj.ForwardBatch: addsmamba_cache_indicesandgdn_metadataforper-request slot indexing into the recurrent state pool.
Python reference, shard_map dispatch, and end-to-end shape/finiteness;
all pass on CPU at TP=1.
Motivation
Refs #1007.
Qwen3-5 uses a hybrid attention architecture where a subset of decoder
layers replace full self-attention with a Gated DeltaNet (GDN) linear-
attention block (depthwise causal conv1d → gated delta-rule recurrence).
This PR lands the standalone GDN stack — JAX kernels, backend, layer, and
the
ForwardBatchplumbing for the recurrent state pool — so thefollow-up wiring PRs can build on a stable, independently-tested
foundation:
This PR — kernels + backend + layer + ForwardBatch hooks (no
caller in production code yet).
Modifications
See the bullet summary at the top. Two structural points worth calling
out explicitly:
MergedColumnParallelLinearproducesper-device block-concat
[q_tp | k_tp | v_tp]activations, and theconv1d_weightinGDNAttnBackendis loaded with the same stripe sothe per-shard conv1d channels line up with the activation layout
(matches sglang's
mamba_v2_sharded_weight_loader). The loadercontract is documented in
gdn_backend.py. The follow-up loader PR isresponsible for the actual stripe-rearrange.
GDNAttnBackend.__init__assertsnum_k_heads % TP == 0,num_v_heads % TP == 0,conv_dim % TP == 0,and
num_v_heads % num_k_heads == 0(the GQA repeat factor). Theseare stricter than
MergedColumnParallelLinear'skey_dim % TPcheckand catch e.g.
num_k_heads=1, TP=2configs wherekey_dimisdivisible but
num_k_heads // TPsilently floors to 0.Accuracy Tests
N/A for this PR — the GDN layer is not yet wired into any model forward
(no callers in production code), so it doesn't change the outputs of any
currently-served model.
Kernel-level correctness is covered in the unit tests:
test_ragged_gated_delta_rule_ref.py— kernel vs. a token-by-tokenstraight-Python reference (atol=1e-3, rtol=1e-3 on output; atol=1e-4,
rtol=1e-4 on state).
test_gated_delta.py— primitive equality tests (_l2norm,_gated_delta_stepleading-dim-agnostic property, causal conv1dupdate/prefill with multi-request boundaries, K=1 edge case,
short-request left-pad), plus decode-vs-ragged numerical equivalence
(including GQA).
End-to-end accuracy against an HF Qwen3-5 checkpoint will land in the
integration follow-up PR (MMLU / GPQA on a small Qwen3-5 build).
Run locally:
JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8
python -m pytest test/srt/test_gated_delta.py
test/srt/test_gdn_backend.py
test/srt/test_gdn_metadata.py
test/srt/test_qwen3_5_gated_delta_net.py
test/srt/test_ragged_gated_delta_rule_ref.py
test/srt/test_merged_column_parallel_linear.py -v
CI: all five files are registered in
unit-test-tpu-v6e-1intest/srt/run_suite.py.Benchmarking and Profiling
N/A — no model currently dispatches to this code path, so this PR
cannot regress (or improve) end-to-end inference speed. The
recurrence's
lax.scanand the conv1d's per-token gather are bothshape-dependent on
T; padding/bucketing of prefillTis thecaller's responsibility and will be exercised in the integration PR.
Single-layer benchmarking (decode step-time, prefill T-scaling, TP
scaling) will be reported in the integration follow-up alongside the
first end-to-end run.
Checklist