Skip to content

Commit adff795

Browse files
arvyanhArronHZGgemini-code-assist[bot]
authored
[megatron] feat: Support MTP training in SFT (verl-project#4981)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The SFT training with MTP is supported, using the same MTP training configuration as RL training. An example configuration for running SFT can be found in `examples/sft/gsm8k/run_mimo_megatron_mtp.sh` **SFT result** The experiment was conducted using following data: - model = mimo-7B-math - dataset = gsm8k The result: [wandb link](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg) The presence of mtp layer has limited effect on main loss. However, when MTP layer is detached, the mtp_loss converges to a higher value. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### Design & Code Changes Support SFT training with MTP, with the core change based on verl-project#4936 ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: Using none-standard mbridge/mcore - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [X] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: ArronHZG <hou.zg@foxmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 016572f commit adff795

File tree

11 files changed

+303
-34
lines changed

11 files changed

+303
-34
lines changed

docs/advance/mtp.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# Guide to Using MTP in RL Training and Inference
1+
# Guide to Using MTP in SFT/RL Training and Inference
22

33
**Author**: `https://github.com/meituan-search`
44

5-
Last updated: 01/16/2026
5+
Last updated: 01/21/2026
66

77
# 1. Scope of Support
88

@@ -41,6 +41,8 @@ Experiment chart:
4141
![fully_async_policy_revenue](
4242
https://github.com/ArronHZG/verl-community/blob/main/docs/mimo-7b-mtp.png?raw=true)
4343

44+
The wandb link for the graph: [wandb](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg)
45+
4446
**Scenarios with No Significant Effect**
4547

4648
The following configurations will not have a noticeable impact on training results:
@@ -82,3 +84,20 @@ Taking the mimo-7B model deployed separately on H20 hardware using SGLang as an
8284
- Current priority recommendation: Do not enable MTP acceleration during the inference phase for now;
8385

8486
- Future planning: Further optimization of the speculative logic in the Rollout phase will be conducted to improve throughput performance.
87+
88+
# 5. SFT training
89+
90+
The SFT training with MTP is supported, using the same MTP training configuration as RL training.
91+
92+
An example configuration for running SFT can be found in `examples/sft/gsm8k/run_mimo_megatron_mtp.sh`
93+
94+
**SFT result**
95+
96+
The experiment was conducted using following data:
97+
- model = mimo-7B-math
98+
- dataset = gsm8k
99+
100+
The result: [wandb link](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg)
101+
102+
The presence of mtp layer has limited effect on main loss. However, when MTP layer is detached, the mtp_loss converges to a higher value.
103+
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
NUM_GPUS=${NUM_GPUS:-8}
5+
SP_SIZE=${SP_SIZE:-1}
6+
TP_SIZE=${TP_SIZE:-1}
7+
PP_SIZE=${PP_SIZE:-1}
8+
VPP_SIZE=${VPP_SIZE:-null}
9+
CP_SIZE=${CP_SIZE:-1}
10+
PAD_MODE=${PAD_MODE:-no_padding}
11+
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-False}
12+
LR="1e-5"
13+
MINLR="1e-6"
14+
15+
export VERL_SFT_LOGGING_LEVEL=INFO
16+
17+
backend=${BACKEND:-megatron}
18+
19+
TENSORBOARD_DIR=~/tensorboard
20+
21+
MASTER_ADDR=${MASTER_ADDR:-localhost}
22+
MASTER_PORT=${MASTER_PORT:-29500}
23+
NNODES=${NNODES:-1}
24+
RANK=${RANK:-0}
25+
26+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
27+
28+
# Note the default MultiturnSFT Dataset requires all the sys/user/assistant in 'data.message_key'
29+
DATASET_DIR=${DATASET_DIR:-~/dataset/rl/gsm8k}
30+
TRAIN_FILES=${DATASET_DIR}/train.parquet
31+
VAL_FILES=${DATASET_DIR}/eval.parquet
32+
33+
project_name=verl_sft_test
34+
35+
RESUME_MODE=disable
36+
37+
MODEL_PATH="XiaomiMiMo/MiMo-7B-RL"
38+
ckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}}
39+
40+
# currently relies on these two commits that is not on master
41+
PYPATH=$HOME/pythonpath
42+
mkdir -p $PYPATH && cd $PYPATH
43+
[ -d Megatron-LM ] || git clone https://github.com/NVIDIA/Megatron-LM -b dev && (cd Megatron-LM; git checkout 23e092f41ec8bc659020e401ddac9576c1cfed7e)
44+
[ -d mbridge ] || git clone https://github.com/ArronHZG/mbridge -b feature/verl_mtp && (cd mbridge; git checkout 6bf2d45a15dc4fb52d2f0c38ff546bee33447d10)
45+
cd -
46+
export PYTHONPATH=$PYTHONPATH:$PYPATH/mbridge:$PYPATH/Megatron-LM
47+
48+
49+
MEGATRON_ENGINE_CONFIG="\
50+
engine=${backend} \
51+
optim=${backend} \
52+
optim.lr=${LR} \
53+
optim.min_lr=${MINLR} \
54+
optim.lr_warmup_steps=10 \
55+
optim.weight_decay=0.1 \
56+
optim.betas='[0.9,0.95]' \
57+
optim.clip_grad=1.0 \
58+
optim.lr_warmup_init=0 \
59+
optim.lr_decay_style=cosine \
60+
engine.override_transformer_config.recompute_method=uniform \
61+
engine.override_transformer_config.recompute_granularity=full \
62+
engine.override_transformer_config.recompute_num_layers=1 \
63+
engine.use_dist_checkpointing=False \
64+
engine.tensor_model_parallel_size=${TP_SIZE} \
65+
engine.pipeline_model_parallel_size=${PP_SIZE} \
66+
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
67+
engine.context_parallel_size=${CP_SIZE} \
68+
engine.use_mbridge=True \
69+
"
70+
71+
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
72+
echo "Using megatron engine"
73+
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-lr-${MINLR}-${LR}
74+
75+
mkdir -p "${ckpts_home}"
76+
77+
$COMMAND \
78+
data.train_files="${TRAIN_FILES}" \
79+
data.val_files="${TRAIN_FILES}" \
80+
data.train_batch_size=64 \
81+
data.micro_batch_size_per_gpu=2 \
82+
data.pad_mode=${PAD_MODE} \
83+
data.truncation=error \
84+
data.max_length=1024 \
85+
data.use_dynamic_bsz=True \
86+
data.max_token_len_per_gpu=2048 \
87+
data.messages_key=prompt \
88+
data.num_workers=0 \
89+
model.path=$MODEL_PATH \
90+
model.use_remove_padding=${USE_REMOVE_PADDING} \
91+
model.trust_remote_code=True \
92+
model.mtp.enable=True \
93+
${ENGINE_CONFIG} \
94+
trainer.test_freq=after_each_epoch \
95+
trainer.save_freq=-1 \
96+
trainer.logger="['console']" \
97+
trainer.project_name="${project_name}" \
98+
trainer.experiment_name="${exp_name}" \
99+
trainer.total_epochs=1 \
100+
trainer.default_local_dir="${ckpts_home}" \
101+
trainer.resume_mode=${RESUME_MODE}
102+

verl/models/mcore/model_forward.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def gptmodel_forward_no_padding(
168168
vision_model=False,
169169
pad_token_id=None,
170170
data_format: str = "thd",
171+
enable_mtp: bool = False,
171172
):
172173
"""Default forward pass for GPT models with optional sequence packing."""
173174

@@ -190,6 +191,15 @@ def gptmodel_forward_no_padding(
190191
input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding(input_ids, pre_process=pre_process)
191192
input_ids_rmpad = input_ids_rmpad.contiguous()
192193

194+
if enable_mtp and post_process:
195+
args = {
196+
k: preprocess_thd_no_padding(v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"))[0]
197+
for k, v in logits_processor_args.items()
198+
}
199+
model_kwargs["labels"] = args["label"].contiguous()
200+
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
201+
logits_processor_args.pop("loss_mask")
202+
193203
# For VLM model, need to pass bshd format `input_ids` and `attention_mask`.
194204
attention_mask = None
195205
if vision_model:
@@ -233,6 +243,16 @@ def gptmodel_forward_no_padding(
233243
input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding(
234244
input_ids, pre_process=pre_process
235245
)
246+
247+
if enable_mtp and post_process:
248+
args = {
249+
k: preprocess_bshd_no_padding(v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"))[0]
250+
for k, v in logits_processor_args.items()
251+
}
252+
model_kwargs["labels"] = args["label"].contiguous()
253+
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
254+
logits_processor_args.pop("loss_mask")
255+
236256
output_orig = model(
237257
input_ids=input_ids_bshd,
238258
attention_mask=attention_mask_bshd,

verl/trainer/config/sft_trainer_engine.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ data:
2626
messages_key: messages # Key for messages list in multi-turn mode
2727
tools_key: tools # Key for tools list in multi-turn mode
2828
enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode
29+
enable_thinking_default: none # The default value when enable_thinking_key is not present in the dataset
2930
pad_mode: no_padding
3031
# for right padding
3132
max_length: 1024
@@ -36,6 +37,7 @@ data:
3637
name: null
3738
use_shm: False
3839
apply_chat_template_kwargs: {}
40+
num_workers: 8
3941

4042
# MultiTurnSFTDataset apply_chat_template to each turn separately and concat `input_ids`
4143
# as a whole sequence, which may not equal to apply_chat_template to whole messages at once.

verl/trainer/sft_trainer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from verl.utils.device import auto_set_device, get_device_name
3939
from verl.utils.distributed import destroy_global_process_group
4040
from verl.utils.logger import log_with_rank
41+
from verl.utils.memory_utils import aggressive_empty_cache
42+
from verl.utils.profiler import log_gpu_memory_usage
4143
from verl.utils.tracking import Tracking
4244
from verl.workers.engine_workers import TrainingWorker
4345

@@ -52,6 +54,8 @@ def __init__(
5254
):
5355
self.config = config
5456

57+
log_gpu_memory_usage(f"rank {torch.distributed.get_rank()}: Before SFTTrainer init", logger=logger)
58+
5559
self.rank = torch.distributed.get_rank()
5660

5761
self._build_config()
@@ -73,6 +77,8 @@ def __init__(
7377
if self.rank == 0:
7478
print(self.config)
7579

80+
log_gpu_memory_usage(f"rank {self.rank}: After SFTTrainer init", logger=logger)
81+
7682
def _build_ckpt_handler(self):
7783
resume_mode = getattr(self.config.trainer, "resume_mode", "auto")
7884
resume_from_path = getattr(self.config.trainer, "resume_from_path", None)
@@ -200,7 +206,7 @@ def _build_dataloader(self):
200206
batch_size=self.train_batch_size_per_dp,
201207
sampler=self.train_sampler,
202208
collate_fn=self.collate_fn,
203-
num_workers=8,
209+
num_workers=self.config.data.num_workers,
204210
pin_memory=False,
205211
drop_last=True,
206212
pin_memory_device=device_name,
@@ -215,7 +221,7 @@ def _build_dataloader(self):
215221
batch_size=self.train_batch_size_per_dp,
216222
sampler=self.val_sampler,
217223
collate_fn=self.collate_fn,
218-
num_workers=8,
224+
num_workers=self.config.data.num_workers,
219225
pin_memory=False,
220226
drop_last=True,
221227
pin_memory_device=device_name,
@@ -298,6 +304,9 @@ def fit(self):
298304
for epoch in range(start_epoch, self.config.trainer.total_epochs):
299305
self.train_sampler.set_epoch(epoch=epoch)
300306

307+
aggressive_empty_cache(force_sync=True)
308+
log_gpu_memory_usage(f"rank {self.rank}: At start of epoch {epoch}", logger=logger)
309+
301310
for step_in_epoch, data in enumerate(
302311
tqdm(
303312
self.train_dataloader,
@@ -330,10 +339,11 @@ def fit(self):
330339
metrics = tu.get(output, "metrics")
331340

332341
# TODO: we can actual accumulate metrics for N steps and perform aggregate metrics
333-
metrics["train/loss"] = metrics.pop("loss")
334-
metrics["train/grad_norm"] = metrics.pop("grad_norm")
335-
metrics["train/lr"] = metrics.pop("lr")
336-
metrics["train/mfu"] = metrics.pop("mfu")
342+
for k in ["loss", "grad_norm", "lr", "mfu"]:
343+
if k in metrics.keys():
344+
value = metrics.pop(k)
345+
metrics[f"train/{k}"] = value
346+
337347
metrics["train/global_tokens"] = torch.sum(
338348
torch.tensor(batch_seqlens, device=self.device_name)
339349
).item()
@@ -373,6 +383,7 @@ def fit(self):
373383
torch.distributed.barrier()
374384

375385
if is_last_step or (self.save_freq > 0 and is_save_step):
386+
aggressive_empty_cache(force_sync=True)
376387
self.ckpt_handler.save_checkpoint(step=global_step)
377388

378389
if is_last_step:

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def generate_state_dict(
259259
key = "model"
260260
if hasattr(model, "module"):
261261
model = model.module
262-
state_dict[key] = model.sharded_state_dict()
262+
263+
# GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group']
264+
kwargs = {"metadata": {"dp_cp_group": mpu.get_data_parallel_group(with_context_parallel=True)}}
265+
state_dict[key] = model.sharded_state_dict(**kwargs)
263266

264267
# Optimizer State Dict
265268
if generate_optimizer:

verl/utils/dataset/multiturn_sft_dataset.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import os
2121
import re
22+
from functools import wraps
2223
from typing import Any, Optional
2324

2425
import numpy as np
@@ -40,6 +41,33 @@
4041
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
4142

4243

44+
def once(func):
45+
"""Decorator to ensure a function runs only once. Subsequent calls do nothing."""
46+
47+
@wraps(func)
48+
def wrapper(*args, **kwargs):
49+
if not hasattr(wrapper, "called"):
50+
wrapper.called = True
51+
return func(*args, **kwargs)
52+
53+
return wrapper
54+
55+
56+
@once
57+
def print_assembled_message(tokenizer, message_list, input_ids, loss_mask, attn_mask, tools):
58+
"""
59+
Print the message after applying the chat template
60+
"""
61+
62+
tokenized = tokenizer.apply_chat_template(message_list, add_generation_prompt=False, tokenize=False, tools=tools)
63+
sep = "\n\n"
64+
str = f"tokenized entire message:\n{tokenized}"
65+
str += sep
66+
str += f"tokenized seperately :\n{tokenizer.decode(input_ids)}"
67+
68+
logger.debug(str)
69+
70+
4371
def convert_nested_value_to_list_recursive(data_item):
4472
if isinstance(data_item, dict):
4573
return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()}
@@ -91,6 +119,7 @@ def __init__(
91119
)
92120
self.tools_key = config.get("tools_key", "tools")
93121
self.enable_thinking_key = config.get("enable_thinking_key", "enable_thinking")
122+
self.enable_thinking_default = config.get("enable_thinking_default", None)
94123
self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {})
95124
self.shuffle = config.get("shuffle", False)
96125
self.seed = config.get("seed")
@@ -125,7 +154,8 @@ def series_to_item(ls):
125154

126155
dataframes = []
127156
for parquet_file in self.parquet_files:
128-
dataframe = pd.read_parquet(parquet_file)
157+
# default loader loads some list as np.ndarray, which fails the tokenizer
158+
dataframe = pd.read_parquet(parquet_file, dtype_backend="pyarrow")
129159
dataframes.append(dataframe)
130160
self.dataframe = pd.concat(dataframes)
131161

@@ -167,6 +197,7 @@ def _process_single_message(
167197
self,
168198
index: int,
169199
message: dict[str, Any],
200+
full_message: list,
170201
tools: Optional[list[dict[str, Any]]] = None,
171202
enable_thinking: Optional[bool] = None,
172203
) -> tuple[list[int], list[int], list[int]]:
@@ -267,14 +298,17 @@ def __getitem__(self, item):
267298
row_dict: dict = self.dataframe.iloc[item].to_dict()
268299
messages = self._build_messages(row_dict)
269300
tools = self.tools[item] if self.tools is not None else None
270-
enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None
301+
enable_thinking = (
302+
self.enable_thinking[item] if self.enable_thinking is not None else self.enable_thinking_default
303+
)
271304

272305
# 1. tokenize each message
273306
input_ids, loss_mask, attention_mask, multi_modal_inputs = [], [], [], {}
274307
for i, message in enumerate(messages):
275308
_input_ids, _loss_mask, _attention_mask, _inputs = self._process_single_message(
276309
index=i,
277310
message=message,
311+
full_message=messages,
278312
tools=tools if i == 0 else None,
279313
enable_thinking=enable_thinking,
280314
)
@@ -290,6 +324,8 @@ def __getitem__(self, item):
290324
assert input_ids.shape == loss_mask.shape == attention_mask.shape, (
291325
f"Shape mismatch: {input_ids.shape}, {loss_mask.shape}, {attention_mask.shape}"
292326
)
327+
328+
print_assembled_message(self.tokenizer, messages, input_ids, loss_mask, attention_mask, tools)
293329
self.sanity_check(input_ids, messages, tools, enable_thinking)
294330

295331
# Since the tokenizer may return user-customized results, we need to filter out inconsistent tensor shapes

0 commit comments

Comments
 (0)