feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch#1072
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Kimi Delta Attention (KDA) support into the upstream kimi_linear branch. The changes focus on enabling hybrid recurrent model support, improving state management to prevent memory pollution, and ensuring the compilation pipeline correctly handles recurrent state requirements without impacting non-recurrent backends. Highlights
New Features🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model: - KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro. - HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend (upstream stub returned full_attn_backend unchanged → server crash). - ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's linear_attn_config attribute (upstream property was a stub returning None). - compilation_manager dummy batch fills recurrent_indices/has_initial_state only when has_recurrent_state is set, so non-recurrent backends are unaffected (CompilationManager grows a has_recurrent_state flag, plumbed from tp_worker via model_runner.linear_recurrent_config). - gated_rmsnorm helper (used by KimiLinear). HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
95d48ea to
0c12f5a
Compare
|
E2E test results UTC 2026-05-13 12:04:28 → 2026-05-13 20:10:03, wall 8h02m45s (eval first prefill at 12:07:18; pod TZ = UTC). ReproducibilityEnvironment
Per-host bootstrappip install -U --pre jax jaxlib libtpu \
-i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install uv
uv venv /tmp/venv && source /tmp/venv/bin/activate
uv pip install pip
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax && git checkout <this PR's HEAD>
uv pip install -e python[all]
uv pip install evalscope==0.17.1Server launch (run on each of the 4 hosts;
|
13e6792
into
sgl-project:feat/support_kimi_linear_model
…1072) Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model: - KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro. - HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend (upstream stub returned full_attn_backend unchanged → server crash). - ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's linear_attn_config attribute (upstream property was a stub returning None). - compilation_manager dummy batch fills recurrent_indices/has_initial_state only when has_recurrent_state is set, so non-recurrent backends are unaffected (CompilationManager grows a has_recurrent_state flag, plumbed from tp_worker via model_runner.linear_recurrent_config). - gated_rmsnorm helper (used by KimiLinear). HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing). Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
* add KimiLinearForCausalLM. Co-authored-by: zhengke.zhou.dev@gmail.com * feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch (#1072) Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model: - KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro. - HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend (upstream stub returned full_attn_backend unchanged → server crash). - ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's linear_attn_config attribute (upstream property was a stub returning None). - compilation_manager dummy batch fills recurrent_indices/has_initial_state only when has_recurrent_state is set, so non-recurrent backends are unaffected (CompilationManager grows a has_recurrent_state flag, plumbed from tp_worker via model_runner.linear_recurrent_config). - gated_rmsnorm helper (used by KimiLinear). HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing). Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> * test(kda): wrap raw KDA backend to accept `pool=` kwarg KDAAttnBackend.__call__ takes `recurrent_state_pool=`, while RadixLinearAttention.__call__ passes `pool=` (HybridLinearAttnBackend's calling convention). Production routes through that wrapper which translates pool→recurrent_state_pool; the unit tests bypass it by assigning a raw KDAAttnBackend as `forward_batch.attn_backend`, so the kwarg falls into **kwargs and `recurrent_state_pool` is unbound → TypeError. Replicate the translation in a test-only shim. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(kda): hoist KDAAttnBackendForTest shim to test_utils.py Both KDA test files were carrying an identical copy of the `pool=` → `recurrent_state_pool=` translation shim added in 466afff. Move it to test_utils.py and import from both, dropping the local underscore prefix since it's now a shared helper. Co-Authored-By: zhengke.zhou.dev@gmail.com Co-authored-by: Mirope Yuhao Hu <miropehu@gmail.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
Wire up Kimi-Linear (KDA + MLA hybrid) end-to-end so the server can actually launch and run on top of
feat/support_kimi_linear_model. The base branch ships KDA kernels (#1051) and the KimiLinear model port, but several glue points are still stubbed or unwired. Without these fixes the server either crashes at startup or, on DP runs, silently produces garbage tokens.Modifications
HybridLinearAttnBackend.attn_backend_wrapper: build the real KDA sub-backend and route bylayer_id. The upstream stub returnedfull_attn_backendunchanged, so KDA layers were never dispatched → server crash.ModelRunner.linear_recurrent_config: detect KimiLinearConfig by the presence oflinear_attn_configonhf_config. Upstream property unconditionally returnedNone, so hybrid pool init was never triggered.KDAAttnBackend.set_ssm_state/set_conv_state: suppress writes atidx == 0(per-rank dummy slot). Padding rows carryidx=0; without the guard their scattered values pollute the dummy slot and leak back as initial state, collapsing tp4dp4 mmlu_pro OVERALL from ~0.66 to ~0.27.CompilationManager: gain ahas_recurrent_stateflag (plumbed fromtp_workerviamodel_runner.linear_recurrent_config); only fillrecurrent_indices/has_initial_statein the dummy precompile batch when it's set, so non-recurrent backends are unaffected.KDAAttnBackend.__call__: renamepoolkwarg torecurrent_state_poolso it matches the call site inHybridLinearAttnBackend.__call__(recurrent_state_pool=pool).layers/attention/fla/gated_rmsnorm.py— gated RMSNorm helper used by KimiLinear.HybridLinearAttnBackend.__call__is kept upstream-clean (nopoolkwarg aliasing).Accuracy Tests
Validated on the parent branch (
feat/kda-e2e-kimi-linear, same logical changes) via full mmlu_pro on Kimi-Linear-48B-A3B-Instruct, T=1.0, bs=32:Pre-fix tp4dp4 collapses to ~0.27 — the dummy-slot guard is the load-bearing fix.
A full re-run of mmlu_pro on this rebased branch is pending; smoke + limit=20 sanity to follow.
Benchmarking and Profiling
No throughput delta expected vs base branch — same kernels, same forward path; only wire-up + a single-int dummy guard inside scatter shard_maps.
Checklist