[TRTLLM-11508][refactor] Merge Eagle3 and MTP-eagle one-model workers#12353
[TRTLLM-11508][refactor] Merge Eagle3 and MTP-eagle one-model workers#12353zhaoyangwang-nvidia wants to merge 7 commits into
Conversation
98b1417 to
9f923d9
Compare
|
/bot run |
|
PR_Github #39589 [ run ] triggered by Bot. Commit: |
|
PR_Github #39589 [ run ] completed with state
|
9f923d9 to
01b1fe7
Compare
|
/bot run |
|
PR_Github #39733 [ run ] triggered by Bot. Commit: |
|
PR_Github #39733 [ run ] completed with state
|
01b1fe7 to
5a1ef63
Compare
|
/bot run |
|
PR_Github #39814 [ run ] triggered by Bot. Commit: |
|
PR_Github #39814 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39819 [ run ] triggered by Bot. Commit: |
|
PR_Github #39819 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39828 [ run ] triggered by Bot. Commit: |
|
PR_Github #39828 [ run ] completed with state
|
7e82c9c to
f4f8ac6
Compare
|
/bot run |
|
PR_Github #40431 [ run ] triggered by Bot. Commit: |
|
PR_Github #40431 [ run ] completed with state
|
8ed37a4 to
67e4428
Compare
|
PR_Github #48373 [ run ] triggered by Bot. Commit: |
|
PR_Github #48373 [ run ] completed with state
|
d2f8aee to
d01e724
Compare
|
/bot run |
|
PR_Github #48499 [ run ] triggered by Bot. Commit: |
|
PR_Github #48499 [ run ] completed with state
|
|
/bot run |
|
PR_Github #48545 [ run ] triggered by Bot. Commit: |
|
PR_Github #48545 [ run ] completed with state
|
|
/bot run |
|
PR_Github #48571 [ run ] triggered by Bot. Commit: |
|
PR_Github #48571 [ run ] completed with state |
|
Hi @mikeiovine @sunnyqgg @ziyixiong-nv could you help to review this PR, all CI passed and ready for reivew. |
| self.model_nextn = 0 | ||
| if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model( | ||
| if model_config.spec_config is not None and ( | ||
| model_config.spec_config.spec_dec_mode.is_mtp_one_model() or |
There was a problem hiding this comment.
Seems https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L347 is forgot to be updated, and I think this would be easy to break in the future.
is_mtp_vanilla is what you want for SpeculativeDecodingMode.MTP, so you can use is_mtp_vanilla in some places, and keep the previous is_mtp_one_model for SpeculativeDecodingMode.MTP or SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL.
Please also ensure that your local test shows the AR won't drop when using MTP Eagle one model.
There was a problem hiding this comment.
Done. Restored is_mtp_one_model() to union (MTP or MTP_EAGLE_ONE_MODEL) and used is_mtp_vanilla() where only vanilla MTP applies.
Test result:
Verified MTP Eagle one-model AR is unchanged after the unified worker refactor:
| Pre-refactor (ce788e0) | Post-refactor (this PR) | |
|---|---|---|
| Acceptance rate | 81.38% | 80.66% |
| Avg acceptance length | 1.814 / 2.0 | 1.805 / 2.0 |
| Tested on Qwen3.5-9B, 16 real prompts × 128 output tokens, single B200. Difference is within run-to-run noise. |
d01e724 to
eb4643f
Compare
Unify the Eagle3 one-model and MTP-eagle one-model speculative-decoding workers into a single Eagle3OneModelWorker in eagle3.py, branching on self.is_mtp_eagle. MTPEagleWorker becomes a thin backward-compatible subclass; mtp.py keeps a module-level __getattr__ shim so the historical import path continues to resolve. Key changes: - Eagle3OneModelSpecMetadata gains slot_ids and subseq_all_rank_num_tokens; prepare() skips num_tokens adjustment for MTP-eagle and populates slot_ids from the resource manager. - Eagle3ResourceManager owns the relaxed-acceptance delta pool for both modes. - New helpers _get_step_all_rank_num_tokens, _run_draft_forward, and _prepare_flash_mla_generation_layout encapsulate the per-step branching. - sample_and_accept_draft_tokens takes input_ids and supports the relaxed-thinking path previously exclusive to MTPEagleWorker. - EagleDecodingConfig grows the relaxed-acceptance fields mirrored from MTPDecodingConfig. - SpeculativeDecodingMode.is_mtp_one_model() now means vanilla MTP only; predicates and per-model checks are extended to recognize MTP_EAGLE_ONE_MODEL as a first-class one-model mode. - The Eagle3 _saved_kv_lens_cuda save/restore is dropped (relying on attn_metadata.update_for_spec_dec() instead) - needs verification under Eagle3 regression tests. - Factory routing in utils.py routes MTP_EAGLE_ONE_MODEL to the unified Eagle3 metadata, sampler, resource manager, and worker. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Move the per-step ``all_rank_num_tokens`` plumbing into the draft model itself so the unified Eagle3OneModelWorker no longer mutates ``attn_metadata`` on the way into the draft loop. - Eagle3DraftModel.forward takes an optional ``all_rank_num_tokens`` kwarg and wraps its body in try/finally that restores ``attn_metadata.all_rank_num_tokens`` on exit. - _run_draft_forward in eagle3.py passes ``all_rank_num_tokens`` via ``inputs`` for Eagle3 (kwarg to Eagle3DraftModel) and as a direct kwarg to ``mtp_layers[0]`` for MTP Eagle; the worker no longer needs the old fallback parameter. - _get_step_all_rank_num_tokens reads only from spec_metadata (all_rank_num_tokens at step 0, subseq_all_rank_num_tokens otherwise). - model_engine.py populates ``spec_metadata.subseq_all_rank_num_tokens`` for both Eagle3 one-model and MTP-eagle one-model at all three sites that allgather per-rank token counts. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Address review feedback: keep is_mtp_one_model() covering both MTP and MTP_EAGLE_ONE_MODEL (matches main), and use is_mtp_vanilla() only where the call should match vanilla MTP exclusively. Drop the JIRA tag from the NOTE comment in eagle3.py and simplify the now-redundant "is_mtp_one_model() or is_mtp_eagle_one_model()" patterns introduced by the merge. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…rward The one-model worker forward signature (Eagle3 / MTP-Eagle) takes ``resource_manager``, and modeling_speculative.py forwards it unconditionally to ``self.spec_worker(...)``. On non-last PP ranks, ``forward`` is replaced by ``skip_forward`` via modeling_utils.skip_forward(), which raised ``TypeError: SpecWorkerBase.skip_forward() got an unexpected keyword argument 'resource_manager'`` and silently terminated the executor worker. Before the merge, MTP-Eagle used MTPWorker.skip_forward which already accepted ``resource_manager``; the merged path now inherits SpecWorkerBase.skip_forward, which did not. Add the parameter (unused) to restore PP compatibility. Validated on H200 x4 with TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] (was: silent crash during warmup; now: PASSED, GSM8K 63.72). Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Pre-commit yapf reformatted the position_ids update in the unified linear draft loop. No functional change. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
79bbe75 to
9733b00
Compare
|
/bot run |
|
PR_Github #50210 [ run ] triggered by Bot. Commit: |
|
PR_Github #50210 [ run ] completed with state
|
|
/bot run |
|
PR_Github #50222 [ run ] triggered by Bot. Commit: |
|
PR_Github #50222 [ run ] completed with state
|
Summary by CodeRabbit
Release Notes
Description
Unify
Eagle3OneModelWorker(Eagle3 one-model) andMTPEagleWorker(MTP-eagle one-model) into a single worker class in
tensorrt_llm/_torch/speculative/eagle3.py, branching onself.is_mtp_eagle = spec_dec_mode.is_mtp_eagle_one_model(). The twocode paths were ~85% duplicated; this PR collapses them into one
implementation while preserving backward-compatible imports.
MTPEagleWorkerbecomes a thin backward-compatible subclass ineagle3.py.mtp.pyretains a module-level__getattr__shim sofrom tensorrt_llm._torch.speculative.mtp import MTPEagleWorkercontinues to resolve.
Key changes
eagle3.py): new helpers_get_step_all_rank_num_tokens,_run_draft_forward,_prepare_flash_mla_generation_layout,draft_sampler(TP-aware);sample_and_accept_draft_tokensgains aninput_idsparameter andthe relaxed-thinking acceptance path previously exclusive to
MTPEagleWorker.Eagle3OneModelSpecMetadata): new fieldsslot_idsandsubseq_all_rank_num_tokens.prepare()skipsnum_tokenssubtraction for MTP-eagle and populatesslot_idsfromthe resource manager.
Eagle3ResourceManagerowns the relaxed-acceptancerelaxed_delta_poolso the Eagle3 path can also use thinking-phaserelaxed acceptance.
Eagle3DraftModel.forwardtakes an optionalall_rank_num_tokenskwarg and wraps its body in try/finally thatrestores
attn_metadata.all_rank_num_tokenson exit — the workerloop no longer mutates
attn_metadatafor Eagle3.EagleDecodingConfiggains the five relaxed-acceptance fieldsmirrored from
MTPDecodingConfig.SpeculativeDecodingMode.is_mtp_one_model()is narrowed tovanilla MTP only;
MTP_EAGLE_ONE_MODELbecomes a first-classone-model mode in
use_one_engine,without_logits,needs_kv_cache_rewind,support_overlap_scheduler,support_capturable_guided_decoder,support_dynamic_draft_len,has_spec_decoder. Per-model checks inmodeling_deepseekv3,modeling_glm,modeling_exaone_moe,modeling_nemotron_h,modeling_qwen3_next,modeling_speculative, andmodel_configare extended accordingly.
utils.py) routesMTP_EAGLE_ONE_MODELtoEagle3OneModelSpecMetadata,Eagle3OneModelSampler,Eagle3ResourceManager, and the unified worker.model_engine.pypopulatesspec_metadata.subseq_all_rank_num_tokensfor both Eagle3 andMTP-eagle one-model at all three attention-DP allgather sites.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.