Skip to content

Commit 43b67a8

Browse files
authored
specdec_bench: keep method=mtp when adding model=<assistant> for Gemma 4 MTP (#1677)
### What does this PR do? Type of change: Bug fix Fixes the specdec_bench vLLM wrapper's MTP `speculative_config` emission so Gemma 4 MTP no longer hits the wrong code path inside vLLM. ### Bug vLLM's `SpeculativeConfig.__post_init__` (`vllm/config/speculative.py:529-602`) auto-detects `method` ONLY when it's unset. When `model` is provided and `method` is `None`, the default branch sets `method = "draft_model"` — the generic same-architecture draft path, NOT MTP. That path enforces equal num_heads between target and draft and raises: ``` AssertionError: All layers in one attention group must share num_heads; got {8, 4} ``` on heterogeneous-head models. Gemma 4 has 8 target heads and 4 draft heads by design. ### Where the previous fix went wrong PR #1663 changed the MTP branch in the wrapper to emit `{model: <assistant>, num_speculative_tokens: N}` WITHOUT `method` when `draft_model_dir` was provided, based on a misread of vLLM PR #41745's test plan that only showed the `{model, num_speculative_tokens}` shape. That test plan was the direct `LLM(...)` constructor invocation; vLLM had already defaulted method internally. Going through specdec_bench's `AsyncEngineArgs(speculative_config=...)` path, the explicit `method` key is required to avoid the auto-detect → draft_model fallback. ### Reference vLLM's own test at [`tests/v1/e2e/spec_decode/test_spec_decode.py:818-823`](https://github.com/vllm-project/vllm/blob/main/tests/v1/e2e/spec_decode/test_spec_decode.py#L818) does exactly this for the gemma4-e4b parametrization: ```python speculative_config = { "method": method, # "mtp" "num_speculative_tokens": ..., } if draft_model is not None: # Gemma 4 case speculative_config["model"] = draft_model ``` ### Fix Restore `method="mtp"` as the unconditional MTP path. ADD `model` only when `draft_model_dir` is set. Backward-compatible for Qwen 3.5 MTP / DeepSeek MTP / other inline-MTP families (they keep the bare `{method: "mtp"}` config). ### Validation Field-tested via vLLM PR #41745's correctness test on `gemma-4-E4B-it` + `gemma-4-E4B-it-assistant`: produced 304.7 output TPS at γ=4 vs 171.0 baseline (178% speedup) on H100. The same `speculative_config` shape this fix emits. ### Surfaced on [OMNIML-5024](https://jirasw.nvidia.com/browse/OMNIML-5024) pipeline #54356795: - Wrapper emitted `{model: assistant, num_speculative_tokens: 3}` - vLLM auto-detected `method = "draft_model"` - Loaded gemma-4-E4B-it-assistant (4 heads) as a generic draft for gemma-4-E4B-it (8 heads) - Attention-group num_heads check tripped → AssertionError, task_0 FAILED, task_1 CANCELLED ### Before your PR is "*Ready for review*" - Backward compatible: ✅ (Qwen 3.5 / DeepSeek MTP unchanged; only the MTP+`draft_model_dir` case changes). - New tests: ❌ — the test exercising this codepath would need a GPU + gemma-4 model checkout, which is cluster work, not unit-test scope. JIRA-tracked validation via OMNIML-5024 dispatch after this lands. - Changelog: ❌ ### Additional Information - vLLM PR #41745 (Gemma4 MTP support) - Companion: NVIDIA/Model-Optimizer PR #1675 (launcher `GlobalVariables.draft_model` schema fix) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed speculative decoding configuration handling in the benchmark example to ensure consistent method assignment and proper draft model configuration. * **Documentation** * Updated configuration comments to reflect corrected behavior and improved clarity. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 46eddab commit 43b67a8

1 file changed

Lines changed: 43 additions & 19 deletions

File tree

  • examples/specdec_bench/specdec_bench/models

examples/specdec_bench/specdec_bench/models/vllm.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,51 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
6363
specdec["disable_padded_drafter_batch"] = True
6464
specdec["parallel_draft_block_sizes"] = kwargs.get("parallel_draft_block_sizes")
6565
elif kwargs.get("speculative_algorithm") == "MTP":
66+
# vLLM's ``SpeculativeConfig.__post_init__`` (vllm/config/
67+
# speculative.py:529-602) does method auto-detection ONLY
68+
# when ``method`` is unset — when ``model`` is provided and
69+
# ``method`` is None, the default branch sets
70+
# ``method = "draft_model"`` (the generic same-architecture
71+
# draft path), NOT MTP. That path enforces equal num_heads
72+
# between target and draft and raises
73+
# ``AssertionError: All layers in one attention group must
74+
# share num_heads`` on heterogeneous-head models like
75+
# Gemma 4 (target=8 heads, assistant=4).
76+
#
77+
# The canonical config for ALL MTP variants is to ALWAYS
78+
# pass ``method="mtp"`` AND ADD ``model=<assistant>`` only
79+
# when the family uses a separate assistant model. vLLM's
80+
# own test at ``tests/v1/e2e/spec_decode/test_spec_decode.py``
81+
# (lines 818-823) does exactly this for the gemma4-e4b
82+
# parametrization:
83+
#
84+
# speculative_config = {
85+
# "method": "mtp",
86+
# "num_speculative_tokens": ...,
87+
# }
88+
# if draft_model is not None: # Gemma 4 case
89+
# speculative_config["model"] = draft_model
90+
#
91+
# Surfaced on OMNIML-5024 pipeline #54356795: dropping the
92+
# ``method`` key when ``draft_model_dir`` was provided sent
93+
# the call into the generic draft_model path, hitting the
94+
# num_heads assertion. Restored both keys.
95+
specdec = {
96+
"method": "mtp",
97+
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
98+
}
6699
draft_model_dir = kwargs.get("draft_model_dir")
67100
if draft_model_dir:
68-
# Assistant-model MTP (e.g. Gemma 4): vLLM's Gemma4 MTP
69-
# support (vllm-project/vllm#41745) expects
70-
# ``speculative_config={"model": <assistant>, ...}`` with
71-
# no ``method`` key — vLLM auto-detects Gemma4 from the
72-
# assistant model. Passing ``method: "mtp"`` here triggers
73-
# ``NotImplementedError: Unsupported speculative method:
74-
# 'mtp'`` on Gemma4 even on a container that has the
75-
# support (e.g. ``vllm/vllm-openai:v0.22.1``+).
76-
specdec = {
77-
"model": draft_model_dir,
78-
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
79-
}
80-
else:
81-
# Generic MTP path (Qwen3.5 etc.) — model carries its
82-
# own MTP layer; no separate draft / assistant model.
83-
specdec = {
84-
"method": "mtp",
85-
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
86-
}
101+
# Gemma 4 family (E2B / E4B / 26B-A4B / 31B) uses a
102+
# separate assistant checkpoint as the MTP draft.
103+
# vLLM auto-detects Gemma4 MTP from the assistant
104+
# ``model_type=gemma4_assistant`` and rewrites it to
105+
# ``gemma4_mtp`` (speculative.py:511-522). For
106+
# families where the MTP layer ships inside the
107+
# target (Qwen 3.5 etc.), omit ``--draft_model_dir``
108+
# and let vLLM use the target model as its own draft
109+
# (handled in speculative.py:562-573).
110+
specdec["model"] = draft_model_dir
87111
elif kwargs.get("speculative_algorithm") == "DFLASH":
88112
specdec = {
89113
"method": "dflash",

0 commit comments

Comments
 (0)