Skip to content

Commit 84ffce6

Browse files
committed
readme
1 parent eb652bb commit 84ffce6

File tree

14 files changed

+2678
-215
lines changed

14 files changed

+2678
-215
lines changed

README.md

Lines changed: 125 additions & 215 deletions
Large diffs are not rendered by default.

figures/high_staleness.png

245 KB
Loading

figures/vcpo_multiturn.png

383 KB
Loading

figures/vcpo_results.png

1.11 MB
Loading
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#!/usr/bin/env bash
2+
#SBATCH --gpus-per-node=8
3+
#SBATCH --cpus-per-task=128
4+
#SBATCH --exclusive
5+
#SBATCH --nodes=1
6+
#SBATCH --ntasks-per-node=1
7+
#SBATCH --output=./slurm/%A_%x.out
8+
#SBATCH --error=./slurm/%A_%x.err
9+
#SBATCH --job-name=vcpo
10+
11+
set -xeuo pipefail
12+
13+
export CUDA_DEVICE_MAX_CONNECTIONS=1
14+
export RAY_DISABLE_IMPORT_WARNING=1
15+
export VLLM_USE_V1=1
16+
export RAY_ADDRESS="local"
17+
18+
# ================= Paths =================
19+
MODEL_PATH=${MODEL_PATH:-"models/Qwen2-1.5B"}
20+
TRAIN_FILE=${TRAIN_FILE:-"data/gsm8k/train.parquet"}
21+
TEST_FILE=${TEST_FILE:-"data/gsm8k/test.parquet"}
22+
23+
project_name='vcpo'
24+
25+
# ================= GPU Layout =================
26+
NNODES=${NNODES:-1}
27+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
28+
n_gpus_rollout=6
29+
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
30+
31+
# ================= Rollout =================
32+
rollout_mode="async"
33+
rollout_name="vllm"
34+
return_raw_chat="True"
35+
gen_tp=2
36+
n_resp_per_prompt=8
37+
gpu_memory_utilization=0.9
38+
enable_chunked_prefill=False
39+
calculate_log_probs=True
40+
41+
# ================= Sequence Lengths =================
42+
max_prompt_length=2048
43+
max_response_length=2048
44+
max_num_batched_tokens=$((max_prompt_length + max_response_length))
45+
46+
# ================= Megatron Parallelism =================
47+
train_tp=2
48+
train_pp=1
49+
train_cp=1
50+
sequence_parallel=True
51+
use_remove_padding=True
52+
precision_dtype="bfloat16"
53+
54+
# ================= Batch Sizes =================
55+
train_prompt_bsz=0
56+
gen_prompt_bsz=1
57+
train_prompt_mini_bsz=8
58+
micro_bsz_per_gpu=1
59+
use_dynamic_bsz=False
60+
log_prob_micro_bsz_per_gpu=1
61+
62+
# ================= Algorithm =================
63+
adv_estimator=grpo
64+
loss_agg_mode="seq-mean-token-mean"
65+
clip_ratio_low=1.0
66+
clip_ratio_high=1e9
67+
clip_ratio_c=1e9
68+
use_kl_loss=False
69+
kl_loss_coef=0.0
70+
use_kl_in_reward=False
71+
kl_coef=0.0
72+
entropy_coeff=0
73+
grad_clip=1.0
74+
75+
# ================= Optimizer =================
76+
lr=1e-6
77+
lr_warmup_steps=0
78+
weight_decay=0.1
79+
80+
# ================= IS / Rollout Correction =================
81+
rollout_is="sequence"
82+
rollout_is_threshold="8.0"
83+
rollout_rs=null
84+
rollout_rs_threshold=null
85+
86+
# ================= Synchronous Training =================
87+
staleness_threshold=0.0
88+
updates_per_param_sync=1
89+
num_minibatches_per_update=1
90+
partial_rollout=False
91+
use_rollout_log_probs=True
92+
93+
# Set to True to view per-trajectory gradient statistics
94+
update_policy_per_traj=False
95+
96+
# ================= Training/Rollout Steps =================
97+
total_rollout_steps=$((500 * num_minibatches_per_update * updates_per_param_sync * train_prompt_mini_bsz))
98+
epochs=10000000
99+
test_freq=10
100+
save_freq=-1
101+
102+
# ================= Logging =================
103+
exp_name="Synchronous GSM8k Qwen2-1.5B ${n_gpus_rollout}-${n_gpus_training} ${loss_agg_mode} ${max_response_length}-len ${weight_decay}-wd"
104+
exp_name_safe=${exp_name//\//_}
105+
log_dir="logs/${exp_name_safe}"
106+
CKPTS_DIR="${log_dir}"
107+
108+
trainer_logger="['console','wandb']"
109+
log_val_generations=0
110+
wandb_entity=${wandb_entity:-""}
111+
wandb_group=${wandb_group:-"vcpo-release"}
112+
val_before_train=False
113+
114+
# ================= LR decay =================
115+
lr_decay_style="constant"
116+
lr_decay_steps=${total_rollout_steps}
117+
118+
# ================= Run =================
119+
python -m recipe.fully_async_policy.fully_async_main \
120+
--config-name=fully_async_ppo_megatron_trainer.yaml \
121+
data.train_files="${TRAIN_FILE}" \
122+
data.val_files="${TEST_FILE}" \
123+
data.prompt_key=prompt \
124+
data.truncation='left' \
125+
data.max_prompt_length=${max_prompt_length} \
126+
data.max_response_length=${max_response_length} \
127+
data.train_batch_size=${train_prompt_bsz} \
128+
data.gen_batch_size=${gen_prompt_bsz} \
129+
data.return_raw_chat=${return_raw_chat} \
130+
data.filter_overlong_prompts=True \
131+
data.filter_overlong_prompts_workers=8 \
132+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
133+
algorithm.adv_estimator=${adv_estimator} \
134+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
135+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
136+
algorithm.rollout_correction.rollout_is=${rollout_is} \
137+
algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \
138+
algorithm.rollout_correction.rollout_rs=${rollout_rs} \
139+
algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \
140+
actor_rollout_ref.actor.strategy=megatron \
141+
critic.strategy=megatron \
142+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
143+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
144+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
145+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
146+
actor_rollout_ref.actor.clip_ratio_c=${clip_ratio_c} \
147+
actor_rollout_ref.model.path="${MODEL_PATH}" \
148+
actor_rollout_ref.model.use_remove_padding=${use_remove_padding} \
149+
actor_rollout_ref.hybrid_engine=False \
150+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
151+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
152+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_bsz_per_gpu} \
153+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
154+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
155+
actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \
156+
actor_rollout_ref.actor.megatron.sequence_parallel=${sequence_parallel} \
157+
actor_rollout_ref.actor.megatron.dtype=${precision_dtype} \
158+
actor_rollout_ref.actor.megatron.use_remove_padding=${use_remove_padding} \
159+
actor_rollout_ref.actor.megatron.param_offload=False \
160+
actor_rollout_ref.actor.megatron.optimizer_offload=False \
161+
actor_rollout_ref.actor.megatron.grad_offload=False \
162+
actor_rollout_ref.actor.optim.lr=${lr} \
163+
actor_rollout_ref.actor.optim.lr_warmup_steps=${lr_warmup_steps} \
164+
actor_rollout_ref.actor.optim.lr_decay_style=${lr_decay_style} \
165+
actor_rollout_ref.actor.optim.lr_decay_steps=${lr_decay_steps} \
166+
actor_rollout_ref.actor.optim.weight_decay=${weight_decay} \
167+
actor_rollout_ref.actor.optim.clip_grad=${grad_clip} \
168+
actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \
169+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
170+
actor_rollout_ref.actor.use_rollout_log_probs=${use_rollout_log_probs} \
171+
actor_rollout_ref.actor.update_policy_per_traj=${update_policy_per_traj} \
172+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
173+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
174+
actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \
175+
actor_rollout_ref.ref.megatron.sequence_parallel=${sequence_parallel} \
176+
actor_rollout_ref.ref.megatron.dtype=${precision_dtype} \
177+
actor_rollout_ref.ref.megatron.use_remove_padding=${use_remove_padding} \
178+
actor_rollout_ref.ref.megatron.param_offload=True \
179+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
180+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \
181+
actor_rollout_ref.rollout.name=${rollout_name} \
182+
actor_rollout_ref.rollout.mode=${rollout_mode} \
183+
actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \
184+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
185+
actor_rollout_ref.rollout.dtype=${precision_dtype} \
186+
actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \
187+
actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \
188+
actor_rollout_ref.rollout.temperature=1.0 \
189+
actor_rollout_ref.rollout.top_p=1.0 \
190+
actor_rollout_ref.rollout.top_k=-1 \
191+
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
192+
actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \
193+
actor_rollout_ref.rollout.val_kwargs.top_k=-1 \
194+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
195+
actor_rollout_ref.rollout.val_kwargs.n=3 \
196+
actor_rollout_ref.rollout.calculate_log_probs=${calculate_log_probs} \
197+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
198+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \
199+
critic.megatron.tensor_model_parallel_size=${train_tp} \
200+
critic.megatron.pipeline_model_parallel_size=${train_pp} \
201+
critic.megatron.context_parallel_size=${train_cp} \
202+
critic.megatron.sequence_parallel=${sequence_parallel} \
203+
critic.megatron.dtype=${precision_dtype} \
204+
trainer.logger=${trainer_logger} \
205+
trainer.project_name="${project_name}" \
206+
trainer.experiment_name="${exp_name}" \
207+
+trainer.wandb_entity="${wandb_entity}" \
208+
+trainer.wandb_group="${wandb_group}" \
209+
trainer.val_before_train=${val_before_train} \
210+
trainer.save_freq=${save_freq} \
211+
trainer.rollout_data_dir="${log_dir}" \
212+
trainer.log_val_generations=${log_val_generations} \
213+
trainer.default_local_dir="${CKPTS_DIR}" \
214+
trainer.nnodes="${NNODES}" \
215+
trainer.n_gpus_per_node="${n_gpus_training}" \
216+
rollout.nnodes="${NNODES}" \
217+
rollout.n_gpus_per_node="${n_gpus_rollout}" \
218+
rollout.total_rollout_steps="${total_rollout_steps}" \
219+
rollout.total_epochs="${epochs}" \
220+
rollout.test_freq="${test_freq}" \
221+
async_training.staleness_threshold="${staleness_threshold}" \
222+
async_training.trigger_parameter_sync_step="${updates_per_param_sync}" \
223+
async_training.require_batches="${num_minibatches_per_update}" \
224+
async_training.partial_rollout="${partial_rollout}" \
225+
async_training.compute_prox_log_prob=True \
226+
async_training.use_rollout_log_probs="${use_rollout_log_probs}" \
227+
2>&1 | tee -a "${run_log_file}"

0 commit comments

Comments
 (0)