feat(spec): MultiLayerEAGLEWorker/MultiLayerDraftWorker (#1053 P1-4)#1089
feat(spec): MultiLayerEAGLEWorker/MultiLayerDraftWorker (#1053 P1-4)#1089zorrofox wants to merge 9 commits into
Conversation
…1053 P1-4) 3-layer MTP for MiMo-V2.5-Pro: N draft model runners (one per mtp_layer_idx), hidden states chained layer→layer in draft_extend_*; draft_forward step i uses runner(i). Verify/orchestration reused from EAGLEWorker via draft_worker injection. Includes: - NEXTN SpeculativeAlgorithm + scheduler dispatch - model_config: draft arch=MiMoV2MTPForCausalLM, num_hidden_layers=1, quant ignore eh_proj/o_proj (bf16 in checkpoint) - mimo_v2_nextn: 4-tuple __call__ + concat reshard - mimo_v2_flash: get_embed_and_head - fa_backend: TARGET_VERIFY custom_mask page-align pad + swa_page_indices - rpa_v3: _fetch_mask aligned mask stride + pl.multiple_of hints - spec prefill via padded get_model_worker_batch (epmoe padding-sensitivity) - EAGLEWorker.__init__ accepts draft_worker (fold sgl-project#1080 review (4)) E2E v6e-64: bs=1 topk=1 greedy accept-len ~2.5 (5-prompt mix), output matches nospec for first N>=13 tokens (later divergence is upstream epmoe bf16 padding-sensitivity, tracked separately).
…ject#1080 review (1)) verify(), forward_target_extend(), forward_batch_speculative_generation() only touch BaseSpecWorker state (target_worker, draft_worker, mesh, speculative_num_*); moving them up makes EAGLEWorker and MultiLayerEAGLEWorker thin draft_worker-injection wrappers. BaseDraftWorker gains explicit abstract draft_extend_for_{prefill,decode} + draft_model_runner so the contract is visible. (3) precompile multi-layer warm-up left as TODO (currently --disable-precompile in E2E; only layer 0 would be warmed via _worker delegation).
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
EAGLE/EAGLE3 dense targets don't hit sgl-project#1090 (no MoE-EP) and EagleDraftWorker.draft_extend_for_prefill doesn't yet handle padded mwb (sharded [:real_bs] slice → ShardingTypeError). Restores CI test_speculative_decoding (Qwen3-32B EAGLE3).
Kernel _fetch_mask now uses page-aligned mask row stride (cu_kv_lens delta) so DMA offset/size are 8-divisible. Update ref impl + test helper to construct masks with the same padded layout (host-side fa_backend already does). Restores test_flashattention custom_mask tests.
…ract" This reverts commit 9f11539.
The host-side mask pad + kernel pl.multiple_of hints were added for the NEXTN hybrid-SWA verify DMA 8-align crash, but they regress EAGLE3 (dense Qwen3-32B): accept-len 1.5→1.06 and per-round JIT cache_miss (padded mask shape varies with seq_len → recompile). Restoring the original unpadded mask path so EAGLE3 CI recovers; the NEXTN hybrid case will be re-fixed via a hybrid-only path that doesn't touch the dense kernel contract. swa_page_indices for TARGET_VERIFY is kept (independent hybrid fix).
| # launch draft worker | ||
| if self.spec_algorithm is not None and self.spec_algorithm.is_eagle(): | ||
| from sgl_jax.srt.speculative.eagle_worker import EAGLEWorker | ||
| if self.spec_algorithm.is_nextn(): |
There was a problem hiding this comment.
nextn's logic == eagle logic
the only difference is nextn's draft model params is from target model layers, here is not correct
we should adjust this judge code to use is there are multi layer draft model params for different steps
There was a problem hiding this comment.
is_nextn() is the user's explicit CLI choice (--speculative-algorithm NEXTN). Agree the model-side check belongs at construction — added an assert in MultiLayerDraftWorker.__init__ against hf_config.num_nextn_predict_layers (93f881e). Keeping the dispatch on the algorithm enum so a future single-layer NEXTN can route to EagleDraftWorker by checking num_nextn_predict_layers==1 there if needed.
| # nospec so target forward sees identical input shapes (#1090). | ||
| # Dense EAGLE/EAGLE3 targets don't need this and EagleDraftWorker | ||
| # doesn't yet handle padded prefill mwb, so keep the unpadded path. | ||
| _get = ( |
There was a problem hiding this comment.
this function name alias is not easy for understanding
There was a problem hiding this comment.
Expanded to explicit if/else in 93f881e.
| ) -> tuple[jax.Array, list[jax.Array]]: | ||
| embed = self.embed_tokens(forward_batch.input_ids) | ||
| hidden_in = forward_batch.spec_info.hidden_states | ||
| emb_sh = jax.typeof(embed).sharding |
There was a problem hiding this comment.
can we reuse "replicate_to_mesh" ?
There was a problem hiding this comment.
This needs reshard to the embed's sharding (e.g. P("data", "tensor")) so the subsequent concat is consistent — replicate_to_mesh would force P() and add an unnecessary all-gather. Different intent.
| ) | ||
| self.hot_token_ids = None | ||
|
|
||
| self.num_mtp_layers = self.speculative_num_steps |
There was a problem hiding this comment.
Shall we add an assert here ?
There was a problem hiding this comment.
Added in 93f881e (checks hf_config.num_nextn_predict_layers == speculative_num_steps).
…ory leak) Spec decode allocates KV via EagleDraftInput.prepare_for_decode, not ScheduleBatch.prepare_for_decode, so the per-step kv_committed_len bump never happens. cache_finished_req then only frees the prefill-time committed range, leaking every decode-allocated page. EAGLE3 CI (bs=16) never hits idle so check_memory never fires; NEXTN bs=1 idles after each request and crashes with 'token_to_kv_pool_allocator memory leak'.
…l-side static) NEXTN (V2.5-Pro, page=256) hits Mosaic tiling(8) proof in _fetch_mask; EAGLE3 (page=64) does not. Derive mask_aligned_to_cu_kv inside the kernel from kv_cache page_size (static shape) — passing it as a kwarg from fa_backend gets traced through jit. Host pads mask rows under the same page_size>=256 condition. Dense path unchanged from upstream.
Addresses sgl-project#1089 review (2)/(4).
Part of #1053 Phase 1 (P1-4). Stacked on
epic/mtp-refactor-phase1(#1080 merged).What
3-layer MTP for MiMo-V2.5-Pro via
MultiLayerEAGLEWorker+MultiLayerDraftWorker:ModelWorkers (one permtp_layer_idx), hidden states chained layer→layer indraft_extend_*draft_forwardstep i uses runner i (per-layer KV pool, per-layerforward_metadata)BaseSpecWorkerviadraft_workerinjectionTwo commits:
5db87f1f— functional:MultiLayer*+ NEXTN wire-up + model/config/kernel adaptationsb731d0c6— refactor: liftverify/forward_target_extend/forward_batch_speculative_generationtoBaseSpecWorker(fold refactor(spec): split EAGLEWorker into BaseSpecWorker/BaseDraftWorker… #1080 review item (1)/(4);EAGLEWorkerreduced todraft_workerinjection + logprob/precompile)Includes
SpeculativeAlgorithm.NEXTN+ scheduler dispatchmodel_config: draft arch→MiMoV2MTPForCausalLM,num_hidden_layers=1, quant ignoreeh_proj/o_proj(BF16 in checkpoint)mimo_v2_nextn: 4-tuple__call__+ concat reshard;mimo_v2_flash:get_embed_and_headflashattention_backend: TARGET_VERIFYcustom_maskpage-align pad +swa_page_indicesfor hybridrpa_v3:_fetch_maskaligned mask stride +pl.multiple_ofhints (kernel side of mask pad)get_model_worker_batch(epmoe padding-sensitivity, see below)E2E (v6e-64, V2.5-Pro,
--ep-size 64 --moe-backend epmoe, bs=1 topk=1 greedy)Known limitation: spec ≠ nospec on MoE+EP after ~13–30 tokens
Verify-stage divergence is not a verify-logic bug — it's upstream
EPMoEbf16 padding-sensitivity:get_spec_model_worker_batchdoesn't pad tokens/bs (vs nospec'sget_model_worker_batch), so the same real tokens land at different row indices inside per-expert GMM groups after_permute'sargsort(topk_ids)→ bf16 accumulation order differs → occasional argmax flip. Padded prefill (this PR) makes the first token match; verify (4-token) vs nospec decode (1-token) are inherently different shapes. Dense targets (P1-0 Llama) unaffected. Tracked in #1090 — likely needs a router-side padding mask.TODO (follow-up, not blocking)
run_spec_decode_precompileonly warms layer 0 via_workerdelegation (currently--disable-precompilein E2E)draft_forwardper-layer semantics: testedrunner(i+1)@ step-0 position — accept-len dropped (3.5→2.0), reverted; currentrunner(i)@ step-i is empirically best but the layer/token-shift contract vs MiMo training needs confirmingEagleDraftWorker.__init__extract_init_common()(review (4) full version) —MultiLayerDraftWorker.__init__still copies ~20 linesTest plan
test_speculative_decoding.py(P1-8 will add NEXTN case)