Skip to content

Commit a7b0a92

Browse files
yeyu-nvidiaclaude
andauthored
EAGLE3 new model support: pipeline configs, triage docs, and Ministral-3 fixes (#1417)
## Summary EAGLE3 automation triage work (OKR-30): testing the 4-step EAGLE3 offline pipeline against 12 new model architectures, documenting failure modes, and fixing issues found. ### Code fixes (modelopt) | File | Change | |------|--------| | `modelopt/torch/speculative/utils.py` | Extend VLM detection in `load_vlm_or_llm` to check `text_config`/`llm_config` attrs (catches `mistral3` models) | | `modelopt/torch/speculative/plugins/modeling_fakebase.py` | Add `consolidated.safetensors` fallback for checkpoints with incomplete HF shards | | `modelopt/torch/export/plugins/hf_spec_configs.py` | Set `use_cache=True` in EAGLE export templates (fixes strict `huggingface_hub` validation) | ### Pipeline infrastructure - `examples/speculative_decoding/pipeline/eagle3/` — pipeline scripts and configs: - `offline_training.sh` — training + export with runtime patches for older container modelopt - `dump_offline_data_vllm.sh` — vLLM-based hidden state extraction (with speculators compat patches) - `dump_offline_data.sh`, `dump_offline_data_hf.sh` — alternative dump paths - 18 quick-fail-check YAMLs for 12 models - 4 standalone task1 YAMLs ### Documentation - `eagle3_triage_chart.md` — model test matrix, triage decision tree, per-model results, failure catalog - `eagle3_new_model_triage_guide.md` — step-by-step guide for triaging new models ### Model test results (as of 2026-05-27) | Model | task_0 | task_1 | task_2 | task_3 | Blocker | |-------|--------|--------|--------|--------|---------| | Qwen3-8B | - | - | - | - | Reference (existing) | | Kimi-K2.5 | - | - | - | - | Existing (GB200) | | **Ministral-3-8B** | SKIP | PASS | PASS | FAIL | `use_cache=null` in export (fixed) | | Ministral-3-14B | FAIL | - | - | FAIL | vLLM engine init fails | | Qwen3.5-35B-A3B | TIMEOUT | - | - | - | Data synth too slow | | gpt-oss-20b | FAIL | - | - | - | Tokenizer `HarmonyError` | | Step-3.5-Flash | TIMEOUT | - | - | - | Data synth time limit | | MiniMax-M2.5 | TIMEOUT | - | - | - | `trust_remote_code` needed | | DeepSeek-V3.2 | no log | - | - | - | May not be mirrored | | Qwen3.5-9B | - | - | - | - | Not yet run | | Qwen3.5-27B | - | - | - | - | Not yet run | | GLM-5 | - | - | - | - | Not yet run | ### Issues found and fixed | # | Issue | Fix | |---|-------|-----| | 1 | `mistral3` model type not detected as VLM | Check `text_config`/`llm_config` attrs in `load_vlm_or_llm` | | 2 | Missing HF shard file (Ministral-3-8B) | Fallback to `consolidated.safetensors` with Mistral native key aliases | | 3 | `use_cache=null` in exported EAGLE config | Set `use_cache=True` in export template configs | | 4 | speculators incompatible with vLLM container | Runtime patches in `dump_offline_data_vllm.sh` | | 5 | `offline_training.sh` infra issues | Rewritten with runtime patches for container modelopt | ## Test plan - [x] Ministral-3-8B training passes (`cicd_1779829129`) - [x] Ministral-3-8B export succeeds - [ ] Ministral-3-8B benchmark passes (`cicd_1779901409` — pending with all fixes) - [ ] Dry-run remaining model configs ## Note GitHub secret scanning alert #6 is a **false positive** — `Mistral3ForConditionalGeneration` (a HuggingFace model class name in a YAML comment) was flagged as a "Mistral AI API Key". 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 88fd7ff commit a7b0a92

35 files changed

Lines changed: 3756 additions & 17 deletions
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an LLM using vLLM's native hidden-state extractor.
17+
18+
This uses vLLM's built-in ``extract_hidden_states`` speculative method together with
19+
the ``ExampleHiddenStatesConnector`` KV connector, so no third-party data-generation
20+
dependency (e.g. ``speculators``) is required. Because the same ``eagle_aux_hidden_state_layer_ids``
21+
convention is used at EAGLE3 deployment time in vLLM, the captured aux layers match
22+
deployment by construction.
23+
24+
See https://docs.vllm.ai/en/stable/features/speculative_decoding/extract_hidden_states/
25+
"""
26+
27+
import argparse
28+
from pathlib import Path
29+
30+
import torch
31+
from common import add_aux_layers_args, resolve_aux_layers
32+
from datasets import load_dataset
33+
from tqdm import tqdm
34+
from transformers import AutoConfig, AutoTokenizer
35+
36+
REMOVE_THINK_CHAT_TEMPLATE = (
37+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
38+
)
39+
40+
41+
def parse_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser(
43+
description="""Collect hidden states from conversations using vLLM's native extractor."""
44+
)
45+
46+
parser.add_argument("--model", type=str, required=True, help="HF model path.")
47+
parser.add_argument(
48+
"--max-seq-len", type=int, default=3072, help="Max tokens per conversation."
49+
)
50+
parser.add_argument(
51+
"--input-data", type=Path, required=True, help="Path to jsonl file or directory."
52+
)
53+
parser.add_argument(
54+
"--output-dir", type=Path, required=True, help="Directory to save hidden states."
55+
)
56+
parser.add_argument("--dp-rank", type=int, default=0, help="Data parallel rank.")
57+
parser.add_argument("--dp-world-size", type=int, default=1, help="Data parallel world size.")
58+
parser.add_argument(
59+
"--trust_remote_code", action="store_true", help="Trust remote code for HF models."
60+
)
61+
parser.add_argument("--tp", type=int, default=None, help="Tensor parallel size.")
62+
parser.add_argument(
63+
"--debug-max-num-conversations", type=int, default=None, help="Limit conversations."
64+
)
65+
add_aux_layers_args(parser)
66+
67+
return parser.parse_args()
68+
69+
70+
def main(args: argparse.Namespace) -> None:
71+
# Import lazily so --help and arg parsing work without vLLM installed.
72+
from vllm import LLM, SamplingParams
73+
from vllm.config.kv_transfer import KVTransferConfig
74+
from vllm.distributed.kv_transfer.kv_connector.v1 import example_hidden_states_connector
75+
from vllm.inputs import TokensPrompt
76+
77+
# Load conversations
78+
if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"):
79+
dataset = load_dataset("json", data_files=str(args.input_data), split="train")
80+
elif args.input_data.is_dir():
81+
dataset = load_dataset(
82+
"json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train"
83+
)
84+
else:
85+
raise ValueError(f"input_data must be a .jsonl file or directory, got: {args.input_data}")
86+
print(f"Loaded {len(dataset)} conversations from {args.input_data}")
87+
88+
# Shard data
89+
if args.dp_world_size > 1:
90+
dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank)
91+
print(f"Sharded to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}")
92+
93+
# Remove already dumped conversations
94+
output_dir = args.output_dir
95+
output_dir.mkdir(parents=True, exist_ok=True)
96+
97+
def keep_conversation(entry):
98+
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
99+
assert conversation_id is not None, "conversation_id is required"
100+
return not (output_dir / f"{conversation_id}.pt").exists()
101+
102+
original_num = len(dataset)
103+
dataset = dataset.filter(keep_conversation)
104+
print(f"Removed {original_num - len(dataset)} conversations due to existing output files")
105+
106+
if args.debug_max_num_conversations is not None:
107+
dataset = dataset.select(range(args.debug_max_num_conversations))
108+
109+
# Resolve the aux-layer indices and append the final-layer output. vLLM saves the
110+
# final (un-normed) hidden state when ``num_hidden_layers`` is passed as a layer id.
111+
config = AutoConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
112+
num_hidden_layers = getattr(config, "num_hidden_layers", None)
113+
if num_hidden_layers is None:
114+
raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}")
115+
aux_layer_ids = resolve_aux_layers(args, num_hidden_layers)
116+
# The trailing entry is the final output hidden state; the rest are aux layers.
117+
extract_layer_ids = [*aux_layer_ids, num_hidden_layers]
118+
print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)")
119+
120+
# Tokenize conversations
121+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
122+
if tokenizer.pad_token is None:
123+
tokenizer.pad_token = tokenizer.eos_token
124+
if tokenizer.chat_template is not None:
125+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
126+
127+
# Prepare prompts for vLLM
128+
prompts = []
129+
conversation_ids = []
130+
num_skipped_too_long = 0
131+
num_invalid = 0
132+
133+
for entry in dataset:
134+
conversation_id = entry.get("conversation_id", entry.get("uuid"))
135+
conversations = entry["conversations"]
136+
if not conversations or not isinstance(conversations, list):
137+
num_invalid += 1
138+
continue
139+
140+
tokenized = tokenizer.apply_chat_template(
141+
conversations, return_tensors="pt", add_generation_prompt=False
142+
)
143+
# transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids
144+
if hasattr(tokenized, "input_ids"):
145+
input_ids = tokenized.input_ids
146+
elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized:
147+
input_ids = tokenized["input_ids"]
148+
else:
149+
input_ids = tokenized
150+
if not hasattr(input_ids, "shape"):
151+
input_ids = torch.tensor(input_ids)
152+
input_ids = input_ids.squeeze(0)
153+
num_tokens = input_ids.shape[0]
154+
if num_tokens <= 10 or num_tokens > args.max_seq_len:
155+
num_skipped_too_long += 1
156+
continue
157+
158+
prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist()))
159+
conversation_ids.append(conversation_id)
160+
161+
print(
162+
f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)"
163+
)
164+
165+
if len(prompts) == 0:
166+
print("No prompts to process.")
167+
return
168+
169+
# Initialize vLLM with the native hidden-state extractor.
170+
tp = args.tp if args.tp is not None else torch.cuda.device_count()
171+
storage_path = output_dir / ".vllm_hidden_states"
172+
storage_path.mkdir(parents=True, exist_ok=True)
173+
174+
llm = LLM(
175+
model=args.model,
176+
tensor_parallel_size=tp,
177+
max_model_len=args.max_seq_len,
178+
trust_remote_code=args.trust_remote_code,
179+
enable_chunked_prefill=False, # required by extract_hidden_states
180+
speculative_config={
181+
"method": "extract_hidden_states",
182+
"num_speculative_tokens": 1,
183+
"draft_model_config": {
184+
"hf_config": {"eagle_aux_hidden_state_layer_ids": extract_layer_ids},
185+
},
186+
},
187+
kv_transfer_config=KVTransferConfig(
188+
kv_connector="ExampleHiddenStatesConnector",
189+
kv_role="kv_producer",
190+
kv_connector_extra_config={
191+
"shared_storage_path": str(storage_path),
192+
"use_synchronization_lock": False, # batch generation, no concurrent readers
193+
},
194+
),
195+
)
196+
197+
# max_tokens=1: we only need a single forward pass over the prompt tokens.
198+
outputs = llm.generate(prompts, SamplingParams(max_tokens=1))
199+
200+
# Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the
201+
# vLLM path does not compute).
202+
num_success = 0
203+
for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"):
204+
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
205+
if hidden_states_path is None:
206+
print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping")
207+
continue
208+
209+
obj = example_hidden_states_connector.load_hidden_states(hidden_states_path)
210+
token_ids = obj["token_ids"]
211+
# hidden_states: [num_tokens, num_extracted_layers, hidden_size], ordered to match
212+
# extract_layer_ids. Last layer = final output; the rest = aux layers.
213+
hidden_states = obj["hidden_states"]
214+
215+
output_hidden_states = hidden_states[:, -1, :].cpu()
216+
if hidden_states.shape[1] > 1:
217+
# Concatenate aux layers along the hidden dim, matching the HF dump format.
218+
aux = hidden_states[:, :-1, :].cpu()
219+
aux_hidden_states = aux.reshape(aux.shape[0], -1)
220+
else:
221+
aux_hidden_states = torch.empty(0)
222+
223+
output_file = output_dir / f"{conv_id}.pt"
224+
with open(output_file, "wb") as f:
225+
torch.save(
226+
{
227+
"input_ids": token_ids.cpu(),
228+
"hidden_states": output_hidden_states,
229+
"aux_hidden_states": aux_hidden_states,
230+
"conversation_id": conv_id,
231+
},
232+
f,
233+
)
234+
example_hidden_states_connector.cleanup_hidden_states(hidden_states_path)
235+
num_success += 1
236+
237+
print(f"Successfully processed {num_success} out of {len(prompts)} conversations.")
238+
239+
240+
if __name__ == "__main__":
241+
cli_args = parse_args()
242+
main(cli_args)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# EAGLE3 New Model Support — Triage Guide for Claude Code
2+
3+
This document describes how to triage EAGLE3 pipeline failures when adding a new model.
4+
Follow these steps in order. Stop at the first failure, diagnose, and document findings.
5+
6+
## Pipeline Overview
7+
8+
The EAGLE3 pipeline has 4 stages (mapped to task_0 through task_3 in the YAML):
9+
10+
| Task | Stage | Container | Script | What it does |
11+
|------|-------|-----------|--------|-------------|
12+
| task_0 | Data synthesis | vllm/vllm-openai | `common/vllm/query.sh` | Serve model with vLLM, generate synthetic conversations |
13+
| task_1 | Hidden state dump | vllm/vllm-openai | `common/eagle3/dump_offline_data*.sh` | Dump hidden states from generated conversations |
14+
| task_2 | Training + Export | tensorrt-llm/release | `common/eagle3/train_eagle.sh` | Train EAGLE3 draft model, export HF checkpoint |
15+
| task_3 | Benchmark | vllm/vllm-openai | `common/specdec_bench/quick_check.sh` | Run speculative decoding benchmark |
16+
17+
Some configs combine task_0+task_1 into a single vLLM dump step, or skip task_0 if data already exists.
18+
19+
## Step 1: Locate the pipeline config
20+
21+
```text
22+
tools/launcher/examples/<Org>/<Model>/eagle3_quick_check.yaml
23+
```
24+
25+
If it doesn't exist, create one by copying an existing `eagle3_quick_check.yaml` and adjusting:
26+
- `HF_MODEL_CKPT` — the HF model path on `/hf-local/`
27+
- GPU/node counts based on model size
28+
- `--trust_remote_code` / `--trust-remote-code` if needed
29+
- Container images
30+
31+
## Step 2: Submit the pipeline
32+
33+
```bash
34+
cd tools/launcher
35+
uv run launch.py --yaml examples/<Org>/<Model>/eagle3_quick_check.yaml --yes -v
36+
```
37+
38+
The rsync can take several minutes. Experiment ID is printed as `cicd_<timestamp>`.
39+
40+
## Step 3: Check experiment output
41+
42+
Experiment directory:
43+
44+
```text
45+
experiments/cicd/cicd_<id>/
46+
```
47+
48+
Each task has a directory `<JobName>_<N>/` containing:
49+
- `sbatch_<JobName>_<N>_<SlurmJobID>.out` — the main log
50+
- `code/` — snapshot of the code at submission time
51+
52+
Check logs:
53+
54+
```bash
55+
tail -100 experiments/cicd/cicd_<id>/<JobName>_<N>/sbatch_*.out
56+
```
57+
58+
## Step 4: Diagnose failures by stage
59+
60+
### task_0/task_1 failures (vLLM data generation / hidden state dump)
61+
62+
Common issues:
63+
- **Server never starts** → Check for OOM, unsupported architecture, or missing `--trust_remote_code`
64+
- **`HarmonyError: vocab file`** → gated model, tokenizer not available offline
65+
- **`TypeError: 'NoneType' object is not iterable`** → vLLM doesn't support this model architecture yet
66+
- **`CANCELLED DUE TO TIME LIMIT`** → Model too slow for the time limit; increase wall time or reduce data
67+
- **Server starts but queries fail** → Check prompt format, connection errors
68+
69+
### task_2 failures (training + export)
70+
71+
Common issues:
72+
- **`No such file or directory: service_utils.sh`** → pipeline infra issue (older experiment)
73+
- **`ValueError: Unrecognized configuration class ... for AutoModelForCausalLM`** → VLM model not detected as VLM. Check if `load_vlm_or_llm` in `modelopt/torch/speculative/utils.py` handles this model type. Look for `text_config`/`llm_config` attributes.
74+
- **`FileNotFoundError` on shard files** → Checkpoint has unusual format (e.g., missing HF shards, has consolidated.safetensors instead). Check `FakeBaseModel._load_weights`.
75+
- **OOM during training** → Reduce `--train_bs` or `--training_seq_len`
76+
- **NaN loss** → Reduce `--lr`, check data quality
77+
78+
### task_3 failures (benchmark)
79+
80+
Common issues:
81+
- **`/scratchspace/export` doesn't exist** → task_2 failed; fix training first
82+
- **`StrictDataclassFieldValidationError`** → exported `config.json` has `null` where a typed field is expected (e.g., `use_cache`). Fix the export template in `modelopt/torch/export/plugins/hf_spec_configs.py`.
83+
- **`KeyError: '<model_type>'`** → transformers version in container doesn't recognize the model type
84+
- **`trust_remote_code=True` required** → add to benchmark config
85+
- **vLLM resolves model as wrong architecture** → VLM wrapper model needs special handling
86+
87+
## Step 5: Applying fixes
88+
89+
### Repo fixes (for merged modelopt)
90+
91+
Edit files in `/home/yeyu/Documents/TensorRT-Model-Optimizer/modelopt/torch/speculative/`.
92+
The key files:
93+
- `utils.py``load_vlm_or_llm()` for model loading
94+
- `plugins/modeling_fakebase.py``FakeBaseModel` for offline training weight loading
95+
- `plugins/hf_eagle.py` — EAGLE model definition
96+
- `../export/plugins/hf_spec_configs.py` — export config templates
97+
- `../export/plugins/hf_spec_export.py` — export logic
98+
99+
### Container patches (for pipeline)
100+
101+
A container may ship a pre-installed modelopt that can't be easily upgraded (CUDA build issues).
102+
If a fix is needed against such an installed library, apply a runtime patch in the relevant
103+
task script (e.g. the training script `common/eagle3/train_eagle.sh`) using a Python heredoc
104+
that find-and-replaces the exact code pattern in the installed file.
105+
106+
> Note: the vLLM dump path previously relied on source-patching the `speculators` library.
107+
> That dependency was removed in favor of vLLM's native `extract_hidden_states` extractor, so
108+
> no speculators patches are applied anymore.
109+
110+
When adding a new patch:
111+
1. Find the exact `old` string in the installed file (must be unique)
112+
2. Write the `new` replacement string
113+
3. Add a `python3 << 'PYEOF' || true` block in the task script before `set -eo pipefail`
114+
115+
## Step 6: Document results
116+
117+
Update `examples/speculative_decoding/pipeline/eagle3/eagle3_triage_chart.md`:
118+
1. Update the model row in the **Model Test Matrix** (status + per-task results)
119+
2. Add a **Per-Model Test Results** entry with experiment IDs, errors, and fixes
120+
3. Add new failure patterns to the **Observed Failure Catalog**
121+
122+
## Known Model-Specific Issues
123+
124+
| Model Type | Issue | Where | Fix |
125+
|-----------|-------|-------|-----|
126+
| `mistral3` (Ministral-3-*) | Not detected as VLM by `"vl"` check | `utils.py` | Check `text_config`/`llm_config` attrs |
127+
| `mistral3` (Ministral-3-8B) | Missing HF shard 1, has `consolidated.safetensors` | `modeling_fakebase.py` | Fallback to consolidated with key aliases |
128+
| All models via FakeBaseModel | `use_cache=null` in exported config | `hf_spec_configs.py` | Set `use_cache: True` in templates |
129+
| `gpt-oss-20b` | Tokenizer requires `openai_harmony` | task_0 | Gated/special tokenizer setup |
130+
| `MiniMax-M2.5` | Custom model code | task_3 | `--trust_remote_code` |
131+
| `ministral3` | `KeyError: 'ministral3'` in older transformers | task_3 | Needs transformers >= 5.3.0 |

0 commit comments

Comments
 (0)