Skip to content

Commit b85ea89

Browse files
authored
[megatron] feat: add script for qwen3next training (#4582)
### What does this PR do? 1. add dockerfile of experimental images 2. add example script to run qwen3next megatron training
1 parent 9f7e070 commit b85ea89

File tree

4 files changed

+366
-2
lines changed

4 files changed

+366
-2
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Dockerfile for verlai/verl:sgl056.exp
2+
FROM lmsysorg/sglang:v0.5.6.post1
3+
4+
RUN pip install pybind11
5+
6+
RUN pip install nvidia-mathdx
7+
8+
RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
9+
10+
RUN export NVTE_FRAMEWORK=pytorch && MAX_JOBS=128 NVTE_BUILD_THREADS_PER_JOB=4 pip3 install --resume-retries 999 --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.11
11+
12+
RUN pip install --upgrade --no-cache-dir transformers tokenizers
13+
14+
RUN pip install codetiming tensordict mathruler pylatexenc qwen_vl_utils
15+
16+
RUN pip install --no-cache-dir --no-build-isolation flash_attn==2.8.1
17+
18+
RUN wget https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_6/nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb && \
19+
apt-get update && apt-get install -y libxcb-cursor0
20+
21+
RUN apt-get install -y ./nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb && \
22+
rm -rf /usr/local/cuda/bin/nsys && \
23+
ln -s /opt/nvidia/nsight-systems/2025.6.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \
24+
rm -rf /usr/local/cuda/bin/nsys-ui && \
25+
ln -s /opt/nvidia/nsight-systems/2025.6.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \
26+
rm nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb
27+
28+
29+
# =========================
30+
# Install HybridEP
31+
# =========================
32+
WORKDIR /home/
33+
RUN git clone --branch hybrid-ep https://github.com/deepseek-ai/DeepEP.git && \
34+
cd DeepEP && git checkout 3f601f7ac1c062c46502646ff04c535013bfca00 && \
35+
TORCH_CUDA_ARCH_LIST="9.0;10.0" pip install --no-build-isolation .
36+
37+
# =========================
38+
# Install Qwen3-Next dependencies
39+
# =========================
40+
WORKDIR /home/
41+
# Install causal-conv1d and flash-linear-attention
42+
RUN cd /tmp && \
43+
git clone https://github.com/Dao-AILab/causal-conv1d.git && \
44+
cd causal-conv1d && \
45+
unset PIP_CONSTRAINT && \
46+
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --no-build-isolation . && \
47+
cd .. && \
48+
rm -rf causal-conv1d && \
49+
pip install flash-linear-attention
50+
51+
RUN pip install --no-cache-dir torch-memory-saver
52+
53+
RUN pip3 install --no-cache-dir --no-deps trl
54+
55+
RUN pip3 install nvtx matplotlib liger_kernel
56+
57+
RUN pip install -U git+https://github.com/ISEEKYAN/mbridge.git
58+
59+
RUN pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@1d462bd37dac21cfa14177405d4921eedb987052 # latest dev branch on 20251209
60+
61+
RUN pip install git+https://github.com/volcengine/[email protected]
62+
63+
RUN pip uninstall -y verl
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# dockerfile for verlai/verl:vll012.exp
2+
FROM nvcr.io/nvidia/pytorch:25.11-py3
3+
4+
RUN git clone -b v0.12.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm
5+
6+
RUN pip install setuptools_scm
7+
8+
RUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .
9+
10+
RUN pip install -r /opt/vllm/requirements/common.txt
11+
12+
13+
RUN pip install pybind11
14+
15+
RUN export NVTE_FRAMEWORK=pytorch && MAX_JOBS=128 NVTE_BUILD_THREADS_PER_JOB=4 pip3 install --resume-retries 999 --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.11
16+
17+
RUN pip install --upgrade --no-cache-dir transformers tokenizers
18+
19+
RUN pip install codetiming tensordict mathruler pylatexenc qwen_vl_utils
20+
21+
RUN pip install flash_attn
22+
#==2.8.1
23+
24+
RUN apt update && apt install numactl
25+
26+
RUN wget https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_6/nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb && \
27+
apt-get update && apt-get install -y libxcb-cursor0
28+
29+
RUN apt-get install -y ./nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb && \
30+
rm -rf /usr/local/cuda/bin/nsys && \
31+
ln -s /opt/nvidia/nsight-systems/2025.6.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \
32+
rm -rf /usr/local/cuda/bin/nsys-ui && \
33+
ln -s /opt/nvidia/nsight-systems/2025.6.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \
34+
rm nsight-systems-2025.6.1_2025.6.1.190-1_amd64.deb
35+
36+
37+
# =========================
38+
# Install HybridEP
39+
# =========================
40+
WORKDIR /home/
41+
RUN git clone --branch hybrid-ep https://github.com/deepseek-ai/DeepEP.git && \
42+
cd DeepEP && git checkout 3f601f7ac1c062c46502646ff04c535013bfca00 && \
43+
TORCH_CUDA_ARCH_LIST="9.0;10.0" pip install --no-build-isolation .
44+
45+
# =========================
46+
# Install Qwen3-Next dependencies
47+
# =========================
48+
WORKDIR /home/
49+
# Install causal-conv1d and flash-linear-attention
50+
RUN cd /tmp && \
51+
git clone https://github.com/Dao-AILab/causal-conv1d.git && \
52+
cd causal-conv1d && \
53+
unset PIP_CONSTRAINT && \
54+
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --no-build-isolation . && \
55+
cd .. && \
56+
rm -rf causal-conv1d && \
57+
pip install flash-linear-attention
58+
59+
RUN pip3 install --no-cache-dir --no-deps trl
60+
61+
RUN pip3 install nvtx matplotlib liger_kernel
62+
63+
RUN pip install -U git+https://github.com/ISEEKYAN/mbridge.git
64+
65+
RUN pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@1d462bd37dac21cfa14177405d4921eedb987052 # latest dev branch on 20251209
66+
67+
RUN pip install git+https://github.com/volcengine/[email protected]
68+
69+
RUN pip uninstall -y verl

recipe/dapo/test_dapo_gptoss_20b_megatron.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ set -xeuo pipefail
44
################################################### document for gptoss ###################################################
55

66
####################### running environment: #######################
7-
# option 1: use a pre-built docker image dedicated for gptoss: `docker://iseekyan/verl:nemo.gptoss_vllm0.11.0`, which is
8-
# built upon nemo's dedicated image, see Dockerfile at https://github.com/volcengine/verl/blob/main/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.vllm011.mcore_gpt-oss
7+
# option 1: use pre-built images verlai/verl:vll012.exp or verlai/verl:sgl056.exp
98
#
109
# option 2: self build TE>=2.8 with CUDNN>=9.13.1, megatron with branch `core_dev_r0.15.0`, latest vllm or sglang
1110
# you can modify the dockerfile to build the image, see Dockerfile at https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.vllm or https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.sglang
1211

12+
1313
####################### before training: #######################
1414
# # install matched mbridge version
1515
# pip uninstall -y mbridge && pip install git+https://github.com/ISEEKYAN/mbridge@gpt-oss
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
5+
################################################### document for qwen3next ###################################################
6+
7+
####################### running environment: #######################
8+
9+
# option 1: use pre-built docker images verlai/verl:vll012.exp or verlai/verl:sgl056.exp
10+
11+
# option 2: self build TE>=2.8, megatron with dev branch and megatron-bridge with main branch
12+
13+
####################### how we support qwen3next? #######################
14+
# we support qwen3next with megatron-bridge, which is enabled by set `vanilla_mbridge=False`
15+
16+
####################### limitations: #######################
17+
# 1. context parallel(CP) is not supported until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2614
18+
# 2. sequence packing(aka thd) is not supported, we must set `actor_rollout_ref.actor.megatron.use_remove_padding=False`, until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2644
19+
20+
## if sequence packing is disabled, we recommend to set `use_dynamic_bsz=False` and set micro batchsize to 1,
21+
## otherwise the data will be padded to the max length of the batch, which is not efficient. But it's not mandatory
22+
23+
24+
25+
26+
################################################### quick config ###################################################
27+
28+
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@dev # install megatron from dev branch
29+
# pip install --no-deps git+https://github.com/NVIDIA-Nemo/Megatron-Bridge.git # install megatron-bridge from main branch
30+
31+
32+
rollout_mode="async"
33+
return_raw_chat="True"
34+
export VLLM_USE_V1=1
35+
rollout_name="vllm" # sglang or vllm
36+
dtype="bfloat16"
37+
38+
39+
project_name='DAPO-test'
40+
exp_name='qwen3next'
41+
42+
adv_estimator=grpo
43+
44+
use_kl_in_reward=False
45+
kl_coef=0.0
46+
use_kl_loss=False
47+
kl_loss_coef=0.0
48+
49+
clip_ratio_low=0.2
50+
clip_ratio_high=0.28
51+
52+
max_prompt_length=$((1024 * 2))
53+
max_response_length=$((1024 * 8))
54+
enable_overlong_buffer=True
55+
overlong_buffer_len=$((1024 * 4))
56+
overlong_penalty_factor=1.0
57+
58+
loss_agg_mode="token-mean"
59+
60+
train_prompt_bsz=32
61+
n_resp_per_prompt=16
62+
train_prompt_mini_bsz=32
63+
64+
# Ray
65+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
66+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
67+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"}
68+
NNODES=${NNODES:-4}
69+
# Paths
70+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-Next-80B-A3B-Instruct"}
71+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
72+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
73+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
74+
75+
# Algorithm
76+
temperature=1.0
77+
top_p=1.0
78+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
79+
val_top_p=0.7
80+
81+
# Performance Related Parameter
82+
use_dynamic_bsz=False
83+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10))
84+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
85+
offload=True
86+
gen_tp=16
87+
train_tp=2
88+
EP=32
89+
ETP=1
90+
train_pp=1
91+
92+
################################################### start of config ###################################################
93+
94+
FP8=(
95+
# # train
96+
# +actor_rollout_ref.actor.megatron.override_transformer_config.fp8="e4m3" # e4m3 or hybrid
97+
# +actor_rollout_ref.actor.megatron.override_transformer_config.fp8_recipe="blockwise"
98+
# +actor_rollout_ref.actor.optim.override_optimizer_config.fp8_recipe="blockwise"
99+
# # rollout
100+
# +actor_rollout_ref.rollout.quantization="fp8"
101+
)
102+
103+
DATA=(
104+
# dddd
105+
data.train_files="${TRAIN_FILE}"
106+
data.val_files="${TEST_FILE}"
107+
data.prompt_key=prompt
108+
data.return_raw_chat=$return_raw_chat
109+
data.truncation='left'
110+
data.max_prompt_length=${max_prompt_length}
111+
data.max_response_length=${max_response_length}
112+
data.train_batch_size=${train_prompt_bsz}
113+
)
114+
115+
REWARD_MODEL=(
116+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}
117+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}
118+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}
119+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False
120+
+reward_model.reward_kwargs.max_resp_len=${max_response_length}
121+
reward_model.reward_manager=dapo
122+
)
123+
124+
PERF_OPT=(
125+
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
126+
actor_rollout_ref.actor.megatron.use_remove_padding=False
127+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
128+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
129+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
130+
actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto
131+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
132+
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
133+
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
134+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
135+
)
136+
137+
ACTOR=(
138+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
139+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
140+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
141+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
142+
actor_rollout_ref.actor.clip_ratio_c=10.0
143+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2
144+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz}
145+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len}
146+
actor_rollout_ref.actor.optim.lr=1e-6
147+
actor_rollout_ref.actor.optim.lr_warmup_steps=10
148+
actor_rollout_ref.actor.optim.weight_decay=0.1
149+
actor_rollout_ref.actor.optim.clip_grad=1.0
150+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
151+
actor_rollout_ref.actor.megatron.param_offload=${offload}
152+
actor_rollout_ref.actor.megatron.optimizer_offload=${offload}
153+
actor_rollout_ref.actor.megatron.grad_offload=${offload}
154+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp}
155+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp}
156+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
157+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
158+
actor_rollout_ref.actor.entropy_coeff=0
159+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
160+
actor_rollout_ref.actor.megatron.use_mbridge=True
161+
actor_rollout_ref.actor.megatron.vanilla_mbridge=False
162+
actor_rollout_ref.model.use_remove_padding=False
163+
)
164+
165+
ROLLOUT=(
166+
actor_rollout_ref.rollout.name=${rollout_name}
167+
actor_rollout_ref.rollout.mode=${rollout_mode}
168+
actor_rollout_ref.rollout.dtype=${dtype}
169+
actor_rollout_ref.rollout.gpu_memory_utilization=0.7
170+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp}
171+
actor_rollout_ref.rollout.enable_chunked_prefill=True
172+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length))
173+
actor_rollout_ref.rollout.temperature=${temperature}
174+
actor_rollout_ref.rollout.top_p=${top_p}
175+
actor_rollout_ref.rollout.top_k=${top_k}
176+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
177+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
178+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
179+
actor_rollout_ref.rollout.val_kwargs.do_sample=True
180+
actor_rollout_ref.rollout.val_kwargs.n=1
181+
actor_rollout_ref.rollout.calculate_log_probs=True
182+
actor_rollout_ref.rollout.n=${n_resp_per_prompt}
183+
)
184+
185+
TRAINER=(
186+
trainer.logger=['console','wandb']
187+
trainer.project_name="${project_name}"
188+
trainer.experiment_name="${exp_name}"
189+
trainer.n_gpus_per_node=8
190+
trainer.nnodes="${NNODES}"
191+
trainer.val_before_train=False
192+
trainer.test_freq=5
193+
trainer.save_freq=-1
194+
trainer.total_epochs=10
195+
trainer.default_local_dir="${CKPTS_DIR}"
196+
trainer.resume_mode=auto
197+
trainer.log_val_generations=10
198+
)
199+
200+
FORWARD_ONLY_SETS=(
201+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4
202+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4
203+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
204+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
205+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len}
206+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len}
207+
)
208+
209+
MODEL=(
210+
actor_rollout_ref.model.path="${MODEL_PATH}"
211+
)
212+
213+
ALGORITHM=(
214+
algorithm.adv_estimator=${adv_estimator}
215+
algorithm.use_kl_in_reward=${use_kl_in_reward}
216+
algorithm.kl_ctrl.kl_coef=${kl_coef}
217+
)
218+
################################################### start script ###################################################
219+
220+
python3 -m verl.trainer.main_ppo \
221+
--config-path=config \
222+
--config-name='ppo_megatron_trainer.yaml' \
223+
"${DATA[@]}" \
224+
"${ALGORITHM[@]}" \
225+
"${MODEL[@]}" \
226+
"${ROLLOUT[@]}" \
227+
"${ACTOR[@]}" \
228+
"${REWARD_MODEL[@]}" \
229+
"${FP8[@]}" \
230+
"${PERF_OPT[@]}" \
231+
"${TRAINER[@]}" \
232+
"${FORWARD_ONLY_SETS[@]}" \

0 commit comments

Comments
 (0)