Skip to content

feat: support ling-2.6-flash#1043

Merged
cjx0709 merged 39 commits into
mainfrom
cjx/gla-kernel
May 14, 2026
Merged

feat: support ling-2.6-flash#1043
cjx0709 merged 39 commits into
mainfrom
cjx/gla-kernel

Conversation

@cjx0709
Copy link
Copy Markdown
Collaborator

@cjx0709 cjx0709 commented May 8, 2026

Summary

This PR adds Data Parallelism (DP) support to LightningAttnBackend for GLA (Gated Linear Attention) and includes a comprehensive test suite, plus end-to-end validation on Ling-2.6-Flash (BailingMoeHybrid) over TPU v6e-16.

Key Changes

  • LightningAttnBackend: DP-aware state management with shard_map

    • Support DP>1 with P("data", "tensor", None, None) sharding for recurrent state
    • DP-local indexing for state gather/scatter operations
    • Varlen format with DP-sharded cu_seqlens in extend mode
    • Per-layer slope storage in tp_slope dict
  • LinearRecurrentAttnBackend: DP metadata computation

    • cu_q_lens: 1D array [dp_size * (per_dp_bs_size+1)] for DP>1
    • recurrent_indices: Sharded along P("data") axis
  • Test Coverage: 32 tests total across 4 files

    • 7 DP correctness tests (new file: test_lightning_backend_dp.py)
    • 18 end-to-end integration tests
    • 5 backend unit tests
    • 2 model-level tests
  • Cleanup: Removed torch dependency tests (HF parity tests)

    • JAX reference implementation provides sufficient validation
  • Eval coverage: added test/srt/eval/simple_eval_{aime26,csimpleqa}.py (driven by the existing test/srt/run_eval.py runner) for the validation below.

Testing

All 32 tests passing on TPU v6e-4.

Documentation

See updated RFC in docs/design/bailing_moe_linear_attention.md for detailed DP implementation and test coverage.

Validation on Ling-2.6-Flash (BailingMoeHybrid, TPU v6e-16)

End-to-end accuracy on the inclusionAI/Ling-2.6-flash model, served by this branch. Numbers compared to the official table on https://www.modelscope.cn/models/inclusionAI/Ling-2.6-flash.

Server environment

  • Hardware: TPU v6e-16, 4 nodes × 4 chips (nnodes=4, one rank per node).
  • Model: /models/Ling-2.6-flash (BailingMoeV2.5 hybrid attention).
  • Repo path on each node: /sglang-jax, branch cjx/gla-kernel.
  • Python: /opt/venv/bin/python (sgl_jax installed).

Launch command

Run the following on each rank with NODE_RANK ∈ {0,1,2,3} and MASTER_ADDR set to rank-0's reachable address. All four ranks should be started concurrently (do not wait for any rank to become ready before starting the next).

PYTHONPATH=/sglang-jax/python \
/opt/venv/bin/python -u -m sgl_jax.launch_server \
  --model-path /models/Ling-2.6-flash \
  --trust-remote-code \
  --tp-size 16 --dp-size 4 --ep-size 16 \
  --moe-backend epmoe \
  --nnodes 4 --node-rank "${NODE_RANK}" \
  --dist-init-addr "${MASTER_ADDR}:10011" \
  --host 0.0.0.0 --port 30271 \
  --page-size 128 \
  --context-length 65536 \
  --chunked-prefill-size 2048 \
  --dtype bfloat16 \
  --mem-fraction-static 0.90 \
  --max-running-requests 256 \
  --attention-backend fa \
  --disable-radix-cache \
  --skip-server-warmup \
  --log-level info

Wait for 'The server is fired up and ready to roll!' in the log on rank 0 (typical: ~5–6 min).

Eval results (vs official)

Category Benchmark Ling-2.6-flash (official) This branch Δ
Knowledge C-SimpleQA (judge_accuracy, same-endpoint judge) 60.23 60.87 +0.64
Math AIME26 (30 problems) 73.85 76.67 +2.82

Notes:

  • C-SimpleQA uses the official LivingFutureLab/ChineseSimpleQA GRADER_TEMPLATE; judge runs against the same served endpoint (biased vs GPT-4o, but lands within 1pp here).
  • AIME26 with 30 problems has high per-question variance; +2.82pp ≈ < 1 problem of difference and is within sampling noise.
  • Both runs: 0 request errors / 0 truncations.

Eval commands

Run from any node that can reach the served port (here we hit the local rank-0 server at 127.0.0.1:30271). Both evals download their dataset from HuggingFace on first run.

AIME26 (full 30 problems, parallel=30, max_tokens=32768 — requires --context-length 65536)

cd /sglang-jax && \
/opt/venv/bin/python test/srt/run_eval.py \
  --base-url http://127.0.0.1:30271 \
  --eval-name aime26 \
  --num-threads 30 \
  --temperature 0.6 \
  --max-tokens 32768

Chinese-SimpleQA (full 3000 questions, parallel=256, with same-endpoint judge)

cd /sglang-jax && \
/opt/venv/bin/python test/srt/run_eval.py \
  --base-url http://127.0.0.1:30271 \
  --eval-name csimpleqa \
  --num-threads 256 \
  --temperature 0.0 \
  --max-tokens 1024

Closes #1032

@gemini-code-assist
Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@cjx0709 cjx0709 force-pushed the cjx/gla-kernel branch 6 times, most recently from fd7f3d4 to ac4bfea Compare May 9, 2026 01:35
@cjx0709 cjx0709 changed the title feat(gla): Add Data Parallelism support for LightningAttnBackend feat(gla): support LightningAttnBackend with dp May 9, 2026
@cjx0709 cjx0709 force-pushed the cjx/gla-kernel branch 2 times, most recently from 9b208c6 to a3b7cec Compare May 9, 2026 07:00
Comment thread python/sgl_jax/test/test_lightning_backend.py Outdated
@JamesBrianD JamesBrianD changed the title feat(gla): support LightningAttnBackend with dp feat: support BailingMoeHybrid May 13, 2026
Comment thread python/sgl_jax/srt/layers/attention/linear/lightning_backend.py Outdated
@cjx0709 cjx0709 force-pushed the cjx/gla-kernel branch 2 times, most recently from 2aa8b8a to 7806f2d Compare May 13, 2026 08:49
cjx0709 and others added 15 commits May 14, 2026 13:31
Extract LightningAttnBackend from BailingMoeV2_5LinearAttention, following
the KDA pattern where the model layer is pure projection/norm logic and
attention dispatch goes through a backend.

- Add LightningAttnBackend with tp_slope pre-computation in __init__
- Add production-scale shard tests (H=16, K=128, TP=4)
- Consolidate 6 test files into 4 (test_gla.py, test_gla_backend.py,
  test_gla_shard.py, plus cross-framework removed)
- Move scatter/gather metadata from fla/ to linear/
- Trim redundant test cases (40→27 in test_gla_shard.py)

Co-authored-by: Rodrian7 <weiyuxin@zju.edu.cn>
…ix wrapper

Vendor the pallas-kernel #313 implementation that handles varlen sequence
padding internally via seq_real_lens parameter, eliminating the need for
scatter/gather metadata plumbing in the backend.

- Kernel pads sequences to chunk_size multiples internally
- Remove GLAMetadata and scatter/gather infrastructure
- Add RadixLightningAttention wrapper (mirrors KDA's RadixLinearAttention)
- Update BailingMoeV2_5LinearAttention to use Radix dispatcher

Co-authored-by: Rodrian7 <weiyuxin@zju.edu.cn>
- Remove backend parameter from BailingMoeV2_5LinearAttention constructor
- Wire backend via _make_forward_batch helper
- Use backend.tp_slope[layer_id] instead of module.slope
- Simplify _copy_weights_across_meshes (no backend detach/reattach)
- Add linear_recurrent_layer_ids to LightningAttnBackend construction
- Remove chunk-aligned padding from numpy reference (kernel handles via seq_real_lens)
- Add FlashAttention import guard with pytest.skip
- LightningAttnBackend: Add DP-aware state management with shard_map
  - Support DP>1 with P("data", "tensor", None, None) sharding for recurrent state
  - DP-local indexing for state gather/scatter operations
  - Varlen format with DP-sharded cu_seqlens in extend mode
  - Per-layer slope storage in tp_slope dict (indexed by layer_id)

- LinearRecurrentAttnBackend: Add DP metadata computation
  - cu_q_lens: 1D array [dp_size * (per_dp_bs_size+1)] for DP>1
  - recurrent_indices: Sharded along P("data") axis
  - Support both DP=1 (backward compat) and DP>1 paths

- Test suite: 32 tests total across 4 files
  - test_lightning_backend_dp.py: 7 DP correctness tests
    * Single/multi-request decode and extend with DP=2
    * DP=2+TP=2 hybrid parallelism validation
    * State isolation and varlen padding verification
  - test_gla.py: 2 model-level tests (slopes, placeholder)
  - test_lightning_backend.py: 18 end-to-end integration tests
  - test_gla_backend.py: 5 backend unit tests

- Remove torch dependency tests
  - Delete TestModuleLevelMockKernel (HF parity with torch)
  - Delete TestTPConsistency (2 TP tests with loose tolerances)
  - JAX reference implementation provides sufficient validation

- RFC update: Document DP implementation and test coverage
  - Add Data Parallelism section with metadata layout and state management
  - Update test coverage summary with all 32 tests
  - Mark core dependencies as Ready
  - Update work breakdown with completed tasks

All tests passing on TPU v6e-4.
* fix(gla): keep recurrent metadata data-sharded

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* little fix

* feat(gla): wire Bailing hybrid recurrent config

Enable Ling/Bailing hybrid configs to use the existing recurrent state and hybrid KV pool path, and remove obsolete GLA backend unit coverage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(gla): wire recurrent state dtype through config

Use a shared recurrent state dtype object for Bailing hybrid pool sizing and allocation so SSM state remains fp32 by default.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* refactor(gla): use recurrent state params dataclass

Route Bailing hybrid recurrent pool sizing and allocation through a typed state params object instead of storing dtype in the legacy linear attention config dict.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(gla): align Bailing linear state heads

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
* test(gla): move native reference and drop legacy gla test

* test(gla): remove stale backend suite entry

* test(gla): fold shard coverage into lightning backend

* test(gla): slice padded outputs on host

* test(gla): avoid asserting padded slot state value
JamesBrianD and others added 19 commits May 14, 2026 13:31
  SGLANG_JAX_GLA_BACKEND=native \
  /opt/venv/bin/python -m sgl_jax.launch_server \
    --model-path /models/Ling-2.6-flash \
    --host 0.0.0.0     --port 30271 \
    --tp-size 16 \
    --dp-size 4 \
    --ep-size 4     --moe-backend epmoe \
    --attention-backend fa \
    --disable-radix-cache     --disable-overlap-schedule \
    --disable-precompile     --skip-server-warmup \
    --max-running-requests 16
HF config has norm_topk_prob=False but HF reference and sglang
(via fused_topk_deepseek) always renormalize. Hardcode to True
instead of reading from config to match actual upstream behavior.
Non-hybrid models (e.g. Qwen3-8B) have recurrent_indices=None at
runtime. The dummy batch unconditionally set recurrent_indices and
has_initial_state to np.zeros, causing pytree structure mismatch
between precompile and runtime → JIT cache miss.

Add is_hybrid_recurrent flag to CompilationManager so recurrent
fields are only populated when the model uses HybridReqToTokenPool.
@cjx0709 cjx0709 changed the title feat: support BailingMoeHybrid feat: support ling-2.6-flash May 14, 2026
JamesBrianD
JamesBrianD previously approved these changes May 14, 2026
Copy link
Copy Markdown
Collaborator

@JamesBrianD JamesBrianD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies—let's wait until Kimi linear has been merged in first before merging this.

@JamesBrianD JamesBrianD dismissed their stale review May 14, 2026 08:36

Apologies—let's wait until Kimi linear has been merged in first before merging this.

Resolved 5 conflicts:
- hf_transformers_utils.py: register both BailingHybridConfig and KimiLinearConfig
- tp_worker.py: unify on has_recurrent_state (using linear_recurrent_config is not None)
- compilation_manager.py: drop is_hybrid_recurrent alias, keep has_recurrent_state
- model_runner_kv_cache_mixin.py: keep cjx kimi_linear_config + lightning_config + aggregated
  linear_recurrent_config; upgrade kimi_linear_config to detect hf_config.linear_attn_config so
  KDA dispatch from main works
- hybrid_linear_attn_backend.py: keep two-way dispatch (KDA + Lightning) with main's simpler
  KDA import
lightning_config is a @Property on ModelRunnerKVCacheMixin and is always
available on ModelRunner instances; getattr default is dead code.
Decorate the 6 hot paths in the linear-attention stack so xprof traces
show readable spans instead of anonymous shard_map closures:

- lightning_decode / lightning_extend
- kda_decode / kda_extend
- short_conv_decode / short_conv_extend
@cjx0709 cjx0709 merged commit 4a738c9 into main May 14, 2026
21 checks passed
@cjx0709 cjx0709 deleted the cjx/gla-kernel branch May 14, 2026 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] GLA Kernel + LightningAttnBackend Integration

4 participants