You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This RFC covers the KDA (Kimi Delta Attention) inference components for sglang-jax, enabling hybrid linear-attention model serving (e.g. Kimi-Linear for Issue#937). Building on the hybrid memory pool system (PR#1034) and linear attention dispatch infrastructure (PR#1041), this PR implements the KDA-specific kernel, backend, and supporting layers needed for end-to-end inference.
KDA is a delta-rule linear attention mechanism with per-element gating and short causal convolutions on Q/K/V. It replaces standard softmax attention on designated layers in hybrid architectures, achieving O(1) per-token state during decode while preserving model quality.
has_initial_state is a per-request bool array ([total_bs], True = continuing request, False = new request). PR#1034 stores it on ModelWorkerBatch (CPU); PR#1041's LinearRecurrentAttnBackend.get_forward_metadata() transfers it to LinearRecurrentAttnBackendMetadata with P("data") sharding so it's available on device.
kda_backend.py: KDAAttnBackend.get_state() uses it to zero stale recurrent state for new requests:
kda_backend.py: both forward paths (_forward_extend, _forward_decode) wrap kernel calls in shard_map for DP support. naive.py: naive_recurrent_kda itself is not modified — shard_map ensures each rank receives per-shard local tensors, so the kernel's existing P(None, "tensor", ...) annotation remains correct.
Metadata computation (provided by PR#1041)
LinearRecurrentAttnBackend.get_forward_metadata() (from PR#1041) computes per-rank cumsum for cu_q_lens and transfers all metadata (cu_q_lens, recurrent_indices, has_initial_state) to device with P("data") sharding. This PR's KDAAttnBackend inherits this behavior.
Tensors reshaped to [dp, ...] before shard_map; B=1 unsqueeze inside shard_map (matches FA's 3D-in pattern).
Tensor
Global → shard_map shape
P-spec
Per-shard
q, k, v, g
[dp, T_per_dp, H, K]
P("data", None, "tensor", None)
[T_per_dp, H_local, K]
beta
[dp, T_per_dp, H]
P("data", None, "tensor")
[T_per_dp, H_local]
initial_state
[dp, N_per_dp, H, K, V]
P("data", None, "tensor", None, None)
[N_per_dp, H_local, K, V]
cu_seqlens
[dp*(N_per_dp+1)]
P("data")
[N_per_dp+1]
A_log
[H]
P("tensor")
[H_local]
dt_bias
[H*K]
P("tensor")
[H_local*K]
Pool gather/scatter
Following FA: pool buffers flow through shard_map as inputs/outputs. Buffer and indices are consistently P("data", ...) sharded with per-rank local slot indices (from PR#1034), so gather/scatter are per-shard independent.
Buffer
Sharding
Shape
recurrent_buffers
P("data", "tensor", None, None)
[total_slots, H, K, V]
conv_buffers
P("data", "tensor", None)
[total_slots, proj_size, K_conv-1]
recurrent_indices
P("data")
[total_bs], per-rank local
State clearing under DP
has_initial_state is P("data") sharded, per-request independent — no cross-rank dependency. Same code as State management.
Baseline for KDA Attention
Manually assembled pipeline, independent of backend internals. Multi-request EXTEND uses a per-sequence loop to ensure varlen packing logic is under test.
Known constraint: RecurrentStatePool must use float32 for recurrent state buffers — bfloat16 causes silent truncation when the kernel's float32 output state is written back.
Background
This RFC covers the KDA (Kimi Delta Attention) inference components for sglang-jax, enabling hybrid linear-attention model serving (e.g. Kimi-Linear for Issue#937). Building on the hybrid memory pool system (PR#1034) and linear attention dispatch infrastructure (PR#1041), this PR implements the KDA-specific kernel, backend, and supporting layers needed for end-to-end inference.
KDA is a delta-rule linear attention mechanism with per-element gating and short causal convolutions on Q/K/V. It replaces standard softmax attention on designated layers in hybrid architectures, achieving O(1) per-token state during decode while preserving model quality.
Prerequisites
RecurrentStatePool+HybridReqToTokenPool+MemoryPoolsmerged to main (PR#1034:feat(mem_cache): Hybrid Memory Pool system #1034)HybridLinearAttnBackend+LinearRecurrentAttnBackend+RadixLinearAttentionmerged to main (PR#1041:feat: Linear Attention Dispatch #1041)Scope
In scope
KDA Pallas kernel (
srt/kernels/kda/kda.py) Adapted from https://github.com/primatrix/pallas-kernelNaive recurrent kernel (
srt/kernels/kda/naive.py)Short causal convolution (
srt/layers/attention/linear/short_convolution.py)cu_seqlens[B, D, K-1]rolling cacheKDAAttnBackend(srt/layers/attention/linear/kda_backend.py)LinearRecurrentAttnBackend(from PR#1041)shard_mapRecurrentStatePoolinterfacehas_initial_state(from feat(mem_cache): Hybrid Memory Pool system #1034)DP adaptation — details in Design Details
Out of scope
HybridLinearAttnBackend,LinearRecurrentAttnBackend,RadixLinearAttention— already merged in PR#1041Code Changes
New files
srt/kernels/kda/__init__.pychunk_kda,naive_recurrent_kdasrt/kernels/kda/kda.pysrt/kernels/kda/naive.pysrt/layers/attention/linear/__init__.pyKDAAttnBackendsrt/layers/attention/linear/kda_backend.pyKDAAttnBackend: conv -> L2 norm -> recurrent kernel dispatch, state management;l2_normalizeutilitysrt/layers/attention/linear/short_convolution.pytest/test_short_conv.pytest/test_kda_attention.pytest/test_kda_attention_dp.pyDesign Details
Interface alignment with Hybrid Memory Pool (PR#1034) and Linear Attention Dispatch (PR#1041)
RecurrentStatePool.get_linear_recurrent_layer_cache(layer_id)->(jax.Array, list[jax.Array])KDAAttnBackend.get_state()unwraps the length-1 conv listLinearRecurrentAttnBackendMetadata(from PR#1041)cu_q_lens,recurrent_indices,has_initial_statewithP("data")shardingLinearRecurrentAttnBackend.get_forward_metadata()(from PR#1041)cu_q_lens, transfers metadata to deviceMemoryPools.replace_all(dict){"token_to_kv_pool": ..., "recurrent_state_pool": (recurrent_bufs, conv_bufs)}RadixLinearAttention(from PR#1041)forward_batch.attn_backend()State management
has_initial_stateis a per-requestboolarray ([total_bs],True= continuing request,False= new request). PR#1034 stores it onModelWorkerBatch(CPU); PR#1041'sLinearRecurrentAttnBackend.get_forward_metadata()transfers it toLinearRecurrentAttnBackendMetadatawithP("data")sharding so it's available on device.kda_backend.py: KDAAttnBackend.get_state()uses it to zero stale recurrent state for new requests:DP adaptation
kda_backend.py: both forward paths (_forward_extend,_forward_decode) wrap kernel calls inshard_mapfor DP support.naive.py: naive_recurrent_kdaitself is not modified —shard_mapensures each rank receives per-shard local tensors, so the kernel's existingP(None, "tensor", ...)annotation remains correct.Metadata computation (provided by PR#1041)
LinearRecurrentAttnBackend.get_forward_metadata()(from PR#1041) computes per-rank cumsum forcu_q_lensand transfers all metadata (cu_q_lens,recurrent_indices,has_initial_state) to device withP("data")sharding. This PR'sKDAAttnBackendinherits this behavior.kda_backend.py: _forward_*— shard_map shape specsTensors reshaped to
[dp, ...]beforeshard_map; B=1 unsqueeze inside shard_map (matches FA's 3D-in pattern).[dp, T_per_dp, H, K]P("data", None, "tensor", None)[T_per_dp, H_local, K][dp, T_per_dp, H]P("data", None, "tensor")[T_per_dp, H_local][dp, N_per_dp, H, K, V]P("data", None, "tensor", None, None)[N_per_dp, H_local, K, V][dp*(N_per_dp+1)]P("data")[N_per_dp+1][H]P("tensor")[H_local][H*K]P("tensor")[H_local*K]Pool gather/scatter
Following FA: pool buffers flow through
shard_mapas inputs/outputs. Buffer and indices are consistentlyP("data", ...)sharded with per-rank local slot indices (from PR#1034), so gather/scatter are per-shard independent.recurrent_buffersP("data", "tensor", None, None)[total_slots, H, K, V]conv_buffersP("data", "tensor", None)[total_slots, proj_size, K_conv-1]recurrent_indicesP("data")[total_bs], per-rank localState clearing under DP
has_initial_stateisP("data")sharded, per-request independent — no cross-rank dependency. Same code as State management.Baseline for KDA Attention
Manually assembled pipeline, independent of backend internals. Multi-request EXTEND uses a per-sequence loop to ensure varlen packing logic is under test.
This design ensures:
test_short_conv.py)initial_state=h0[i:i+1]matches how the pool feeds per-request stateTest Plan
Test structure mirrors
test_flashattention.py/test_flashattention_dp.py.1. Short convolution —
test_short_conv.py(24 tests, done)Baseline:
nnx.Convper-sequence with injected history. fp32:rtol=1e-4, atol=1e-4; bf16:rtol=1e-2, atol=1e-2.ShortConvolutionTestShortConvolutionBF16Test2. Backend tests —
test_kda_attention.py+test_kda_attention_dp.py(17 tests)Common setup:
bf16(matching FA test convention)KDAAttnBackend.__call__()full pipeline_baseline_extend/_baseline_decode(nnx.Conv) + L2 norm +naive_recurrent_kda(see Baseline section)create_test_data(mode, lens, ...)buildsForwardBatch,ModelWorkerBatch,MockRecurrentStatePoolwith random Q/K/V, gateg, beta, A_log, dt_bias, conv weightstest_kda_attention.pymesh = create_device_mesh(ici_parallelism=[1, 1])(TP=1, DP=1) for smoke tests;[1, -1](TP=4, DP=1) for sharded teststest_kda_attention_dp.pyset_mesh(tp_size, dp_size)test_kda_attention.py— classTestKDAAttention, 8 testsseq_lens: per-request token count. EXTEND processes all tokens; DECODE always 1 token per request (state carries history).test_single_seq_extend_no_shard[128]test_single_seq_extend[128]test_multi_seq_extend[64, 128, 32]run_test's ssm/conv assertions)test_extend_short_seqs[3, 7, 1]test_extend_with_initial_state[64], h0 from prior 64 tokenstest_single_step_decode[1], B=1test_multi_request_decode[1, 1, 1], B=3test_extend_then_decode[128]+ 4 decode stepstest_kda_attention_dp.py— classTestKDAAttentionDP, 9 teststest_extend_dp4_tp1{0: [128], 1: [64], 2: [32, 64], 3: [128]}test_extend_sparse_ranks_dp4{0: [64], 1: [], 2: [128], 3: []}test_extend_dp2_tp2{0: [128, 64], 1: [32]}test_decode_dp4_tp1{0: [1, 1], 1: [1], 2: [1, 1], 3: [1]}test_decode_sparse_ranks_dp4{0: [], 1: [1, 1], 2: [], 3: [1]}test_decode_unbalanced_dp4{0: [1], 1: [1, 1, 1], 2: [1], 3: [1, 1, 1, 1]}test_dp_state_isolationtest_dp_mixed_new_continuing{0: new reqs (h0=0), 1: continuing (h0≠0), 2: mix, 3: mix}has_initial_statecross-ranktest_dp_extend_then_decodeSummary
test_short_conv.pytest_kda_attention.pytest_kda_attention_dp.pyPass criteria: 41/41 tests pass.
Known constraint:
RecurrentStatePoolmust usefloat32for recurrent state buffers —bfloat16causes silent truncation when the kernel'sfloat32output state is written back.References