feat: support ling-2.6-flash#1043
Merged
Merged
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
fd7f3d4 to
ac4bfea
Compare
9b208c6 to
a3b7cec
Compare
Rodrian7
reviewed
May 9, 2026
cjx0709
commented
May 13, 2026
2aa8b8a to
7806f2d
Compare
Extract LightningAttnBackend from BailingMoeV2_5LinearAttention, following the KDA pattern where the model layer is pure projection/norm logic and attention dispatch goes through a backend. - Add LightningAttnBackend with tp_slope pre-computation in __init__ - Add production-scale shard tests (H=16, K=128, TP=4) - Consolidate 6 test files into 4 (test_gla.py, test_gla_backend.py, test_gla_shard.py, plus cross-framework removed) - Move scatter/gather metadata from fla/ to linear/ - Trim redundant test cases (40→27 in test_gla_shard.py) Co-authored-by: Rodrian7 <weiyuxin@zju.edu.cn>
…ix wrapper Vendor the pallas-kernel #313 implementation that handles varlen sequence padding internally via seq_real_lens parameter, eliminating the need for scatter/gather metadata plumbing in the backend. - Kernel pads sequences to chunk_size multiples internally - Remove GLAMetadata and scatter/gather infrastructure - Add RadixLightningAttention wrapper (mirrors KDA's RadixLinearAttention) - Update BailingMoeV2_5LinearAttention to use Radix dispatcher Co-authored-by: Rodrian7 <weiyuxin@zju.edu.cn>
- Remove backend parameter from BailingMoeV2_5LinearAttention constructor - Wire backend via _make_forward_batch helper - Use backend.tp_slope[layer_id] instead of module.slope - Simplify _copy_weights_across_meshes (no backend detach/reattach) - Add linear_recurrent_layer_ids to LightningAttnBackend construction - Remove chunk-aligned padding from numpy reference (kernel handles via seq_real_lens) - Add FlashAttention import guard with pytest.skip
- LightningAttnBackend: Add DP-aware state management with shard_map
- Support DP>1 with P("data", "tensor", None, None) sharding for recurrent state
- DP-local indexing for state gather/scatter operations
- Varlen format with DP-sharded cu_seqlens in extend mode
- Per-layer slope storage in tp_slope dict (indexed by layer_id)
- LinearRecurrentAttnBackend: Add DP metadata computation
- cu_q_lens: 1D array [dp_size * (per_dp_bs_size+1)] for DP>1
- recurrent_indices: Sharded along P("data") axis
- Support both DP=1 (backward compat) and DP>1 paths
- Test suite: 32 tests total across 4 files
- test_lightning_backend_dp.py: 7 DP correctness tests
* Single/multi-request decode and extend with DP=2
* DP=2+TP=2 hybrid parallelism validation
* State isolation and varlen padding verification
- test_gla.py: 2 model-level tests (slopes, placeholder)
- test_lightning_backend.py: 18 end-to-end integration tests
- test_gla_backend.py: 5 backend unit tests
- Remove torch dependency tests
- Delete TestModuleLevelMockKernel (HF parity with torch)
- Delete TestTPConsistency (2 TP tests with loose tolerances)
- JAX reference implementation provides sufficient validation
- RFC update: Document DP implementation and test coverage
- Add Data Parallelism section with metadata layout and state management
- Update test coverage summary with all 32 tests
- Mark core dependencies as Ready
- Update work breakdown with completed tasks
All tests passing on TPU v6e-4.
* fix(gla): keep recurrent metadata data-sharded Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * little fix * feat(gla): wire Bailing hybrid recurrent config Enable Ling/Bailing hybrid configs to use the existing recurrent state and hybrid KV pool path, and remove obsolete GLA backend unit coverage. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(gla): wire recurrent state dtype through config Use a shared recurrent state dtype object for Bailing hybrid pool sizing and allocation so SSM state remains fp32 by default. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * refactor(gla): use recurrent state params dataclass Route Bailing hybrid recurrent pool sizing and allocation through a typed state params object instead of storing dtype in the legacy linear attention config dict. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(gla): align Bailing linear state heads --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
* test(gla): move native reference and drop legacy gla test * test(gla): remove stale backend suite entry * test(gla): fold shard coverage into lightning backend * test(gla): slice padded outputs on host * test(gla): avoid asserting padded slot state value
SGLANG_JAX_GLA_BACKEND=native \
/opt/venv/bin/python -m sgl_jax.launch_server \
--model-path /models/Ling-2.6-flash \
--host 0.0.0.0 --port 30271 \
--tp-size 16 \
--dp-size 4 \
--ep-size 4 --moe-backend epmoe \
--attention-backend fa \
--disable-radix-cache --disable-overlap-schedule \
--disable-precompile --skip-server-warmup \
--max-running-requests 16
HF config has norm_topk_prob=False but HF reference and sglang (via fused_topk_deepseek) always renormalize. Hardcode to True instead of reading from config to match actual upstream behavior.
Non-hybrid models (e.g. Qwen3-8B) have recurrent_indices=None at runtime. The dummy batch unconditionally set recurrent_indices and has_initial_state to np.zeros, causing pytree structure mismatch between precompile and runtime → JIT cache miss. Add is_hybrid_recurrent flag to CompilationManager so recurrent fields are only populated when the model uses HybridReqToTokenPool.
JamesBrianD
previously approved these changes
May 14, 2026
JamesBrianD
reviewed
May 14, 2026
Collaborator
JamesBrianD
left a comment
There was a problem hiding this comment.
Apologies—let's wait until Kimi linear has been merged in first before merging this.
Apologies—let's wait until Kimi linear has been merged in first before merging this.
Resolved 5 conflicts: - hf_transformers_utils.py: register both BailingHybridConfig and KimiLinearConfig - tp_worker.py: unify on has_recurrent_state (using linear_recurrent_config is not None) - compilation_manager.py: drop is_hybrid_recurrent alias, keep has_recurrent_state - model_runner_kv_cache_mixin.py: keep cjx kimi_linear_config + lightning_config + aggregated linear_recurrent_config; upgrade kimi_linear_config to detect hf_config.linear_attn_config so KDA dispatch from main works - hybrid_linear_attn_backend.py: keep two-way dispatch (KDA + Lightning) with main's simpler KDA import
lightning_config is a @Property on ModelRunnerKVCacheMixin and is always available on ModelRunner instances; getattr default is dead code.
Decorate the 6 hot paths in the linear-attention stack so xprof traces show readable spans instead of anonymous shard_map closures: - lightning_decode / lightning_extend - kda_decode / kda_extend - short_conv_decode / short_conv_extend
aolemila
approved these changes
May 14, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds Data Parallelism (DP) support to LightningAttnBackend for GLA (Gated Linear Attention) and includes a comprehensive test suite, plus end-to-end validation on Ling-2.6-Flash (BailingMoeHybrid) over TPU v6e-16.
Key Changes
LightningAttnBackend: DP-aware state management with shard_map
P("data", "tensor", None, None)sharding for recurrent stateLinearRecurrentAttnBackend: DP metadata computation
[dp_size * (per_dp_bs_size+1)]for DP>1P("data")axisTest Coverage: 32 tests total across 4 files
Cleanup: Removed torch dependency tests (HF parity tests)
Eval coverage: added
test/srt/eval/simple_eval_{aime26,csimpleqa}.py(driven by the existingtest/srt/run_eval.pyrunner) for the validation below.Testing
All 32 tests passing on TPU v6e-4.
Documentation
See updated RFC in
docs/design/bailing_moe_linear_attention.mdfor detailed DP implementation and test coverage.Validation on Ling-2.6-Flash (BailingMoeHybrid, TPU v6e-16)
End-to-end accuracy on the inclusionAI/Ling-2.6-flash model, served by this branch. Numbers compared to the official table on https://www.modelscope.cn/models/inclusionAI/Ling-2.6-flash.
Server environment
nnodes=4, one rank per node)./models/Ling-2.6-flash(BailingMoeV2.5 hybrid attention)./sglang-jax, branchcjx/gla-kernel./opt/venv/bin/python(sgl_jax installed).Launch command
Run the following on each rank with
NODE_RANK ∈ {0,1,2,3}andMASTER_ADDRset to rank-0's reachable address. All four ranks should be started concurrently (do not wait for any rank to become ready before starting the next).Wait for
'The server is fired up and ready to roll!'in the log on rank 0 (typical: ~5–6 min).Eval results (vs official)
Notes:
LivingFutureLab/ChineseSimpleQAGRADER_TEMPLATE; judge runs against the same served endpoint (biased vs GPT-4o, but lands within 1pp here).Eval commands
Run from any node that can reach the served port (here we hit the local rank-0 server at
127.0.0.1:30271). Both evals download their dataset from HuggingFace on first run.AIME26 (full 30 problems, parallel=30, max_tokens=32768 — requires
--context-length 65536)Chinese-SimpleQA (full 3000 questions, parallel=256, with same-endpoint judge)
Closes #1032