Skip to content

Commit 81f0c3d

Browse files
committed
squash: rdma + multinode
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 2c52e7b commit 81f0c3d

27 files changed

Lines changed: 1997 additions & 897 deletions

examples/specdec_bench/specdec_bench/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def _checkpoint_provenance(model_dir):
196196

197197

198198
def _is_sensitive_key(key):
199+
# Engine configs can carry non-string dict keys (e.g. int layer ids in a
200+
# serving_config); those are never sensitive field *names*, so skip them.
199201
if not isinstance(key, str):
200202
return False
201203
klow = key.lower()

examples/speculative_decoding/eagle_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def make_speculative_data_module(
5959
train_len=None,
6060
answer_only_loss=False,
6161
shift_labels=True,
62-
seed: int = 0,
6362
) -> dict:
6463
"""Create data module for speculative decoding training.
6564
@@ -88,14 +87,15 @@ def make_speculative_data_module(
8887
ds = load_dataset("json", data_files=data_args.data_path, split="train")
8988
if data_args.sample_size > 0:
9089
ds = ds.select(range(data_args.sample_size))
90+
# Map-style dataset: each rank fetches its own DistributedSampler shard.
91+
# Fetch concurrency comes from the DataLoader's num_workers, not a config knob;
92+
# shuffling/order is the sampler's job (seeded by training_args.seed).
93+
# ``server_urls`` accepts a comma-separated string for multi-server fan-out.
9194
streaming_cfg = EagleVllmStreamingConfig(
92-
server_url=data_args.streaming_server_url,
95+
server_urls=data_args.streaming_server_url,
9396
model=data_args.streaming_model_name,
94-
shared_storage_root=data_args.streaming_shared_storage_path,
9597
max_seq_len=train_len,
9698
answer_only_loss=answer_only_loss,
97-
prefetch=data_args.streaming_prefetch,
98-
seed=seed,
9999
)
100100
train_dataset = EagleVllmStreamingDataset(
101101
entries=ds,
@@ -138,7 +138,9 @@ def make_speculative_data_module(
138138
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
139139
if data_args.sample_size > 0:
140140
dumped_files = dumped_files[: data_args.sample_size]
141-
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
141+
train_dataset = OfflineSupervisedDataset(
142+
dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer
143+
)
142144
data_collator = EagleOfflineDataCollator(train_len=train_len)
143145

144146
return {

examples/speculative_decoding/launch_train.sh

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip <IP>
2020
# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy
2121
#
22-
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py.
23-
# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file.
24-
# Only multi-node routing args are passed here; mixed_precision is fixed to bf16.
22+
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py; all
23+
# training config lives in the YAML. mixed_precision is fixed to bf16.
2524

2625
set -eo pipefail
2726

@@ -30,12 +29,14 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
3029
CONFIG_FILE=""
3130
NUM_NODES=1
3231
HEAD_NODE_IP=""
32+
MACHINE_RANK=""
3333
EXTRA_ARGS=()
3434
while [ $# -gt 0 ]; do
3535
case "$1" in
3636
--config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;;
3737
--num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;;
3838
--head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;;
39+
--machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;;
3940
*) EXTRA_ARGS+=("$1") ;;
4041
esac
4142
shift
@@ -46,7 +47,6 @@ if [ -z "$CONFIG_FILE" ]; then
4647
exit 1
4748
fi
4849

49-
# GPU count detection
5050
if [[ "$NUM_NODES" != "1" ]]; then
5151
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
5252
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
@@ -56,20 +56,28 @@ else
5656
echo "Total GPUs: $TOTAL_GPU (single node)"
5757
fi
5858

59-
# Multi-node routing args (accelerate only; training config comes from the YAML)
60-
MULTI_NODE_ARGS=""
59+
MULTI_NODE_ARGS=()
6160
if [[ "$NUM_NODES" != "1" ]]; then
62-
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
63-
--num_machines $NUM_NODES \
64-
--machine_rank $SLURM_PROCID \
65-
--rdzv_backend c10d \
66-
--main_process_ip $HEAD_NODE_IP \
67-
--main_process_port 29500"
61+
# --multi_gpu is required even at 1 GPU/node, else accelerate won't form the DDP group.
62+
# machine_rank defaults to $SLURM_PROCID; override --machine_rank if node 0 isn't a trainer.
63+
MULTI_NODE_ARGS=(
64+
--multi_gpu
65+
--num_processes "$TOTAL_GPU"
66+
--num_machines "$NUM_NODES"
67+
--machine_rank "${MACHINE_RANK:-$SLURM_PROCID}"
68+
--main_process_ip "$HEAD_NODE_IP"
69+
--main_process_port 29500
70+
)
6871
fi
6972

7073
export TOKENIZERS_PARALLELISM=False
7174

75+
# argv array, not `sh -c` (which would word-split overrides and run embedded substitutions).
76+
CMD=(accelerate launch --mixed_precision bf16
77+
"${MULTI_NODE_ARGS[@]}"
78+
"${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}")
79+
7280
set -x
7381
start_time=$(date +%s)
74-
sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}"
82+
"${CMD[@]}"
7583
echo "Total time: $(( $(date +%s) - $start_time )) seconds"

examples/speculative_decoding/main.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def train():
267267
train_len=training_args.training_seq_len,
268268
answer_only_loss=training_args.answer_only_loss,
269269
shift_labels=not is_dflash,
270-
seed=training_args.seed,
271270
)
272271

273272
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
@@ -277,13 +276,10 @@ def train():
277276
and recipe.eagle.eagle_base_lora_warmup_steps > 0
278277
):
279278
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
280-
if recipe.data.mode == "streaming":
281-
# Skip-on-resume happens inside the dataset (no re-fetch from server);
282-
# disable HF Trainer's own data skip so the offset isn't applied twice.
283-
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback
284-
285-
training_args.ignore_data_skip = True
286-
callbacks.append(StreamingResumeCallback())
279+
# Leave training_args.ignore_data_skip at its default (False). The dataset is
280+
# map-style, so HF Trainer's resume skips consumed indices at the batch-sampler
281+
# level (accelerate.skip_first_batches) without re-fetching them, landing at the
282+
# exact data position. Setting it True would restart the data order from the top.
287283

288284
trainer = EagleTrainerWithAccLog(
289285
model=model,

modelopt/recipe/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131
TrainingArguments as SpecTrainingArgs,
3232
)
3333

34+
__all__ = [
35+
"RECIPE_TYPE_TO_CLASS",
36+
"ModelOptDFlashRecipe",
37+
"ModelOptEagleRecipe",
38+
"ModelOptMedusaRecipe",
39+
"ModelOptPTQRecipe",
40+
"ModelOptRecipeBase",
41+
"ModelOptSpeculativeRecipeBase",
42+
"RecipeMetadataConfig",
43+
"RecipeType",
44+
]
45+
3446

3547
class RecipeType(str, Enum):
3648
"""List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping."""
@@ -178,7 +190,11 @@ class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase):
178190

179191
@model_validator(mode="after")
180192
def _derive_dflash_offline(self) -> ModelOptDFlashRecipe:
181-
self.dflash.dflash_offline = self.data.offline_data_path is not None
193+
# offline (dumped .pt) and streaming (hidden states over HTTP from a vLLM
194+
# serve) both feed pre-computed base hidden states to the DFlash module, so
195+
# both set dflash_offline. Only fully-online training runs the base model.
196+
# Mirrors ModelOptEagleRecipe._derive_eagle_offline.
197+
self.dflash.dflash_offline = self.data.mode != "online"
182198
return self
183199

184200

modelopt/torch/speculative/config.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323

2424
from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config
2525

26+
__all__ = [
27+
"DFLASH_DEFAULT_CFG",
28+
"EAGLE3_DEFAULT_CFG",
29+
"EAGLE_MTP_DEFAULT_CFG",
30+
"DFlashConfig",
31+
"EagleConfig",
32+
"MedusaConfig",
33+
"eagle3_default_config",
34+
"eagle_mtp_default_config",
35+
"kimik2_eagle_default_config",
36+
]
37+
2638
kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)
2739

2840
eagle3_default_config = deepcopy(default_eagle_config)
@@ -68,8 +80,10 @@ class DFlashConfig(ModeloptBaseConfig):
6880
dflash_offline: bool = ModeloptField(
6981
default=False,
7082
description=(
71-
"Whether to use detached DFlash (offline training from pre-computed hidden states). "
72-
"Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable."
83+
"Whether the DFlash module consumes pre-computed hidden states (offline from "
84+
"dumped .pt files, or streaming over HTTP from a vLLM serve) instead of running "
85+
"the base model. Derived by ModelOptDFlashRecipe from data.mode (True unless "
86+
"online); not user-configurable."
7387
),
7488
)
7589

modelopt/torch/speculative/eagle/utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from torch.utils.data import Dataset
4242
from transformers.trainer_pt_utils import LabelSmoother
4343

44+
from modelopt.torch.utils.loss_mask import get_loss_mask_recovery
45+
4446
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
4547

4648

@@ -96,20 +98,27 @@ class OfflineSupervisedDataset(Dataset):
9698
dumped_files (list): A list of file paths to the dumped .pt files.
9799
answer_only_loss (bool): If True, use the ``loss_mask`` stored in each .pt
98100
file so that only assistant-produced tokens contribute to the loss.
99-
Raises ``ValueError`` on ``__getitem__`` if the file lacks ``loss_mask``.
101+
If a file lacks ``loss_mask`` and ``tokenizer`` has a registered
102+
model-specific recovery (see ``modelopt.torch.utils.loss_mask``), the
103+
mask is rebuilt from ``input_ids``; otherwise ``__getitem__`` raises
104+
``ValueError``.
100105
If False (default), a uniform all-ones mask is used regardless of what
101106
is stored in the file (backward compatible).
107+
tokenizer: Optional tokenizer used to recover the assistant mask for dumps
108+
that lack a stored ``loss_mask``.
102109
"""
103110

104111
def __init__(
105112
self,
106113
dumped_files,
107114
answer_only_loss: bool = False,
115+
tokenizer=None,
108116
):
109117
"""Initialize with a list of .pt file paths."""
110118
super().__init__()
111119
self.dumped_files = dumped_files
112120
self.answer_only_loss = answer_only_loss
121+
self.tokenizer = tokenizer
113122

114123
def __len__(self):
115124
return len(self.dumped_files)
@@ -121,13 +130,22 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
121130
labels[..., :-1] = offline_data["input_ids"][..., 1:]
122131

123132
if self.answer_only_loss:
124-
if "loss_mask" not in offline_data:
133+
recovery = get_loss_mask_recovery(self.tokenizer) if self.tokenizer else None
134+
if "loss_mask" in offline_data:
135+
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
136+
elif recovery is not None:
137+
# Dumps from tokenizers that cannot emit assistant masks carry no
138+
# loss_mask; rebuild it from the token ids.
139+
loss_mask = recovery.compute(self.tokenizer, offline_data["input_ids"]).to(
140+
offline_data["input_ids"].dtype
141+
)
142+
else:
125143
raise ValueError(
126144
f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
127145
f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
128-
f"with --answer-only-loss in compute_hidden_states_*.py."
146+
f"with --answer-only-loss in compute_hidden_states_*.py, or pass a "
147+
f"tokenizer with a registered loss-mask recovery."
129148
)
130-
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
131149
else:
132150
loss_mask = torch.ones_like(offline_data["input_ids"])
133151

0 commit comments

Comments
 (0)