Skip to content

Commit ffa86d4

Browse files
authored
fix(dispatch): gate Kimi-Linear detection through get_kimi_linear_config (#1085)
Both Kimi-Linear and Bailing-hybrid hf_configs carry a linear_attn_config attribute, so the prior 'has linear_attn_config?' check matched Bailing too. attn_backend_wrapper then routed Bailing into KDAAttnBackend instead of LightningAttnBackend, which crashed at the first forward with: AttributeError: 'RadixLightningAttention' object has no attribute 'q_conv1d' (KDAAttnBackend reads layer.q_conv1d.weight, which only exists on the KDA attention module, not on Lightning's RadixLightningAttention.) Add a top-level get_kimi_linear_config() factory in configs/kimi_linear.py that mirrors the existing configs/bailing_hybrid.py:get_bailing_hybrid_config helper (model_type guard + architectures fallback). Then make the ModelRunnerKVCacheMixin.kimi_linear_config property dispatch through that helper, so the two linear-recurrent paths are detected by symmetric module- local helpers instead of magic strings in the mixin.
1 parent 4a738c9 commit ffa86d4

2 files changed

Lines changed: 31 additions & 5 deletions

File tree

python/sgl_jax/srt/configs/kimi_linear.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/kimi_linear.py
22
# (which itself is adapted from vllm's kimi_linear config).
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
37
from transformers.configuration_utils import PretrainedConfig
48

59

@@ -143,3 +147,27 @@ def linear_layer_ids(self):
143147
@property
144148
def full_attention_layer_ids(self):
145149
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
150+
151+
152+
def _is_kimi_linear_config(hf_config: Any) -> bool:
153+
if getattr(hf_config, "model_type", None) == "kimi_linear":
154+
return True
155+
architectures = getattr(hf_config, "architectures", None) or []
156+
return any(str(arch).startswith("KimiLinear") for arch in architectures)
157+
158+
159+
def get_kimi_linear_config(hf_config: Any) -> KimiLinearConfig | None:
160+
"""Return a KimiLinearConfig if hf_config describes a Kimi-Linear model, else None.
161+
162+
Mirrors ``configs.bailing_hybrid.get_bailing_hybrid_config`` so the dispatch
163+
layer can detect Kimi-Linear and Bailing-hybrid through symmetric helpers
164+
instead of comparing model_type strings inline.
165+
"""
166+
if not _is_kimi_linear_config(hf_config):
167+
return None
168+
if getattr(hf_config, "linear_attn_config", None) is None:
169+
return None
170+
if isinstance(hf_config, KimiLinearConfig):
171+
return hf_config
172+
config_kwargs = hf_config.to_dict() if hasattr(hf_config, "to_dict") else dict(vars(hf_config))
173+
return KimiLinearConfig(**config_kwargs)

python/sgl_jax/srt/model_executor/model_runner_kv_cache_mixin.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,9 @@ def init_memory_pool(
618618

619619
@property
620620
def kimi_linear_config(self: ModelRunner):
621-
"""Return Kimi-Linear hf_config if the model has KDA linear attention, else None."""
622-
hf_cfg = getattr(self.model_config, "hf_config", None)
623-
if hf_cfg is not None and getattr(hf_cfg, "linear_attn_config", None) is not None:
624-
return hf_cfg
625-
return None
621+
from sgl_jax.srt.configs.kimi_linear import get_kimi_linear_config
622+
623+
return get_kimi_linear_config(self.model_config.hf_config)
626624

627625
@property
628626
def lightning_config(self: ModelRunner):

0 commit comments

Comments
 (0)