Skip to content

Commit 46eddab

Browse files
authored
[Feat]: Specdec Streaming: RDMA + Multinode (#1611)
### What does this PR do? Type of change: New feature Multi-node **streaming** training for speculative decoding (EAGLE3 / DFlash): a live `vllm serve` captures the target model's hidden states and moves them straight to the trainer over **NIXL RDMA** — no disk round-trip. The streaming dataset is map-style — each rank fetches only its own `DistributedSampler` shard (concurrency from `dataloader_num_workers`), round-robins across multiple serve replicas (`server_urls`), and scales to multi-node DDP. Serve-side tensor parallelism (TP>1) is supported: hidden states are replicated across TP ranks, so rank 0 alone owns the pool + transfer. ### How - `RdmaHiddenStatesConnector` — out-of-tree vLLM connector (no vLLM source edits): one pre-registered pinned NIXL pool per serve, a ring slot per request, and a small HTTP sidecar serving transfer metadata. The trainer RDMA-READs the slot into a per-worker buffer. RDMA is the **only** transport (the earlier disk/safetensors path is removed). - Map-style dataset + multi-node accelerate launch (`--machine_rank`, optional Slurm `--segment` to keep nodes in one NVLink domain). ### Usage ```yaml data: mode: streaming streaming_server_url: "http://node0:8000,http://node1:8000" # round-robin ``` ### Validation (Qwen3-8B, oci-nrt H100) sandbox CI: https://gitlab-master.nvidia.com/omniml/integration/nmm-sandbox/-/jobs/337489812 **1. End-to-end convergence — EAGLE3 & DFlash, 5000 steps.** Both algorithms converge and export a deployable draft; the DFlash drafts also serve under vLLM speculative decoding (8/8 smoke prompts pass). | algorithm | topology (nodes) | train loss (step 0 → 5000) | vLLM draft acc-len | |---|---|---|---| | EAGLE3 | 2 serve TP=2 + 2 trainer DDP (4) | 37.1 → 8.20 | — | | DFlash | 1 serve TP=1 + 1 trainer (2) | 11.7 → 5.56 | 1.11 | | DFlash | 2 serve TP=2 + 2 trainer DDP (4) | 10.9 → 5.26 | 1.19 | <!-- Drag these PNGs in here (GitHub turns them into asset URLs): eagle3_streaming_loss.png, dflash_streaming_loss_singlenode.png, dflash_streaming_loss_multinode.png --> **2. Scalability — 1 → 12 nodes (EAGLE3, 200 steps).** Throughput scales ~23× across the sweep below. The step-time growth is cross-node DDP all-reduce, not the streaming path — RDMA (~0.33 ms/req @ 2 MB, ~47 GB/s host-pinned READ) is never the bottleneck. Scale serve + trainer nodes together for near-linear speedup. | serve / trainer | nodes | step time | samples / step | samples / sec (global) | acc @ step 200 | |---|---|---|---|---|---| | 1 serve / 1 rank (co-located, 1 node 2 GPU) | 1 | 0.23 s | 1 | 4.4 | [0.141, 0.094, 0.072] | | 1 serve / 1 rank (cross-node) | 2 | 0.23 s | 1 | 4.3 | [0.137, 0.105, 0.074] | | 2 serve / 8 ranks | 3 | 0.26 s | 8 | 31.1 | [0.215, 0.126, 0.097] | | 4 serve / 16 ranks (2 trainer nodes) | 6 | 0.28 s | 16 | 56.5 | [0.217, 0.148, 0.110] | | 8 serve / 32 ranks (4 trainer nodes) | 12 | 0.31 s | 32 | 101.9 | [0.235, 0.165, 0.137] | **3. Serve-side TP correctness.** TP=1 vs TP=2 draft top-1 accuracy track step-for-step (hidden states are replicated across TP ranks). <img width="910" height="546" alt="serve-tp-acc" src="https://github.com/user-attachments/assets/73df9214-7ff0-4ab4-bf2f-95842b12cd5f" /> ### Before your PR is "*Ready for review*" - Backward compatible?: ❌ — streaming is now RDMA-only; `server_url` → `server_urls`; the disk transport (`HS_TRANSPORT`, `streaming_shared_storage_path`) is removed. - New tests?: ✅ `tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py` (map-style dataset + mocked RDMA fetch). - Updated Changelog?: ❌ - Claude approval?: ❌ --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 13e540f commit 46eddab

32 files changed

Lines changed: 2356 additions & 905 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Changelog
77
**New Features**
88

99
- Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred.
10+
- Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``.
1011

1112
0.45 (2026-06-xx)
1213
^^^^^^^^^^^^^^^^^

examples/dataset/example_data_config.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ outputs:
66
splits:
77
all: 0
88
- name: "ultrachat"
9+
# UltraChat's loader yields prompt-only turns (no assistant completions),
10+
# which makes answer_only_loss=true mask nothing. Use daring-anteater below.
911
splits:
10-
train_gen: 25000
11-
train_sft: 25000
12+
train_gen: 0
13+
train_sft: 0
1214
- name: "mtbench"
1315
splits:
1416
all: 0
1517
- name: "daring-anteater"
18+
# Multi-turn SFT conversations WITH assistant completions (real split: train).
1619
splits:
17-
all: 0
20+
train: 50000
1821
- name: "magpie"
1922
splits:
2023
300k: 0

examples/specdec_bench/specdec_bench/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def _checkpoint_provenance(model_dir):
209209

210210

211211
def _is_sensitive_key(key):
212+
# Engine configs can carry non-string dict keys (e.g. int layer ids in a
213+
# serving_config); those are never sensitive field *names*, so skip them.
212214
if not isinstance(key, str):
213215
return False
214216
klow = key.lower()

examples/speculative_decoding/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
1818
| Simplified Workflow | Train, evaluate, and export EAGLE model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
1919
| Online Training | Train draft model alongside base model in GPU memory | \[[Link](#training-draft-model-with-online-base-model)\] |
2020
| Offline Training | Train draft model using pre-computed hidden states | \[[Link](#training-draft-model-with-offline-base-model)\] |
21+
| Streaming Training | Train draft on hidden states streamed from a live vLLM serve (no disk dump) | \[[Link](#training-draft-model-with-streaming-base-model)\] |
2122
| After Training | Evaluation, export and deployment | \[[Link](#model-validation)\] |
2223
| Advanced Usage | Data synthesis, vocab compression, and configuration | \[[Link](#advanced-usage)\] |
2324
| Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] |
@@ -127,6 +128,10 @@ Once we finish dumping hidden states, launch offline training pointing to the hi
127128
training.output_dir=ckpts/llama-3.2-1b-offline
128129
```
129130

131+
## Training Draft Model with Streaming Base Model
132+
133+
For large base models, you can stream hidden states from a live `vllm serve` instead of dumping them to disk: a co-located server produces the base-model hidden states on the fly and sends them to the trainer over NIXL RDMA, scaling to multiple nodes (dedicated serve replicas + DDP trainers). See the launcher examples, e.g. [Kimi-K2.5 streaming EAGLE3](../../tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yaml) and [streaming DFlash](../../tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml).
134+
130135
## Model Validation
131136

132137
For online training checkpoints, we can run in-framework evaluation on MT-bench:
@@ -334,6 +339,7 @@ See `main.py` for the full example including tokenizer setup, dataset loading, a
334339
| Mistral ||||
335340
| Phi 3 ||||
336341
| QWen 1.5,2,2.5,3 ||||
342+
| Kimi-K2.5, K2.6 | | ||
337343

338344
## Speculation Module Checkpoints
339345

examples/speculative_decoding/eagle_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def make_speculative_data_module(
5959
train_len=None,
6060
answer_only_loss=False,
6161
shift_labels=True,
62-
seed: int = 0,
6362
) -> dict:
6463
"""Create data module for speculative decoding training.
6564
@@ -88,14 +87,15 @@ def make_speculative_data_module(
8887
ds = load_dataset("json", data_files=data_args.data_path, split="train")
8988
if data_args.sample_size > 0:
9089
ds = ds.select(range(data_args.sample_size))
90+
# Map-style dataset: each rank fetches its own DistributedSampler shard.
91+
# Fetch concurrency comes from the DataLoader's num_workers, not a config knob;
92+
# shuffling/order is the sampler's job (seeded by training_args.seed).
93+
# ``server_urls`` accepts a comma-separated string for multi-server fan-out.
9194
streaming_cfg = EagleVllmStreamingConfig(
92-
server_url=data_args.streaming_server_url,
95+
server_urls=data_args.streaming_server_url,
9396
model=data_args.streaming_model_name,
94-
shared_storage_root=data_args.streaming_shared_storage_path,
9597
max_seq_len=train_len,
9698
answer_only_loss=answer_only_loss,
97-
prefetch=data_args.streaming_prefetch,
98-
seed=seed,
9999
)
100100
train_dataset = EagleVllmStreamingDataset(
101101
entries=ds,
@@ -138,7 +138,9 @@ def make_speculative_data_module(
138138
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
139139
if data_args.sample_size > 0:
140140
dumped_files = dumped_files[: data_args.sample_size]
141-
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
141+
train_dataset = OfflineSupervisedDataset(
142+
dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer
143+
)
142144
data_collator = EagleOfflineDataCollator(train_len=train_len)
143145

144146
return {

examples/speculative_decoding/launch_train.sh

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip <IP>
2020
# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy
2121
#
22-
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py.
23-
# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file.
24-
# Only multi-node routing args are passed here; mixed_precision is fixed to bf16.
22+
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py; all
23+
# training config lives in the YAML. mixed_precision is fixed to bf16.
2524

2625
set -eo pipefail
2726

@@ -30,12 +29,14 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
3029
CONFIG_FILE=""
3130
NUM_NODES=1
3231
HEAD_NODE_IP=""
32+
MACHINE_RANK=""
3333
EXTRA_ARGS=()
3434
while [ $# -gt 0 ]; do
3535
case "$1" in
3636
--config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;;
3737
--num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;;
3838
--head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;;
39+
--machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;;
3940
*) EXTRA_ARGS+=("$1") ;;
4041
esac
4142
shift
@@ -46,7 +47,6 @@ if [ -z "$CONFIG_FILE" ]; then
4647
exit 1
4748
fi
4849

49-
# GPU count detection
5050
if [[ "$NUM_NODES" != "1" ]]; then
5151
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
5252
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
@@ -56,20 +56,28 @@ else
5656
echo "Total GPUs: $TOTAL_GPU (single node)"
5757
fi
5858

59-
# Multi-node routing args (accelerate only; training config comes from the YAML)
60-
MULTI_NODE_ARGS=""
59+
MULTI_NODE_ARGS=()
6160
if [[ "$NUM_NODES" != "1" ]]; then
62-
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
63-
--num_machines $NUM_NODES \
64-
--machine_rank $SLURM_PROCID \
65-
--rdzv_backend c10d \
66-
--main_process_ip $HEAD_NODE_IP \
67-
--main_process_port 29500"
61+
# --multi_gpu is required even at 1 GPU/node, else accelerate won't form the DDP group.
62+
# machine_rank defaults to $SLURM_PROCID; override --machine_rank if node 0 isn't a trainer.
63+
MULTI_NODE_ARGS=(
64+
--multi_gpu
65+
--num_processes "$TOTAL_GPU"
66+
--num_machines "$NUM_NODES"
67+
--machine_rank "${MACHINE_RANK:-$SLURM_PROCID}"
68+
--main_process_ip "$HEAD_NODE_IP"
69+
--main_process_port 29500
70+
)
6871
fi
6972

7073
export TOKENIZERS_PARALLELISM=False
7174

75+
# argv array, not `sh -c` (which would word-split overrides and run embedded substitutions).
76+
CMD=(accelerate launch --mixed_precision bf16
77+
"${MULTI_NODE_ARGS[@]}"
78+
"${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}")
79+
7280
set -x
7381
start_time=$(date +%s)
74-
sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}"
82+
"${CMD[@]}"
7583
echo "Total time: $(( $(date +%s) - $start_time )) seconds"

examples/speculative_decoding/main.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def train():
267267
train_len=training_args.training_seq_len,
268268
answer_only_loss=training_args.answer_only_loss,
269269
shift_labels=not is_dflash,
270-
seed=training_args.seed,
271270
)
272271

273272
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
@@ -277,13 +276,10 @@ def train():
277276
and recipe.eagle.eagle_base_lora_warmup_steps > 0
278277
):
279278
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
280-
if recipe.data.mode == "streaming":
281-
# Skip-on-resume happens inside the dataset (no re-fetch from server);
282-
# disable HF Trainer's own data skip so the offset isn't applied twice.
283-
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback
284-
285-
training_args.ignore_data_skip = True
286-
callbacks.append(StreamingResumeCallback())
279+
# Leave training_args.ignore_data_skip at its default (False). The dataset is
280+
# map-style, so HF Trainer's resume skips consumed indices at the batch-sampler
281+
# level (accelerate.skip_first_batches) without re-fetching them, landing at the
282+
# exact data position. Setting it True would restart the data order from the top.
287283

288284
trainer = EagleTrainerWithAccLog(
289285
model=model,

modelopt/recipe/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131
TrainingArguments as SpecTrainingArgs,
3232
)
3333

34+
__all__ = [
35+
"RECIPE_TYPE_TO_CLASS",
36+
"ModelOptDFlashRecipe",
37+
"ModelOptEagleRecipe",
38+
"ModelOptMedusaRecipe",
39+
"ModelOptPTQRecipe",
40+
"ModelOptRecipeBase",
41+
"ModelOptSpeculativeRecipeBase",
42+
"RecipeMetadataConfig",
43+
"RecipeType",
44+
]
45+
3446

3547
class RecipeType(str, Enum):
3648
"""List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping."""
@@ -178,7 +190,11 @@ class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase):
178190

179191
@model_validator(mode="after")
180192
def _derive_dflash_offline(self) -> ModelOptDFlashRecipe:
181-
self.dflash.dflash_offline = self.data.offline_data_path is not None
193+
# offline (dumped .pt) and streaming (hidden states via NIXL RDMA from a vLLM
194+
# serve) both feed pre-computed base hidden states to the DFlash module, so
195+
# both set dflash_offline. Only fully-online training runs the base model.
196+
# Mirrors ModelOptEagleRecipe._derive_eagle_offline.
197+
self.dflash.dflash_offline = self.data.mode != "online"
182198
return self
183199

184200

modelopt/torch/speculative/config.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323

2424
from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config
2525

26+
__all__ = [
27+
"DFLASH_DEFAULT_CFG",
28+
"EAGLE3_DEFAULT_CFG",
29+
"EAGLE_MTP_DEFAULT_CFG",
30+
"DFlashConfig",
31+
"EagleConfig",
32+
"MedusaConfig",
33+
"eagle3_default_config",
34+
"eagle_mtp_default_config",
35+
"kimik2_eagle_default_config",
36+
]
37+
2638
kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)
2739

2840
eagle3_default_config = deepcopy(default_eagle_config)
@@ -68,8 +80,10 @@ class DFlashConfig(ModeloptBaseConfig):
6880
dflash_offline: bool = ModeloptField(
6981
default=False,
7082
description=(
71-
"Whether to use detached DFlash (offline training from pre-computed hidden states). "
72-
"Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable."
83+
"Whether the DFlash module consumes pre-computed hidden states (offline from "
84+
"dumped .pt files, or streaming via NIXL RDMA from a vLLM serve) instead of running "
85+
"the base model. Derived by ModelOptDFlashRecipe from data.mode (True unless "
86+
"online); not user-configurable."
7387
),
7488
)
7589

modelopt/torch/speculative/eagle/utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from torch.utils.data import Dataset
4242
from transformers.trainer_pt_utils import LabelSmoother
4343

44+
from modelopt.torch.utils.loss_mask import get_loss_mask_recovery
45+
4446
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
4547

4648

@@ -96,20 +98,27 @@ class OfflineSupervisedDataset(Dataset):
9698
dumped_files (list): A list of file paths to the dumped .pt files.
9799
answer_only_loss (bool): If True, use the ``loss_mask`` stored in each .pt
98100
file so that only assistant-produced tokens contribute to the loss.
99-
Raises ``ValueError`` on ``__getitem__`` if the file lacks ``loss_mask``.
101+
If a file lacks ``loss_mask`` and ``tokenizer`` has a registered
102+
model-specific recovery (see ``modelopt.torch.utils.loss_mask``), the
103+
mask is rebuilt from ``input_ids``; otherwise ``__getitem__`` raises
104+
``ValueError``.
100105
If False (default), a uniform all-ones mask is used regardless of what
101106
is stored in the file (backward compatible).
107+
tokenizer: Optional tokenizer used to recover the assistant mask for dumps
108+
that lack a stored ``loss_mask``.
102109
"""
103110

104111
def __init__(
105112
self,
106113
dumped_files,
107114
answer_only_loss: bool = False,
115+
tokenizer=None,
108116
):
109117
"""Initialize with a list of .pt file paths."""
110118
super().__init__()
111119
self.dumped_files = dumped_files
112120
self.answer_only_loss = answer_only_loss
121+
self.tokenizer = tokenizer
113122

114123
def __len__(self):
115124
return len(self.dumped_files)
@@ -121,13 +130,22 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
121130
labels[..., :-1] = offline_data["input_ids"][..., 1:]
122131

123132
if self.answer_only_loss:
124-
if "loss_mask" not in offline_data:
133+
recovery = get_loss_mask_recovery(self.tokenizer) if self.tokenizer else None
134+
if "loss_mask" in offline_data:
135+
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
136+
elif recovery is not None:
137+
# Dumps from tokenizers that cannot emit assistant masks carry no
138+
# loss_mask; rebuild it from the token ids.
139+
loss_mask = recovery.compute(self.tokenizer, offline_data["input_ids"]).to(
140+
offline_data["input_ids"].dtype
141+
)
142+
else:
125143
raise ValueError(
126144
f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
127145
f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
128-
f"with --answer-only-loss in compute_hidden_states_*.py."
146+
f"with --answer-only-loss in compute_hidden_states_*.py, or pass a "
147+
f"tokenizer with a registered loss-mask recovery."
129148
)
130-
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
131149
else:
132150
loss_mask = torch.ones_like(offline_data["input_ids"])
133151

0 commit comments

Comments
 (0)