Skip to content

Commit 66b54ed

Browse files
authored
[OMNIML-5024] specdec_bench cell t0_d3 — google/gemma-4-E4B-it / MTP / vllm (#1663)
### What does this PR do? Type of change: Bug fix + new example Wires SPEED-bench's MTP path to support **Gemma 4** (and any future MTP variant that uses a separate assistant / draft model), and adds the SPEED-bench MTP/vLLM example for `google/gemma-4-E4B-it`. **Key difference: Gemma 4 MTP vs. generic MTP.** vLLM's `speculative_config` accepts two different shapes for MTP: | Variant | `speculative_config` shape | Models | |---|---|---| | **Generic MTP** | `{"method": "mtp", "num_speculative_tokens": N}` | Models that carry their own MTP layer in-tree (e.g. Qwen 3.5 MTP variants) — no separate draft / assistant model. | | **Assistant-model MTP** | `{"model": "<assistant>", "num_speculative_tokens": N}` (no `method` key — vLLM auto-detects from the assistant) | Gemma 4 family (E2B / E4B / 26B-A4B / 31B); each target model has a paired `<target>-assistant` checkpoint that acts as the MTP draft. Landed in [vllm-project/vllm#41745](vllm-project/vllm#41745) (2026-05-06). | The specdec_bench vLLM wrapper at `examples/specdec_bench/specdec_bench/models/vllm.py` previously emitted only the generic shape for any `--speculative_algorithm MTP` invocation, which produced `NotImplementedError: Unsupported speculative method: 'mtp'` on Gemma 4 even with a container that has the support (`vllm/vllm-openai:v0.22.1`+). This PR teaches the wrapper to switch shapes based on whether `--draft_model_dir` is provided. **Concrete changes:** 1. **`examples/specdec_bench/specdec_bench/models/vllm.py`** — when `speculative_algorithm == "MTP"` AND `draft_model_dir` is set, emit `{"model": draft_model_dir, "num_speculative_tokens": N}` (assistant-model shape). Otherwise emit the existing `{"method": "mtp", ...}` (generic shape). Backward-compatible — Qwen 3.5 MTP and other callers that omit `--draft_model_dir` get the same config they got before. 2. **`examples/specdec_bench/specdec_bench/utils.py`** — `get_tokenizer` reads `extra_special_tokens` from the model's `tokenizer_config.json` and passes them through to `AutoTokenizer.from_pretrained`. Gemma 4 tokenizers ship a list-shaped `extra_special_tokens` entry that the constructor would otherwise reject. Necessary for any Gemma 4 cell. 3. **`tools/launcher/examples/gemma-4/gemma-4-E4B-it/specdec_bench_mtp_vllm.yaml`** — SPEED-bench parent YAML for `google/gemma-4-E4B-it`. Uses `vllm/vllm-openai:v0.22.1` (has `gemma4_mtp.py` from #41745) and wires `--draft_model_dir /hf-local/google/gemma-4-E4B-it-assistant` on both task_0 (qualitative) and task_1 (throughput_32k). 4. **`tools/launcher/common/specdec_bench/_cells/gemma-4-E4B-it_mtp_vllm_t0_d3.yaml`** — runtime params for the `t0_d3` cell of OMNIML-5022 (`temperature=0`, `max_model_len=40960`). ### Usage ```python # Wrapper-level: same CLI as before, just pass --draft_model_dir for # Gemma 4 MTP. The wrapper auto-routes to the assistant-model shape. # python examples/specdec_bench/run.py \ # --engine VLLM \ # --speculative_algorithm MTP \ # --draft_model_dir /hf-local/google/gemma-4-E4B-it-assistant \ # --draft_length 3 \ # --tp_size 1 \ # ...other SPEED-bench knobs... # Equivalent direct vLLM invocation (for reference, no wrapper): from vllm import LLM, SamplingParams llm = LLM( model="google/gemma-4-E4B-it", speculative_config={ "model": "google/gemma-4-E4B-it-assistant", "num_speculative_tokens": 3, }, trust_remote_code=True, ) ``` ### Testing - **Upstream existence checks**: verified the assistant models `google/gemma-4-{E2B,E4B,26B-A4B}-it-assistant` exist, public, ungated on HuggingFace; verified `vllm/model_executor/models/gemma4_mtp.py` is in vLLM `v0.22.0`, `v0.22.1`, and `main`. - **Backward compat**: `MTP` callers that don't pass `--draft_model_dir` (e.g. the existing Qwen 3.5 MTP/vLLM cells under `tools/launcher/examples/Qwen/Qwen3.5-4B/`) take the unchanged `{"method": "mtp", ...}` branch. No diff for those. - **End-to-end cluster validation**: pending. Will run via the OMNIML-5022 cells (OMNIML-5024 / 5025 / 5026 / 5027) once the nmm-sandbox submodule pin advances past this PR. Each cell exercises `task_0` (SPEED-Bench qualitative, 880 samples) + `task_1` (throughput_32k, 80 samples) on cw_dfw, single H100. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ — the wrapper only takes the new branch when `--draft_model_dir` is provided alongside `--speculative_algorithm MTP`. Existing MTP callers (Qwen 3.5 etc.) keep the generic `method: "mtp"` config. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A — no new dependencies. - Did you write any new necessary tests?: ❌ — relying on the SPEED-bench cluster cells (OMNIML-5024 …5027) for end-to-end validation; no unit test fixture for the vLLM wrapper exists in `tests/` for me to extend symmetrically. Happy to add one if reviewers want it. - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ — small fix + example addition. Can add if requested. - Did you get Claude approval on this PR?: ❌ — will run `/claude review` once the PR is marked Ready for review. ### Additional Information - JIRA: [OMNIML-5024](https://jirasw.nvidia.com/browse/OMNIML-5024) (cell_t0_d3); siblings OMNIML-5025/5026/5027 (cell_{t0_d7, t1_d3, t1_d7}) of Epic OMNIML-5022 are blocked on this PR landing. - Upstream reference: vllm-project/vllm#41745 — "[Spec Decode] Add Gemma4 MTP speculative decoding support". - Companion (pensieve-intern !91, internal): adds a "Model-family-specific MTP invocation" table to the specdec_bench cell SPEC so future agents pair `MTP` with the right `--draft_model_dir` from SPEC-read time. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a SPEED-bench pipeline for Gemma 4 using vLLM speculative decoding (MTP) with qualitative and throughput tasks. * **Improvements** * Speculative-decoding logic updated to handle assistant-model and generic MTP cases distinctly. * Tokenizer loading now reads and normalizes extra special tokens from tokenizer config when available. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Pensieve Intern <chenhany@nvidia.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 48767a0 commit 66b54ed

3 files changed

Lines changed: 118 additions & 5 deletions

File tree

examples/specdec_bench/specdec_bench/models/vllm.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,27 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
6363
specdec["disable_padded_drafter_batch"] = True
6464
specdec["parallel_draft_block_sizes"] = kwargs.get("parallel_draft_block_sizes")
6565
elif kwargs.get("speculative_algorithm") == "MTP":
66-
specdec = {
67-
"method": "mtp",
68-
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
69-
}
66+
draft_model_dir = kwargs.get("draft_model_dir")
67+
if draft_model_dir:
68+
# Assistant-model MTP (e.g. Gemma 4): vLLM's Gemma4 MTP
69+
# support (vllm-project/vllm#41745) expects
70+
# ``speculative_config={"model": <assistant>, ...}`` with
71+
# no ``method`` key — vLLM auto-detects Gemma4 from the
72+
# assistant model. Passing ``method: "mtp"`` here triggers
73+
# ``NotImplementedError: Unsupported speculative method:
74+
# 'mtp'`` on Gemma4 even on a container that has the
75+
# support (e.g. ``vllm/vllm-openai:v0.22.1``+).
76+
specdec = {
77+
"model": draft_model_dir,
78+
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
79+
}
80+
else:
81+
# Generic MTP path (Qwen3.5 etc.) — model carries its
82+
# own MTP layer; no separate draft / assistant model.
83+
specdec = {
84+
"method": "mtp",
85+
"num_speculative_tokens": kwargs.get("speculative_num_steps", 3),
86+
}
7087
elif kwargs.get("speculative_algorithm") == "DFLASH":
7188
specdec = {
7289
"method": "dflash",

examples/specdec_bench/specdec_bench/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@
3535

3636

3737
def get_tokenizer(path, trust_remote_code=False):
38-
return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code)
38+
extra_special_tokens = None
39+
tokenizer_config_path = os.path.join(path, "tokenizer_config.json")
40+
if os.path.exists(tokenizer_config_path):
41+
with open(tokenizer_config_path) as f:
42+
tokenizer_config = json.load(f)
43+
extra_special_tokens = tokenizer_config.get("extra_special_tokens")
44+
45+
kwargs = {"trust_remote_code": trust_remote_code}
46+
if isinstance(extra_special_tokens, list):
47+
kwargs["extra_special_tokens"] = {
48+
token.strip("<|>").replace("|", "_") + "_token": token for token in extra_special_tokens
49+
}
50+
51+
return AutoTokenizer.from_pretrained(path, **kwargs)
3952

4053

4154
def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPEED-bench MTP speculative-decoding run for gemma-4-E4B-it via vLLM.
2+
#
3+
# Gemma 4 MTP support landed in vLLM PR vllm-project/vllm#41745 (2026-05-06)
4+
# and is in ``vllm/vllm-openai:v0.22.1`` (and later). Gemma 4 MTP uses a
5+
# separate assistant model passed via ``--draft_model_dir``; vLLM
6+
# auto-detects Gemma 4 from the assistant and does NOT take a ``method``
7+
# key in ``speculative_config``. The wrapper at
8+
# ``examples/specdec_bench/specdec_bench/models/vllm.py`` routes to the
9+
# assistant-model config shape when ``--speculative_algorithm MTP`` is
10+
# paired with ``--draft_model_dir``.
11+
#
12+
# Assistant model: ``google/gemma-4-E4B-it-assistant`` (public, ungated).
13+
#
14+
# Slurm run on cw_dfw — cells override per-cell knobs via
15+
# pipeline.task_N.args+=[...]:
16+
#
17+
# uv run slurm.py \
18+
# --yaml modules/Model-Optimizer/tools/launcher/examples/gemma-4/gemma-4-E4B-it/specdec_bench_mtp_vllm.yaml \
19+
# --yes detach=true \
20+
# pipeline.task_0.args+=["--temperature 0","--max_seq_len 65536","--save_dir /scratchspace/<sweep>/qualitative","--draft_length 3"] \
21+
# pipeline.task_1.args+=["--temperature 0","--max_seq_len 65536","--save_dir /scratchspace/<sweep>/throughput_32k","--num_requests 80","--draft_length 3"]
22+
23+
job_name: gemma-4-E4B-it_specdec_bench_mtp_vllm
24+
25+
pipeline:
26+
global_vars:
27+
hf_model: /hf-local/google/gemma-4-E4B-it
28+
draft_model: /hf-local/google/gemma-4-E4B-it-assistant
29+
30+
# task_0: SPEED qualitative split
31+
task_0:
32+
script: common/specdec_bench/run.sh
33+
args:
34+
- --dataset speed
35+
- --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/qualitative
36+
- --engine VLLM
37+
- --speculative_algorithm MTP
38+
- --draft_model_dir <<global_vars.draft_model>>
39+
- --draft_length 3
40+
- --tp_size 1
41+
- --ep_size 1
42+
- --concurrency 32
43+
- --output_length 4096
44+
- --aa_timing
45+
- --show_progress
46+
- --save_dir /scratchspace/{sweep_name_default}/qualitative
47+
environment:
48+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
49+
- HF_LOCAL: /hf-local
50+
slurm_config:
51+
_factory_: "slurm_factory"
52+
nodes: 1
53+
ntasks_per_node: 1
54+
gpus_per_node: 1
55+
container: vllm/vllm-openai:v0.22.1
56+
57+
# task_1: SPEED throughput_32k split
58+
task_1:
59+
script: common/specdec_bench/run.sh
60+
args:
61+
- --dataset speed
62+
- --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k
63+
- --engine VLLM
64+
- --speculative_algorithm MTP
65+
- --draft_model_dir <<global_vars.draft_model>>
66+
- --draft_length 3
67+
- --tp_size 1
68+
- --ep_size 1
69+
- --concurrency 8
70+
- --num_requests 80
71+
- --output_length 4096
72+
- --aa_timing
73+
- --show_progress
74+
- --save_dir /scratchspace/{sweep_name_default}/throughput_32k
75+
environment:
76+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
77+
- HF_LOCAL: /hf-local
78+
slurm_config:
79+
_factory_: "slurm_factory"
80+
nodes: 1
81+
ntasks_per_node: 1
82+
gpus_per_node: 1
83+
container: vllm/vllm-openai:v0.22.1

0 commit comments

Comments
 (0)