Skip to content

feat(spec): MultiLayerEAGLEWorker/MultiLayerDraftWorker (#1053 P1-4)#1089

Open
zorrofox wants to merge 9 commits into
sgl-project:epic/mtp-refactor-phase1from
zorrofox:feat/p1-4-rebased
Open

feat(spec): MultiLayerEAGLEWorker/MultiLayerDraftWorker (#1053 P1-4)#1089
zorrofox wants to merge 9 commits into
sgl-project:epic/mtp-refactor-phase1from
zorrofox:feat/p1-4-rebased

Conversation

@zorrofox
Copy link
Copy Markdown
Contributor

@zorrofox zorrofox commented May 14, 2026

Part of #1053 Phase 1 (P1-4). Stacked on epic/mtp-refactor-phase1 (#1080 merged).

What

3-layer MTP for MiMo-V2.5-Pro via MultiLayerEAGLEWorker + MultiLayerDraftWorker:

  • N independent draft ModelWorkers (one per mtp_layer_idx), hidden states chained layer→layer in draft_extend_*
  • draft_forward step i uses runner i (per-layer KV pool, per-layer forward_metadata)
  • Verify/orchestration reused from BaseSpecWorker via draft_worker injection

Two commits:

Includes

  • SpeculativeAlgorithm.NEXTN + scheduler dispatch
  • model_config: draft arch→MiMoV2MTPForCausalLM, num_hidden_layers=1, quant ignore eh_proj/o_proj (BF16 in checkpoint)
  • mimo_v2_nextn: 4-tuple __call__ + concat reshard; mimo_v2_flash: get_embed_and_head
  • flashattention_backend: TARGET_VERIFY custom_mask page-align pad + swa_page_indices for hybrid
  • rpa_v3: _fetch_mask aligned mask stride + pl.multiple_of hints (kernel side of mask pad)
  • spec prefill via padded get_model_worker_batch (epmoe padding-sensitivity, see below)

E2E (v6e-64, V2.5-Pro, --ep-size 64 --moe-backend epmoe, bs=1 topk=1 greedy)

crash-free ✓ prefill / 3×draft_extend / draft_forward / verify / decode loop
accept-len warm 8-round mean 2.53 (5-prompt mix: natural / code / periodic)
output vs nospec first token matches; first ~13–30 tokens match per prompt

Known limitation: spec ≠ nospec on MoE+EP after ~13–30 tokens

Verify-stage divergence is not a verify-logic bug — it's upstream EPMoE bf16 padding-sensitivity: get_spec_model_worker_batch doesn't pad tokens/bs (vs nospec's get_model_worker_batch), so the same real tokens land at different row indices inside per-expert GMM groups after _permute's argsort(topk_ids) → bf16 accumulation order differs → occasional argmax flip. Padded prefill (this PR) makes the first token match; verify (4-token) vs nospec decode (1-token) are inherently different shapes. Dense targets (P1-0 Llama) unaffected. Tracked in #1090 — likely needs a router-side padding mask.

TODO (follow-up, not blocking)

  • refactor(spec): split EAGLEWorker into BaseSpecWorker/BaseDraftWorker… #1080 review (3): run_spec_decode_precompile only warms layer 0 via _worker delegation (currently --disable-precompile in E2E)
  • draft_forward per-layer semantics: tested runner(i+1) @ step-0 position — accept-len dropped (3.5→2.0), reverted; current runner(i) @ step-i is empirically best but the layer/token-shift contract vs MiMo training needs confirming
  • EagleDraftWorker.__init__ extract _init_common() (review (4) full version) — MultiLayerDraftWorker.__init__ still copies ~20 lines

Test plan

  • v6e-64 E2E: 3-layer MTP load + serve + 6-prompt generate, no crash
  • accept-len ≥2.5 (warm)
  • spec output = nospec for first N≥13 tokens
  • CI test_speculative_decoding.py (P1-8 will add NEXTN case)

Test User added 2 commits May 14, 2026 13:35
…1053 P1-4)

3-layer MTP for MiMo-V2.5-Pro: N draft model runners (one per mtp_layer_idx),
hidden states chained layer→layer in draft_extend_*; draft_forward step i uses
runner(i). Verify/orchestration reused from EAGLEWorker via draft_worker
injection.

Includes:
- NEXTN SpeculativeAlgorithm + scheduler dispatch
- model_config: draft arch=MiMoV2MTPForCausalLM, num_hidden_layers=1, quant
  ignore eh_proj/o_proj (bf16 in checkpoint)
- mimo_v2_nextn: 4-tuple __call__ + concat reshard
- mimo_v2_flash: get_embed_and_head
- fa_backend: TARGET_VERIFY custom_mask page-align pad + swa_page_indices
- rpa_v3: _fetch_mask aligned mask stride + pl.multiple_of hints
- spec prefill via padded get_model_worker_batch (epmoe padding-sensitivity)
- EAGLEWorker.__init__ accepts draft_worker (fold sgl-project#1080 review (4))

E2E v6e-64: bs=1 topk=1 greedy accept-len ~2.5 (5-prompt mix), output matches
nospec for first N>=13 tokens (later divergence is upstream epmoe bf16
padding-sensitivity, tracked separately).
…ject#1080 review (1))

verify(), forward_target_extend(), forward_batch_speculative_generation()
only touch BaseSpecWorker state (target_worker, draft_worker, mesh,
speculative_num_*); moving them up makes EAGLEWorker and
MultiLayerEAGLEWorker thin draft_worker-injection wrappers.

BaseDraftWorker gains explicit abstract draft_extend_for_{prefill,decode}
+ draft_model_runner so the contract is visible.

(3) precompile multi-layer warm-up left as TODO (currently
--disable-precompile in E2E; only layer 0 would be warmed via
_worker delegation).
@gemini-code-assist
Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zorrofox zorrofox marked this pull request as ready for review May 15, 2026 00:10
@gemini-code-assist
Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zorrofox zorrofox closed this May 15, 2026
@zorrofox zorrofox reopened this May 15, 2026
Test User added 4 commits May 15, 2026 01:18
EAGLE/EAGLE3 dense targets don't hit sgl-project#1090 (no MoE-EP) and
EagleDraftWorker.draft_extend_for_prefill doesn't yet handle padded
mwb (sharded [:real_bs] slice → ShardingTypeError). Restores CI
test_speculative_decoding (Qwen3-32B EAGLE3).
Kernel _fetch_mask now uses page-aligned mask row stride (cu_kv_lens
delta) so DMA offset/size are 8-divisible. Update ref impl + test
helper to construct masks with the same padded layout (host-side
fa_backend already does). Restores test_flashattention custom_mask
tests.
The host-side mask pad + kernel pl.multiple_of hints were added for the
NEXTN hybrid-SWA verify DMA 8-align crash, but they regress EAGLE3
(dense Qwen3-32B): accept-len 1.5→1.06 and per-round JIT cache_miss
(padded mask shape varies with seq_len → recompile). Restoring the
original unpadded mask path so EAGLE3 CI recovers; the NEXTN hybrid
case will be re-fixed via a hybrid-only path that doesn't touch the
dense kernel contract. swa_page_indices for TARGET_VERIFY is kept
(independent hybrid fix).
# launch draft worker
if self.spec_algorithm is not None and self.spec_algorithm.is_eagle():
from sgl_jax.srt.speculative.eagle_worker import EAGLEWorker
if self.spec_algorithm.is_nextn():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nextn's logic == eagle logic

the only difference is nextn's draft model params is from target model layers, here is not correct

we should adjust this judge code to use is there are multi layer draft model params for different steps

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_nextn() is the user's explicit CLI choice (--speculative-algorithm NEXTN). Agree the model-side check belongs at construction — added an assert in MultiLayerDraftWorker.__init__ against hf_config.num_nextn_predict_layers (93f881e). Keeping the dispatch on the algorithm enum so a future single-layer NEXTN can route to EagleDraftWorker by checking num_nextn_predict_layers==1 there if needed.

# nospec so target forward sees identical input shapes (#1090).
# Dense EAGLE/EAGLE3 targets don't need this and EagleDraftWorker
# doesn't yet handle padded prefill mwb, so keep the unpadded path.
_get = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function name alias is not easy for understanding

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded to explicit if/else in 93f881e.

) -> tuple[jax.Array, list[jax.Array]]:
embed = self.embed_tokens(forward_batch.input_ids)
hidden_in = forward_batch.spec_info.hidden_states
emb_sh = jax.typeof(embed).sharding
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse "replicate_to_mesh" ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs reshard to the embed's sharding (e.g. P("data", "tensor")) so the subsequent concat is consistent — replicate_to_mesh would force P() and add an unnecessary all-gather. Different intent.

)
self.hot_token_ids = None

self.num_mtp_layers = self.speculative_num_steps
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add an assert here ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 93f881e (checks hf_config.num_nextn_predict_layers == speculative_num_steps).

Test User added 3 commits May 15, 2026 07:05
…ory leak)

Spec decode allocates KV via EagleDraftInput.prepare_for_decode, not
ScheduleBatch.prepare_for_decode, so the per-step kv_committed_len bump
never happens. cache_finished_req then only frees the prefill-time
committed range, leaking every decode-allocated page. EAGLE3 CI (bs=16)
never hits idle so check_memory never fires; NEXTN bs=1 idles after each
request and crashes with 'token_to_kv_pool_allocator memory leak'.
…l-side static)

NEXTN (V2.5-Pro, page=256) hits Mosaic tiling(8) proof in _fetch_mask;
EAGLE3 (page=64) does not. Derive mask_aligned_to_cu_kv inside the
kernel from kv_cache page_size (static shape) — passing it as a kwarg
from fa_backend gets traced through jit. Host pads mask rows under the
same page_size>=256 condition. Dense path unchanged from upstream.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants