Skip to content

feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch#1072

Merged
aolemila merged 1 commit into
sgl-project:feat/support_kimi_linear_modelfrom
MokusMokun:feat/kda-e2e-kimi-linear-v2
May 14, 2026
Merged

feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch#1072
aolemila merged 1 commit into
sgl-project:feat/support_kimi_linear_modelfrom
MokusMokun:feat/kda-e2e-kimi-linear-v2

Conversation

@MokusMokun
Copy link
Copy Markdown
Contributor

@MokusMokun MokusMokun commented May 13, 2026

Motivation

Wire up Kimi-Linear (KDA + MLA hybrid) end-to-end so the server can actually launch and run on top of feat/support_kimi_linear_model. The base branch ships KDA kernels (#1051) and the KimiLinear model port, but several glue points are still stubbed or unwired. Without these fixes the server either crashes at startup or, on DP runs, silently produces garbage tokens.

Modifications

  • HybridLinearAttnBackend.attn_backend_wrapper: build the real KDA sub-backend and route by layer_id. The upstream stub returned full_attn_backend unchanged, so KDA layers were never dispatched → server crash.
  • ModelRunner.linear_recurrent_config: detect KimiLinearConfig by the presence of linear_attn_config on hf_config. Upstream property unconditionally returned None, so hybrid pool init was never triggered.
  • KDAAttnBackend.set_ssm_state / set_conv_state: suppress writes at idx == 0 (per-rank dummy slot). Padding rows carry idx=0; without the guard their scattered values pollute the dummy slot and leak back as initial state, collapsing tp4dp4 mmlu_pro OVERALL from ~0.66 to ~0.27.
  • CompilationManager: gain a has_recurrent_state flag (plumbed from tp_worker via model_runner.linear_recurrent_config); only fill recurrent_indices / has_initial_state in the dummy precompile batch when it's set, so non-recurrent backends are unaffected.
  • KDAAttnBackend.__call__: rename pool kwarg to recurrent_state_pool so it matches the call site in HybridLinearAttnBackend.__call__ (recurrent_state_pool=pool).
  • New: layers/attention/fla/gated_rmsnorm.py — gated RMSNorm helper used by KimiLinear.

HybridLinearAttnBackend.__call__ is kept upstream-clean (no pool kwarg aliasing).

Accuracy Tests

Validated on the parent branch (feat/kda-e2e-kimi-linear, same logical changes) via full mmlu_pro on Kimi-Linear-48B-A3B-Instruct, T=1.0, bs=32:

Config OVERALL Wall
tp16dp1 0.6609 8h23m47s
tp4dp4 0.6574 9h12m51s

Pre-fix tp4dp4 collapses to ~0.27 — the dummy-slot guard is the load-bearing fix.

A full re-run of mmlu_pro on this rebased branch is pending; smoke + limit=20 sanity to follow.

Benchmarking and Profiling

No throughput delta expected vs base branch — same kernels, same forward path; only wire-up + a single-int dummy guard inside scatter shard_maps.

Checklist

  • Please use English, otherwise it will be closed.
  • The purpose of the PR, or link existing issues this PR will resolve.
  • The test plan, such as providing test command.
  • (Optional) The necessary documentation update.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Kimi Delta Attention (KDA) support into the upstream kimi_linear branch. The changes focus on enabling hybrid recurrent model support, improving state management to prevent memory pollution, and ensuring the compilation pipeline correctly handles recurrent state requirements without impacting non-recurrent backends.

Highlights

  • KDA State Management: Implemented a dummy-slot pollution guard in set_ssm_state and set_conv_state to prevent padding rows from leaking garbage into the recurrent state.
  • Hybrid Backend Integration: Updated HybridLinearAttnBackend to dynamically build the KDA sub-backend and correctly route layers based on the model configuration.
  • Configuration Detection: Enabled automatic detection of KimiLinearConfig via the model's hf_config attribute, replacing the previous stub implementation.
  • Compilation Manager Updates: Added a has_recurrent_state flag to the compilation manager to ensure dummy batch filling only occurs for relevant recurrent backends.
  • Gated RMSNorm: Introduced a new GatedRMSNorm helper module required for KimiLinear model normalization.
New Features

🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model:
- KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without
  this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro.
- HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend
  (upstream stub returned full_attn_backend unchanged → server crash).
- ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's
  linear_attn_config attribute (upstream property was a stub returning None).
- compilation_manager dummy batch fills recurrent_indices/has_initial_state
  only when has_recurrent_state is set, so non-recurrent backends are
  unaffected (CompilationManager grows a has_recurrent_state flag, plumbed
  from tp_worker via model_runner.linear_recurrent_config).
- gated_rmsnorm helper (used by KimiLinear).

HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@MokusMokun MokusMokun force-pushed the feat/kda-e2e-kimi-linear-v2 branch from 95d48ea to 0c12f5a Compare May 13, 2026 11:28
@MokusMokun
Copy link
Copy Markdown
Contributor Author

MokusMokun commented May 13, 2026

E2E test results

UTC 2026-05-13 12:04:28 → 2026-05-13 20:10:03, wall 8h02m45s (eval first prefill at 12:07:18; pod TZ = UTC).

+---------+-----------+-----------------+------------------+-------+---------+---------+
| Model   | Dataset   | Metric          | Subset           |   Num |   Score | Cat.0   |
+=========+===========+=================+==================+=======+=========+=========+
|         | mmlu_pro  | AverageAccuracy | computer science |   410 |  0.6854 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | math             |  1351 |  0.7809 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | chemistry        |  1132 |  0.7323 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | engineering      |   969 |  0.5387 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | law              |  1101 |  0.4124 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | biology          |   717 |  0.795  | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | health           |   818 |  0.6064 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | physics          |  1299 |  0.7236 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | business         |   789 |  0.7148 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | philosophy       |   499 |  0.5832 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | economics        |   844 |  0.718  | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | other            |   924 |  0.6039 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | psychology       |   798 |  0.7043 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | history          |   381 |  0.5669 | default |
+---------+-----------+-----------------+------------------+-------+---------+---------+
|         | mmlu_pro  | AverageAccuracy | OVERALL          | 12032 |  0.6602 | -       |
+---------+-----------+-----------------+------------------+-------+---------+---------+

Reproducibility

Environment

  • TPU v6e 4x4 slice (4 hosts × 4 chips, tp=16 dp=4)
  • Base image: python:3.12
  • Model checkpoint: /models/Kimi-Linear-48B-A3B-Instruct/ (mounted from object storage)

Per-host bootstrap

pip install -U --pre jax jaxlib libtpu \
  -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install uv
uv venv /tmp/venv && source /tmp/venv/bin/activate
uv pip install pip
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax && git checkout <this PR's HEAD>
uv pip install -e python[all]
uv pip install evalscope==0.17.1

Server launch (run on each of the 4 hosts; $i = node rank 0..3, $POD0_IP = rank-0 host IP)

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
python -u -m sgl_jax.launch_server \
    --model-path /models/Kimi-Linear-48B-A3B-Instruct/ --trust-remote-code \
    --tp-size 16 --dp-size 4 \
    --device tpu --dtype bfloat16 --mem-fraction-static 0.95 \
    --chunked-prefill-size 512 --page-size 256 --max-running-requests 64 \
    --skip-server-warmup --disable-radix-cache --disable-overlap-schedule \
    --precompile-bs-paddings 1 16 64 --precompile-token-paddings 16 128 512 \
    --nnodes 4 --node-rank $i --dist-init-addr $POD0_IP:29500 \
    --host 0.0.0.0 --port 30000

Wait for The server is fired up and ready to roll! on rank-0 before launching the eval.

Full mmlu_pro eval (rank-0 host)

evalscope eval \
  --model /models/Kimi-Linear-48B-A3B-Instruct/ \
  --api-url http://127.0.0.1:30000/v1 \
  --api-key EMPTY \
  --eval-type service \
  --datasets mmlu_pro \
  --eval-batch-size 32 \
  --generation-config '{"temperature": 1.0}' \
  --timeout 10000000

@aolemila aolemila merged commit 13e6792 into sgl-project:feat/support_kimi_linear_model May 14, 2026
1 check passed
aolemila pushed a commit that referenced this pull request May 14, 2026
…1072)

Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model:
- KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without
  this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro.
- HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend
  (upstream stub returned full_attn_backend unchanged → server crash).
- ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's
  linear_attn_config attribute (upstream property was a stub returning None).
- compilation_manager dummy batch fills recurrent_indices/has_initial_state
  only when has_recurrent_state is set, so non-recurrent backends are
  unaffected (CompilationManager grows a has_recurrent_state flag, plumbed
  from tp_worker via model_runner.linear_recurrent_config).
- gated_rmsnorm helper (used by KimiLinear).

HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing).

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
@aolemila aolemila mentioned this pull request May 14, 2026
2 tasks
aolemila added a commit that referenced this pull request May 14, 2026
* add KimiLinearForCausalLM. Co-authored-by: zhengke.zhou.dev@gmail.com

* feat(kimi-linear): port e2e wire-up onto upstream kimi_linear branch (#1072)

Squash of the local kda-e2e branch onto origin/feat/support_kimi_linear_model:
- KDA dummy-slot pollution guard (set_ssm_state, set_conv_state) — without
  this, DP runs (tp4dp4) collapse from ~0.66 to ~0.27 OVERALL on mmlu_pro.
- HybridLinearAttnBackend.attn_backend_wrapper builds real KDA sub-backend
  (upstream stub returned full_attn_backend unchanged → server crash).
- ModelRunner.linear_recurrent_config detects KimiLinearConfig by hf_config's
  linear_attn_config attribute (upstream property was a stub returning None).
- compilation_manager dummy batch fills recurrent_indices/has_initial_state
  only when has_recurrent_state is set, so non-recurrent backends are
  unaffected (CompilationManager grows a has_recurrent_state flag, plumbed
  from tp_worker via model_runner.linear_recurrent_config).
- gated_rmsnorm helper (used by KimiLinear).

HybridLinearAttnBackend.__call__ kept upstream-clean (no pool kwarg aliasing).

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>

* test(kda): wrap raw KDA backend to accept `pool=` kwarg

KDAAttnBackend.__call__ takes `recurrent_state_pool=`, while
RadixLinearAttention.__call__ passes `pool=` (HybridLinearAttnBackend's
calling convention). Production routes through that wrapper which
translates pool→recurrent_state_pool; the unit tests bypass it by
assigning a raw KDAAttnBackend as `forward_batch.attn_backend`, so the
kwarg falls into **kwargs and `recurrent_state_pool` is unbound →
TypeError. Replicate the translation in a test-only shim.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(kda): hoist KDAAttnBackendForTest shim to test_utils.py

Both KDA test files were carrying an identical copy of the `pool=` →
`recurrent_state_pool=` translation shim added in 466afff. Move it to
test_utils.py and import from both, dropping the local underscore prefix
since it's now a shared helper.

Co-Authored-By: zhengke.zhou.dev@gmail.com
Co-authored-by: Mirope Yuhao Hu <miropehu@gmail.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants