1+ #! /usr/bin/env bash
2+ set -xeuo pipefail
3+
4+
5+ # ################################################## document for qwen3next ###################################################
6+
7+ # ###################### running environment: #######################
8+
9+ # option 1: use pre-built docker images verlai/verl:vll012.exp or verlai/verl:sgl056.exp
10+
11+ # option 2: self build TE>=2.8, megatron with dev branch and megatron-bridge with main branch
12+
13+ # ###################### how we support qwen3next? #######################
14+ # we support qwen3next with megatron-bridge, which is enabled by set `vanilla_mbridge=False`
15+
16+ # ###################### limitations: #######################
17+ # 1. context parallel(CP) is not supported until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2614
18+ # 2. sequence packing(aka thd) is not supported, we must set `actor_rollout_ref.actor.megatron.use_remove_padding=False`, until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2644
19+
20+ # # if sequence packing is disabled, we recommend to set `use_dynamic_bsz=False` and set micro batchsize to 1,
21+ # # otherwise the data will be padded to the max length of the batch, which is not efficient. But it's not mandatory
22+
23+
24+
25+
26+ # ################################################## quick config ###################################################
27+
28+ # pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@dev # install megatron from dev branch
29+ # pip install --no-deps git+https://github.com/NVIDIA-Nemo/Megatron-Bridge.git # install megatron-bridge from main branch
30+
31+
32+ rollout_mode=" async"
33+ return_raw_chat=" True"
34+ export VLLM_USE_V1=1
35+ rollout_name=" vllm" # sglang or vllm
36+ dtype=" bfloat16"
37+
38+
39+ project_name=' DAPO-test'
40+ exp_name=' qwen3next'
41+
42+ adv_estimator=grpo
43+
44+ use_kl_in_reward=False
45+ kl_coef=0.0
46+ use_kl_loss=False
47+ kl_loss_coef=0.0
48+
49+ clip_ratio_low=0.2
50+ clip_ratio_high=0.28
51+
52+ max_prompt_length=$(( 1024 * 2 ))
53+ max_response_length=$(( 1024 * 8 ))
54+ enable_overlong_buffer=True
55+ overlong_buffer_len=$(( 1024 * 4 ))
56+ overlong_penalty_factor=1.0
57+
58+ loss_agg_mode=" token-mean"
59+
60+ train_prompt_bsz=32
61+ n_resp_per_prompt=16
62+ train_prompt_mini_bsz=32
63+
64+ # Ray
65+ RAY_ADDRESS=${RAY_ADDRESS:- " http://localhost:8265" }
66+ WORKING_DIR=${WORKING_DIR:- " ${PWD} " }
67+ RUNTIME_ENV=${RUNTIME_ENV:- " ${WORKING_DIR} /verl/verl/trainer/runtime_env.yaml" }
68+ NNODES=${NNODES:- 4}
69+ # Paths
70+ MODEL_PATH=${MODEL_PATH:- " ${RAY_DATA_HOME} /models/Qwen3-Next-80B-A3B-Instruct" }
71+ CKPTS_DIR=${CKPTS_DIR:- " ${RAY_DATA_HOME} /ckpts/${project_name} /${exp_name} " }
72+ TRAIN_FILE=${TRAIN_FILE:- " ${RAY_DATA_HOME} /data/dapo-math-17k.parquet" }
73+ TEST_FILE=${TEST_FILE:- " ${RAY_DATA_HOME} /data/aime-2024.parquet" }
74+
75+ # Algorithm
76+ temperature=1.0
77+ top_p=1.0
78+ top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
79+ val_top_p=0.7
80+
81+ # Performance Related Parameter
82+ use_dynamic_bsz=False
83+ actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 10 / 10 ))
84+ infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 ))
85+ offload=True
86+ gen_tp=16
87+ train_tp=2
88+ EP=32
89+ ETP=1
90+ train_pp=1
91+
92+ # ################################################## start of config ###################################################
93+
94+ FP8=(
95+ # # train
96+ # +actor_rollout_ref.actor.megatron.override_transformer_config.fp8="e4m3" # e4m3 or hybrid
97+ # +actor_rollout_ref.actor.megatron.override_transformer_config.fp8_recipe="blockwise"
98+ # +actor_rollout_ref.actor.optim.override_optimizer_config.fp8_recipe="blockwise"
99+ # # rollout
100+ # +actor_rollout_ref.rollout.quantization="fp8"
101+ )
102+
103+ DATA=(
104+ # dddd
105+ data.train_files=" ${TRAIN_FILE} "
106+ data.val_files=" ${TEST_FILE} "
107+ data.prompt_key=prompt
108+ data.return_raw_chat=$return_raw_chat
109+ data.truncation=' left'
110+ data.max_prompt_length=${max_prompt_length}
111+ data.max_response_length=${max_response_length}
112+ data.train_batch_size=${train_prompt_bsz}
113+ )
114+
115+ REWARD_MODEL=(
116+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}
117+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}
118+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}
119+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False
120+ +reward_model.reward_kwargs.max_resp_len=${max_response_length}
121+ reward_model.reward_manager=dapo
122+ )
123+
124+ PERF_OPT=(
125+ +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
126+ actor_rollout_ref.actor.megatron.use_remove_padding=False
127+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
128+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
129+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
130+ actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto
131+ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
132+ +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
133+ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
134+ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
135+ )
136+
137+ ACTOR=(
138+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
139+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
140+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
141+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
142+ actor_rollout_ref.actor.clip_ratio_c=10.0
143+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2
144+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz}
145+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len}
146+ actor_rollout_ref.actor.optim.lr=1e-6
147+ actor_rollout_ref.actor.optim.lr_warmup_steps=10
148+ actor_rollout_ref.actor.optim.weight_decay=0.1
149+ actor_rollout_ref.actor.optim.clip_grad=1.0
150+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
151+ actor_rollout_ref.actor.megatron.param_offload=${offload}
152+ actor_rollout_ref.actor.megatron.optimizer_offload=${offload}
153+ actor_rollout_ref.actor.megatron.grad_offload=${offload}
154+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp}
155+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp}
156+ actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
157+ actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
158+ actor_rollout_ref.actor.entropy_coeff=0
159+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
160+ actor_rollout_ref.actor.megatron.use_mbridge=True
161+ actor_rollout_ref.actor.megatron.vanilla_mbridge=False
162+ actor_rollout_ref.model.use_remove_padding=False
163+ )
164+
165+ ROLLOUT=(
166+ actor_rollout_ref.rollout.name=${rollout_name}
167+ actor_rollout_ref.rollout.mode=${rollout_mode}
168+ actor_rollout_ref.rollout.dtype=${dtype}
169+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7
170+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp}
171+ actor_rollout_ref.rollout.enable_chunked_prefill=True
172+ actor_rollout_ref.rollout.max_num_batched_tokens=$(( max_prompt_length + max_response_length))
173+ actor_rollout_ref.rollout.temperature=${temperature}
174+ actor_rollout_ref.rollout.top_p=${top_p}
175+ actor_rollout_ref.rollout.top_k=${top_k}
176+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
177+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
178+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
179+ actor_rollout_ref.rollout.val_kwargs.do_sample=True
180+ actor_rollout_ref.rollout.val_kwargs.n=1
181+ actor_rollout_ref.rollout.calculate_log_probs=True
182+ actor_rollout_ref.rollout.n=${n_resp_per_prompt}
183+ )
184+
185+ TRAINER=(
186+ trainer.logger=[' console' ,' wandb' ]
187+ trainer.project_name=" ${project_name} "
188+ trainer.experiment_name=" ${exp_name} "
189+ trainer.n_gpus_per_node=8
190+ trainer.nnodes=" ${NNODES} "
191+ trainer.val_before_train=False
192+ trainer.test_freq=5
193+ trainer.save_freq=-1
194+ trainer.total_epochs=10
195+ trainer.default_local_dir=" ${CKPTS_DIR} "
196+ trainer.resume_mode=auto
197+ trainer.log_val_generations=10
198+ )
199+
200+ FORWARD_ONLY_SETS=(
201+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4
202+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4
203+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
204+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
205+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len}
206+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len}
207+ )
208+
209+ MODEL=(
210+ actor_rollout_ref.model.path=" ${MODEL_PATH} "
211+ )
212+
213+ ALGORITHM=(
214+ algorithm.adv_estimator=${adv_estimator}
215+ algorithm.use_kl_in_reward=${use_kl_in_reward}
216+ algorithm.kl_ctrl.kl_coef=${kl_coef}
217+ )
218+ # ################################################## start script ###################################################
219+
220+ python3 -m verl.trainer.main_ppo \
221+ --config-path=config \
222+ --config-name=' ppo_megatron_trainer.yaml' \
223+ " ${DATA[@]} " \
224+ " ${ALGORITHM[@]} " \
225+ " ${MODEL[@]} " \
226+ " ${ROLLOUT[@]} " \
227+ " ${ACTOR[@]} " \
228+ " ${REWARD_MODEL[@]} " \
229+ " ${FP8[@]} " \
230+ " ${PERF_OPT[@]} " \
231+ " ${TRAINER[@]} " \
232+ " ${FORWARD_ONLY_SETS[@]} " \
0 commit comments