Skip to content

[RFC] KDA Components Integration #1038

@MokusMokun

Description

@MokusMokun

Background

This RFC covers the KDA (Kimi Delta Attention) inference components for sglang-jax, enabling hybrid linear-attention model serving (e.g. Kimi-Linear for Issue#937). Building on the hybrid memory pool system (PR#1034) and linear attention dispatch infrastructure (PR#1041), this PR implements the KDA-specific kernel, backend, and supporting layers needed for end-to-end inference.

KDA is a delta-rule linear attention mechanism with per-element gating and short causal convolutions on Q/K/V. It replaces standard softmax attention on designated layers in hybrid architectures, achieving O(1) per-token state during decode while preserving model quality.

Prerequisites

Scope

In scope

  1. KDA Pallas kernel (srt/kernels/kda/kda.py) Adapted from https://github.com/primatrix/pallas-kernel

    • Chunked forward pass for variable-length sequences (varlen)
    • Four-stage pipeline: gate cumsum -> intra-chunk delta-rule solve -> inter-chunk state propagation -> output computation
    • Varlen alignment helpers for chunk-size padding
  2. Naive recurrent kernel (srt/kernels/kda/naive.py)

    • Step-by-step reference implementation: decay -> delta update -> output
    • Used for decode path and as test baseline
  3. Short causal convolution (srt/layers/attention/linear/short_convolution.py)

    • Stateless depthwise conv1d with per-sequence cache
    • EXTEND path: variable-length packed prefill with cu_seqlens
    • DECODE path: single-token step with [B, D, K-1] rolling cache
  4. KDAAttnBackend (srt/layers/attention/linear/kda_backend.py)

    • Extends LinearRecurrentAttnBackend (from PR#1041)
    • EXTEND: Pallas chunked kernel via shard_map
    • DECODE: naive recurrent kernel (Pallas decode TBD)
    • Conv state packing/unpacking (Q/K/V concatenated along channel dim)
    • State management: get/set through RecurrentStatePool interface
    • State clearing for new requests via has_initial_state (from feat(mem_cache): Hybrid Memory Pool system #1034)
  5. DP adaptation — details in Design Details

Out of scope

  • Kimi-Linear model skeleton (modeling layer, weight loading) — separate PR
  • Pallas decode kernel — future optimization
  • Kernel performance — future optimization
  • HybridLinearAttnBackend, LinearRecurrentAttnBackend, RadixLinearAttention — already merged in PR#1041

Code Changes

New files

File Description
srt/kernels/kda/__init__.py Exports chunk_kda, naive_recurrent_kda
srt/kernels/kda/kda.py Chunked Pallas kernel: gate cumsum, intra-chunk solve, inter-chunk state propagation, output computation
srt/kernels/kda/naive.py Naive step-by-step recurrent reference (test baseline + decode path)
srt/layers/attention/linear/__init__.py Exports KDAAttnBackend
srt/layers/attention/linear/kda_backend.py KDAAttnBackend: conv -> L2 norm -> recurrent kernel dispatch, state management; l2_normalize utility
srt/layers/attention/linear/short_convolution.py Stateless depthwise causal conv1d with per-sequence cache
test/test_short_conv.py Conv correctness tests (EXTEND/DECODE against nnx.Conv baseline, fp32 + bf16)
test/test_kda_attention.py Backend-level tests: KDAAttnBackend vs manual baseline (TP only, no DP)
test/test_kda_attention_dp.py Backend-level tests with DP sharding

Design Details

Interface alignment with Hybrid Memory Pool (PR#1034) and Linear Attention Dispatch (PR#1041)

Interface Usage in this PR
RecurrentStatePool.get_linear_recurrent_layer_cache(layer_id) -> (jax.Array, list[jax.Array]) KDAAttnBackend.get_state() unwraps the length-1 conv list
LinearRecurrentAttnBackendMetadata (from PR#1041) Contains cu_q_lens, recurrent_indices, has_initial_state with P("data") sharding
LinearRecurrentAttnBackend.get_forward_metadata() (from PR#1041) Computes per-rank cumsum for cu_q_lens, transfers metadata to device
MemoryPools.replace_all(dict) Caller passes {"token_to_kv_pool": ..., "recurrent_state_pool": (recurrent_bufs, conv_bufs)}
RadixLinearAttention (from PR#1041) Holds conv1d weight references, A_log, dt_bias, scale; calls forward_batch.attn_backend()

State management

has_initial_state is a per-request bool array ([total_bs], True = continuing request, False = new request). PR#1034 stores it on ModelWorkerBatch (CPU); PR#1041's LinearRecurrentAttnBackend.get_forward_metadata() transfers it to LinearRecurrentAttnBackendMetadata with P("data") sharding so it's available on device.

kda_backend.py: KDAAttnBackend.get_state() uses it to zero stale recurrent state for new requests:

ssm = recurrent_buffer[recurrent_indices]
ssm = jnp.where(has_initial_state[:, None, None, None], ssm, 0.0)

DP adaptation

kda_backend.py: both forward paths (_forward_extend, _forward_decode) wrap kernel calls in shard_map for DP support. naive.py: naive_recurrent_kda itself is not modified — shard_map ensures each rank receives per-shard local tensors, so the kernel's existing P(None, "tensor", ...) annotation remains correct.

Metadata computation (provided by PR#1041)

LinearRecurrentAttnBackend.get_forward_metadata() (from PR#1041) computes per-rank cumsum for cu_q_lens and transfers all metadata (cu_q_lens, recurrent_indices, has_initial_state) to device with P("data") sharding. This PR's KDAAttnBackend inherits this behavior.

kda_backend.py: _forward_* — shard_map shape specs

Tensors reshaped to [dp, ...] before shard_map; B=1 unsqueeze inside shard_map (matches FA's 3D-in pattern).

Tensor Global → shard_map shape P-spec Per-shard
q, k, v, g [dp, T_per_dp, H, K] P("data", None, "tensor", None) [T_per_dp, H_local, K]
beta [dp, T_per_dp, H] P("data", None, "tensor") [T_per_dp, H_local]
initial_state [dp, N_per_dp, H, K, V] P("data", None, "tensor", None, None) [N_per_dp, H_local, K, V]
cu_seqlens [dp*(N_per_dp+1)] P("data") [N_per_dp+1]
A_log [H] P("tensor") [H_local]
dt_bias [H*K] P("tensor") [H_local*K]

Pool gather/scatter

Following FA: pool buffers flow through shard_map as inputs/outputs. Buffer and indices are consistently P("data", ...) sharded with per-rank local slot indices (from PR#1034), so gather/scatter are per-shard independent.

Buffer Sharding Shape
recurrent_buffers P("data", "tensor", None, None) [total_slots, H, K, V]
conv_buffers P("data", "tensor", None) [total_slots, proj_size, K_conv-1]
recurrent_indices P("data") [total_bs], per-rank local

State clearing under DP

has_initial_state is P("data") sharded, per-request independent — no cross-rank dependency. Same code as State management.

Baseline for KDA Attention

Manually assembled pipeline, independent of backend internals. Multi-request EXTEND uses a per-sequence loop to ensure varlen packing logic is under test.

# EXTEND baseline (multi-request):
outputs, final_states = [], []
for i in range(N):
    start, end = cu_seqlens[i], cu_seqlens[i+1]
    q_i = _baseline_extend(q_packed[start:end], ...)   # nnx.Conv per-seq
    k_i = _baseline_extend(k_packed[start:end], ...)
    v_i = _baseline_extend(v_packed[start:end], ...)
    q_i, k_i = l2_normalize(q_i), l2_normalize(k_i)
    g_act_i = -exp(A_log) * softplus(g[start:end] + dt_bias)
    o_i, s_i = naive_recurrent_kda(
        q_i[None], k_i[None], v_i[None], g_act_i[None], beta_i[None],
        initial_state=h0[i:i+1], output_final_state=True,
    )
    outputs.append(o_i[0])
    final_states.append(s_i)
o_ref = jnp.concatenate(outputs, axis=0)          # [T, H, V]
state_ref = jnp.concatenate(final_states, axis=0)  # [N, H, K, V]

# DECODE baseline (Direct batched call):
q_conv = _baseline_decode(q_step, histories, nnx_conv)  # [B, D]
k_conv = _baseline_decode(k_step, histories, nnx_conv)
v_conv = _baseline_decode(v_step, histories, nnx_conv)
q, k = l2_normalize(q_conv.reshape(B, H, D)), l2_normalize(k_conv.reshape(B, H, D))
v = v_conv.reshape(B, H, V)
g_act = -exp(A_log) * softplus(g + dt_bias)
o_ref, state_ref = naive_recurrent_kda(
    q[:,None], k[:,None], v[:,None], g_act[:,None], beta[:,None],
    initial_state=h0, output_final_state=True,
)

This design ensures:

  • Conv correctness: nnx.Conv baseline (already validated by test_short_conv.py)
  • Varlen packing: per-seq loop ensures the backend's packed varlen handling is tested end-to-end
  • State isolation: per-seq initial_state=h0[i:i+1] matches how the pool feeds per-request state

Test Plan

Test structure mirrors test_flashattention.py / test_flashattention_dp.py.

1. Short convolution — test_short_conv.py (24 tests, done)

Baseline: nnx.Conv per-sequence with injected history. fp32: rtol=1e-4, atol=1e-4; bf16: rtol=1e-2, atol=1e-2.

Class # Tests Description
ShortConvolutionTest 5 EXTEND + 2 DECODE + 4 bias/activation + 1 error = 12 Conv correctness vs nnx.Conv baseline (fp32)
ShortConvolutionBF16Test inherits same 12 Same at bf16 precision

2. Backend tests — test_kda_attention.py + test_kda_attention_dp.py (17 tests)

Common setup:

  • Input dtype: bf16 (matching FA test convention)
  • Test target: KDAAttnBackend.__call__() full pipeline
  • Baseline: per-sequence loop — _baseline_extend / _baseline_decode (nnx.Conv) + L2 norm + naive_recurrent_kda (see Baseline section)
  • Helper create_test_data(mode, lens, ...) builds ForwardBatch, ModelWorkerBatch, MockRecurrentStatePool with random Q/K/V, gate g, beta, A_log, dt_bias, conv weights
File Tolerance DP setup Baseline reference
test_kda_attention.py rtol=2e-2, atol=1e-2 mesh = create_device_mesh(ici_parallelism=[1, 1]) (TP=1, DP=1) for smoke tests; [1, -1] (TP=4, DP=1) for sharded tests Direct baseline
test_kda_attention_dp.py rtol=2e-2, atol=2e-2 set_mesh(tp_size, dp_size) Per-DP-rank baseline assembled into global padded layout

test_kda_attention.py — class TestKDAAttention, 8 tests

seq_lens: per-request token count. EXTEND processes all tokens; DECODE always 1 token per request (state carries history).

# Test DP TP seq_lens Description
1 test_single_seq_extend_no_shard 1 1 [128] Single sequence prefill, no sharding (smoke test)
2 test_single_seq_extend 1 4 [128] Single sequence prefill
3 test_multi_seq_extend 1 4 [64, 128, 32] Multi-sequence packed prefill (also covers per-request final-state stack via run_test's ssm/conv assertions)
4 test_extend_short_seqs 1 4 [3, 7, 1] Short sequences (< chunk_size=64)
5 test_extend_with_initial_state 1 4 [64], h0 from prior 64 tokens Chunked prefill continuation (has_initial_state=True)
6 test_single_step_decode 1 4 [1], B=1 Single-step decode
7 test_multi_request_decode 1 4 [1, 1, 1], B=3 Batched decode, no cross-request leakage
8 test_extend_then_decode 1 4 extend [128] + 4 decode steps State continuity: EXTEND writes pool, DECODE reads

test_kda_attention_dp.py — class TestKDAAttentionDP, 9 tests

# Test DP TP seq_lens_dict Description
1 test_extend_dp4_tp1 4 1 {0: [128], 1: [64], 2: [32, 64], 3: [128]} Multi-rank EXTEND
2 test_extend_sparse_ranks_dp4 4 1 {0: [64], 1: [], 2: [128], 3: []} Empty DP ranks don't crash
3 test_extend_dp2_tp2 2 2 {0: [128, 64], 1: [32]} Combined DP+TP
4 test_decode_dp4_tp1 4 1 {0: [1, 1], 1: [1], 2: [1, 1], 3: [1]} Multi-rank DECODE
5 test_decode_sparse_ranks_dp4 4 1 {0: [], 1: [1, 1], 2: [], 3: [1]} Empty DP ranks in decode
6 test_decode_unbalanced_dp4 4 1 {0: [1], 1: [1, 1, 1], 2: [1], 3: [1, 1, 1, 1]} Unbalanced batch across DP
7 test_dp_state_isolation 4 1 per-rank distinct h0 No cross-DP state leakage
8 test_dp_mixed_new_continuing 4 1 {0: new reqs (h0=0), 1: continuing (h0≠0), 2: mix, 3: mix} has_initial_state cross-rank
9 test_dp_extend_then_decode 4 1 extend per-rank + 4 decode steps State continuity under DP: pool scatter→gather

Summary

File Tests Tolerance Status
test_short_conv.py 24 fp32: 1e-4; bf16: 1e-2 Done
test_kda_attention.py 8 rtol=2e-2, atol=1e-2 To implement
test_kda_attention_dp.py 9 rtol=2e-2, atol=2e-2 To implement
Total 41

Pass criteria: 41/41 tests pass.

Known constraint: RecurrentStatePool must use float32 for recurrent state buffers — bfloat16 causes silent truncation when the kernel's float32 output state is written back.

References

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions