Skip to content

Commit 2c5d860

Browse files
aolemilaMokusMokunclaude
authored
feat: support Kimi Linear model (#1047)
* add KimiLinearForCausalLM. Co-authored-by: zhengke.zhou.dev@gmail.com * feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch (#1072) Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model: - KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro. - HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend (upstream stub returned full_attn_backend unchanged → server crash). - ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's linear_attn_config attribute (upstream property was a stub returning None). - compilation_manager dummy batch fills recurrent_indices/has_initial_state only when has_recurrent_state is set, so non-recurrent backends are unaffected (CompilationManager grows a has_recurrent_state flag, plumbed from tp_worker via model_runner.linear_recurrent_config). - gated_rmsnorm helper (used by KimiLinear). HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing). Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> * test(kda): wrap raw KDA backend to accept `pool=` kwarg KDAAttnBackend.__call__ takes `recurrent_state_pool=`, while RadixLinearAttention.__call__ passes `pool=` (HybridLinearAttnBackend's calling convention). Production routes through that wrapper which translates pool→recurrent_state_pool; the unit tests bypass it by assigning a raw KDAAttnBackend as `forward_batch.attn_backend`, so the kwarg falls into **kwargs and `recurrent_state_pool` is unbound → TypeError. Replicate the translation in a test-only shim. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(kda): hoist KDAAttnBackendForTest shim to test_utils.py Both KDA test files were carrying an identical copy of the `pool=` → `recurrent_state_pool=` translation shim added in 466afff. Move it to test_utils.py and import from both, dropping the local underscore prefix since it's now a shared helper. Co-Authored-By: zhengke.zhou.dev@gmail.com Co-authored-by: Mirope Yuhao Hu <miropehu@gmail.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent f4d39f2 commit 2c5d860

25 files changed

Lines changed: 1156 additions & 48 deletions
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/kimi_linear.py
2+
# (which itself is adapted from vllm's kimi_linear config).
3+
from transformers.configuration_utils import PretrainedConfig
4+
5+
6+
class KimiLinearConfig(PretrainedConfig):
7+
model_type = "kimi_linear"
8+
keys_to_ignore_at_inference = ["past_key_values"]
9+
10+
def __init__(
11+
self,
12+
model_type="kimi_linear",
13+
vocab_size=163840,
14+
hidden_size=4096,
15+
head_dim=None,
16+
intermediate_size=11008,
17+
num_hidden_layers=32,
18+
num_attention_heads=32,
19+
num_key_value_heads=None,
20+
hidden_act="silu",
21+
initializer_range=0.02,
22+
rms_norm_eps=1e-6,
23+
use_cache=True,
24+
pad_token_id=0,
25+
bos_token_id=1,
26+
eos_token_id=2,
27+
rope_theta=10000.0,
28+
rope_scaling=None,
29+
tie_word_embeddings=False,
30+
moe_intermediate_size: int | None = None,
31+
moe_renormalize: bool = True,
32+
moe_router_activation_func: str = "sigmoid",
33+
num_experts: int | None = None,
34+
num_experts_per_token: int | None = None,
35+
num_shared_experts: int = 0,
36+
routed_scaling_factor: float = 1.0,
37+
first_k_dense_replace: int = 0,
38+
moe_layer_freq: int = 1,
39+
use_grouped_topk: bool = True,
40+
num_expert_group: int = 1,
41+
topk_group: int = 1,
42+
q_lora_rank: int | None = None,
43+
kv_lora_rank: int | None = None,
44+
qk_nope_head_dim: int | None = None,
45+
qk_rope_head_dim: int | None = None,
46+
v_head_dim: int | None = None,
47+
mla_use_nope: bool | None = False,
48+
num_nextn_predict_layers: int = 0,
49+
linear_attn_config: dict | None = None,
50+
**kwargs,
51+
):
52+
self.model_type = model_type
53+
self.vocab_size = vocab_size
54+
self.hidden_size = hidden_size
55+
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
56+
self.intermediate_size = intermediate_size
57+
self.num_hidden_layers = num_hidden_layers
58+
self.num_attention_heads = num_attention_heads
59+
60+
# for backward compatibility
61+
if num_key_value_heads is None:
62+
num_key_value_heads = num_attention_heads
63+
64+
self.num_key_value_heads = num_key_value_heads
65+
self.hidden_act = hidden_act
66+
self.initializer_range = initializer_range
67+
self.rms_norm_eps = rms_norm_eps
68+
self.use_cache = use_cache
69+
self.rope_theta = rope_theta
70+
self.rope_scaling = rope_scaling
71+
72+
self.q_lora_rank = q_lora_rank
73+
self.kv_lora_rank = kv_lora_rank
74+
self.qk_nope_head_dim = qk_nope_head_dim
75+
self.qk_rope_head_dim = qk_rope_head_dim
76+
self.v_head_dim = v_head_dim
77+
self.mla_use_nope = mla_use_nope
78+
# moe config
79+
self.n_routed_experts = self.num_experts = num_experts
80+
self.num_experts_per_token = num_experts_per_token
81+
self.moe_renormalize = moe_renormalize
82+
self.num_shared_experts = num_shared_experts
83+
self.routed_scaling_factor = routed_scaling_factor
84+
self.moe_router_activation_func = moe_router_activation_func
85+
assert self.moe_router_activation_func in ("softmax", "sigmoid")
86+
self.moe_intermediate_size = moe_intermediate_size
87+
self.first_k_dense_replace = first_k_dense_replace
88+
self.moe_layer_freq = moe_layer_freq
89+
self.use_grouped_topk = use_grouped_topk
90+
self.num_expert_group = num_expert_group
91+
self.topk_group = topk_group
92+
self.num_nextn_predict_layers = num_nextn_predict_layers
93+
94+
if linear_attn_config is not None:
95+
assert linear_attn_config["kda_layers"] is not None
96+
assert linear_attn_config["full_attn_layers"] is not None
97+
self.linear_attn_config = linear_attn_config
98+
99+
super().__init__(
100+
pad_token_id=pad_token_id,
101+
bos_token_id=bos_token_id,
102+
eos_token_id=eos_token_id,
103+
tie_word_embeddings=tie_word_embeddings,
104+
**kwargs,
105+
)
106+
107+
@property
108+
def is_mla(self):
109+
return (
110+
self.q_lora_rank is not None
111+
or self.kv_lora_rank is not None
112+
or self.qk_nope_head_dim is not None
113+
or self.qk_rope_head_dim is not None
114+
or self.v_head_dim is not None
115+
or self.mla_use_nope is True
116+
)
117+
118+
@property
119+
def is_moe(self):
120+
return self.num_experts is not None
121+
122+
@property
123+
def is_linear_attn(self) -> bool:
124+
return not (
125+
self.linear_attn_config is None
126+
or (
127+
isinstance(self.linear_attn_config, dict)
128+
and self.linear_attn_config["kda_layers"] is not None
129+
and len(self.linear_attn_config["kda_layers"]) == 0
130+
)
131+
)
132+
133+
def is_kda_layer(self, layer_idx: int):
134+
return (
135+
self.linear_attn_config is not None
136+
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
137+
)
138+
139+
@property
140+
def linear_layer_ids(self):
141+
return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
142+
143+
@property
144+
def full_attention_layer_ids(self):
145+
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]

python/sgl_jax/srt/hf_transformers_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@
1818
)
1919
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
2020

21+
from sgl_jax.srt.configs.kimi_linear import KimiLinearConfig
2122
from sgl_jax.srt.managers.tiktoken_tokenizer import TiktokenTokenizer
2223
from sgl_jax.srt.utils.common_utils import is_remote_url, lru_cache_frozenset
2324

24-
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {}
25+
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
26+
cls.model_type: cls
27+
for cls in [
28+
KimiLinearConfig,
29+
]
30+
}
2531

2632
for name, cls in _CONFIG_REGISTRY.items():
2733
with contextlib.suppress(ValueError):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Gated RMS normalization for linear attention layers.
2+
3+
Computes ``RMSNorm(x) * sigmoid(gate)`` — used by KDA (Kimi Delta Attention)
4+
as the output normalization before the final projection.
5+
6+
GPU reference: ``sglang/srt/layers/attention/fla/fused_norm_gate.py``
7+
(``FusedRMSNormGated`` with ``activation="sigmoid"``).
8+
"""
9+
10+
import jax
11+
import jax.numpy as jnp
12+
from flax import nnx
13+
from flax.typing import Dtype
14+
15+
16+
class GatedRMSNorm(nnx.Module):
17+
"""RMSNorm with a multiplicative sigmoid gate.
18+
19+
Given input ``x`` and ``gate`` of the same shape, computes::
20+
21+
output = (x / sqrt(mean(x^2) + eps)) * weight * sigmoid(gate)
22+
"""
23+
24+
def __init__(
25+
self,
26+
num_features: int,
27+
epsilon: float = 1e-6,
28+
param_dtype: Dtype = jnp.float32,
29+
):
30+
self.weight = nnx.Param(jnp.ones((num_features,), dtype=param_dtype))
31+
self.epsilon = epsilon
32+
33+
def __call__(self, x: jax.Array, gate: jax.Array) -> jax.Array:
34+
orig_dtype = x.dtype
35+
x_f32 = x.astype(jnp.float32)
36+
variance = jnp.mean(jnp.square(x_f32), axis=-1, keepdims=True)
37+
x_norm = x_f32 * jax.lax.rsqrt(variance + self.epsilon)
38+
x_norm = x_norm * self.weight[...].astype(jnp.float32)
39+
return (x_norm * jax.nn.sigmoid(gate.astype(jnp.float32))).astype(orig_dtype)

python/sgl_jax/srt/layers/attention/hybrid_linear_attn_backend.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,5 +227,22 @@ def attn_backend_wrapper(
227227
runner: ModelRunner,
228228
full_attn_backend: AttentionBackend,
229229
):
230-
"""Wrap full_attn_backend in HybridLinearAttnBackend for hybrid models."""
231-
return full_attn_backend
230+
"""Wrap full_attn_backend in HybridLinearAttnBackend for hybrid models.
231+
232+
For hybrid recurrent models (e.g. Kimi-Linear: KDA + MLA), build the
233+
matching linear sub-backend and route by layer_id. For pure full-attn
234+
models, return the full_attn_backend unchanged.
235+
"""
236+
cfg = runner.linear_recurrent_config
237+
if cfg is None:
238+
return full_attn_backend
239+
240+
# Only supported linear sub-backend today is KDA.
241+
from sgl_jax.srt.layers.attention.linear.kda_backend import KDAAttnBackend
242+
243+
linear_attn_backend = KDAAttnBackend(mesh=runner.mesh)
244+
return HybridLinearAttnBackend(
245+
full_attn_backend=full_attn_backend,
246+
linear_attn_backend=linear_attn_backend,
247+
full_attn_layers=cfg.full_attention_layer_ids,
248+
)

python/sgl_jax/srt/layers/attention/linear/kda_backend.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ def __call__(
4242
b: jax.Array,
4343
layer: RadixLinearAttention,
4444
forward_batch: ForwardBatch,
45-
pool: RecurrentStatePool,
45+
recurrent_state_pool: RecurrentStatePool,
4646
**kwargs,
4747
) -> jax.Array:
4848
recurrent_indices = self.forward_metadata.recurrent_indices
49-
ssm_states, conv_states = self.get_state(pool, layer.layer_id, recurrent_indices)
49+
ssm_states, conv_states = self.get_state(
50+
recurrent_state_pool, layer.layer_id, recurrent_indices
51+
)
5052
q_conv_w = layer.q_conv1d.weight.value
5153
k_conv_w = layer.k_conv1d.weight.value
5254
v_conv_w = layer.v_conv1d.weight.value
@@ -132,9 +134,11 @@ def __call__(
132134
else:
133135
raise NotImplementedError(f"KDA does not support {forward_batch.forward_mode}")
134136

135-
new_ssm_full = self.set_ssm_state(pool, layer.layer_id, recurrent_indices, new_recurrent)
137+
new_ssm_full = self.set_ssm_state(
138+
recurrent_state_pool, layer.layer_id, recurrent_indices, new_recurrent
139+
)
136140
new_conv_full_list = self.set_conv_state(
137-
pool, layer.layer_id, recurrent_indices, new_conv_packed
141+
recurrent_state_pool, layer.layer_id, recurrent_indices, new_conv_packed
138142
)
139143
return output.reshape(output.shape[0], -1), (new_ssm_full, new_conv_full_list)
140144

@@ -174,11 +178,20 @@ def get_state(self, recurrent_state_pool, layer_id, recurrent_indices):
174178
return ssm, conv
175179

176180
def set_ssm_state(self, recurrent_state_pool, layer_id, recurrent_indices, new_recurrent):
177-
"""Scatter per-request ``new_recurrent`` into the FULL pool buffer."""
181+
"""Scatter per-request ``new_recurrent`` into the FULL pool buffer.
182+
183+
Suppress writes at idx==0: padding rows carry idx=0 and would otherwise
184+
pollute the per-rank dummy slot, leaking garbage back as initial state.
185+
"""
178186
full_recurrent, _ = self.get_layer_cache(recurrent_state_pool, layer_id)
179187

188+
def _scatter(buf, idx, val):
189+
keep_mask = (idx == 0).reshape(-1, 1, 1, 1)
190+
safe_val = jnp.where(keep_mask, buf[idx], val)
191+
return buf.at[idx].set(safe_val)
192+
180193
return jax.shard_map(
181-
lambda buf, idx, val: buf.at[idx].set(val),
194+
_scatter,
182195
mesh=self.mesh,
183196
in_specs=(
184197
P("data", "tensor", None, None),
@@ -190,13 +203,18 @@ def set_ssm_state(self, recurrent_state_pool, layer_id, recurrent_indices, new_r
190203
)(full_recurrent, recurrent_indices, new_recurrent)
191204

192205
def set_conv_state(self, recurrent_state_pool, layer_id, recurrent_indices, new_conv_packed):
193-
"""Scatter per-request packed conv state into the FULL pool buffer."""
206+
"""Scatter per-request packed conv state. Same idx==0 guard as set_ssm_state."""
194207
_, conv_buffer_list = self.get_layer_cache(recurrent_state_pool, layer_id)
195208
assert len(conv_buffer_list) == 1
196209
full_conv = conv_buffer_list[0]
197210

211+
def _scatter(buf, idx, val):
212+
keep_mask = (idx == 0).reshape(-1, 1, 1)
213+
safe_val = jnp.where(keep_mask, buf[idx], val)
214+
return buf.at[idx].set(safe_val)
215+
198216
new_conv_full = jax.shard_map(
199-
lambda buf, idx, val: buf.at[idx].set(val),
217+
_scatter,
200218
mesh=self.mesh,
201219
in_specs=(
202220
P("data", "tensor", None),

python/sgl_jax/srt/managers/tp_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
max_req_len=self.max_req_len,
207207
vocab_size=self.model_config.vocab_size,
208208
multimodal=server_args.multimodal,
209+
has_recurrent_state=self.model_runner.linear_recurrent_config is not None,
209210
)
210211

211212
self.parent_process = psutil.Process().parent()

python/sgl_jax/srt/model_executor/compilation_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
max_req_len: int,
3636
vocab_size: int,
3737
multimodal: bool = False,
38+
has_recurrent_state: bool = False,
3839
):
3940
self.dp_size = dp_size
4041
self.tp_size = tp_size
@@ -44,6 +45,7 @@ def __init__(
4445
self.max_padded_num_tokens = max_padded_num_tokens
4546
self.vocab_size = vocab_size
4647
self.multimodal = multimodal
48+
self.has_recurrent_state = has_recurrent_state
4749
self.moe_backend = server_args.moe_backend
4850
self.enable_static_lora = server_args.enable_static_lora
4951

@@ -309,6 +311,12 @@ def _make_dummy_batch(
309311
per_dp_bs_size=per_dp_bs_size,
310312
real_bs_per_dp=[bs] * dp_size,
311313
logits_indices_selector=np.arange(bs, dtype=np.int32),
314+
# Hybrid recurrent backends (e.g. KDA) require these per-batch
315+
# arrays even at precompile time; slot 0 is RecurrentStatePool's
316+
# per-rank dummy slot, safe to point at. Leave None otherwise so
317+
# non-recurrent backends are unaffected.
318+
recurrent_indices=(np.zeros(bs, dtype=np.int32) if self.has_recurrent_state else None),
319+
has_initial_state=(np.zeros(bs, dtype=np.bool_) if self.has_recurrent_state else None),
312320
)
313321

314322
# ---- Lazy compilation tracking ----

python/sgl_jax/srt/model_executor/model_runner_kv_cache_mixin.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,10 @@ def init_memory_pool(
602602

603603
@property
604604
def linear_recurrent_config(self: ModelRunner):
605-
"""Return linear recurrent config if the model has linear attention, else None.
606-
607-
Currently returns None unconditionally — KimiLinearConfig detection
608-
will be wired up when the modeling layer lands.
609-
"""
605+
"""Return linear recurrent config if the model has linear attention, else None."""
606+
hf_cfg = getattr(self.model_config, "hf_config", None)
607+
if hf_cfg is not None and getattr(hf_cfg, "linear_attn_config", None) is not None:
608+
return hf_cfg
610609
return None
611610

612611
def _kv_pool_layer_count(self: ModelRunner):

python/sgl_jax/srt/models/bailing_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def __call__(
918918
output = self.logits_processor(hidden_states, self.lm_head, logits_metadata)
919919
else:
920920
output = self.logits_processor(hidden_states, self.model.embed_tokens, logits_metadata)
921-
return output, layers_kv_fused, True, layers_topk_ids
921+
return output, {"token_to_kv_pool": layers_kv_fused}, True, layers_topk_ids
922922

923923

924924
class BailingMoeForCausalLM(BailingMoEForCausalLM):

0 commit comments

Comments
 (0)