|
| 1 | +#!/usr/bin/env bash |
| 2 | +#SBATCH --gpus-per-node=8 |
| 3 | +#SBATCH --cpus-per-task=128 |
| 4 | +#SBATCH --exclusive |
| 5 | +#SBATCH --nodes=1 |
| 6 | +#SBATCH --ntasks-per-node=1 |
| 7 | +#SBATCH --output=./slurm/%A_%x.out |
| 8 | +#SBATCH --error=./slurm/%A_%x.err |
| 9 | +#SBATCH --job-name=vcpo |
| 10 | + |
| 11 | +set -xeuo pipefail |
| 12 | + |
| 13 | +export CUDA_DEVICE_MAX_CONNECTIONS=1 |
| 14 | +export RAY_DISABLE_IMPORT_WARNING=1 |
| 15 | +export VLLM_USE_V1=1 |
| 16 | +export RAY_ADDRESS="local" |
| 17 | + |
| 18 | +# ================= Paths ================= |
| 19 | +MODEL_PATH=${MODEL_PATH:-"models/Qwen2-1.5B"} |
| 20 | +TRAIN_FILE=${TRAIN_FILE:-"data/gsm8k/train.parquet"} |
| 21 | +TEST_FILE=${TEST_FILE:-"data/gsm8k/test.parquet"} |
| 22 | + |
| 23 | +project_name='vcpo' |
| 24 | + |
| 25 | +# ================= GPU Layout ================= |
| 26 | +NNODES=${NNODES:-1} |
| 27 | +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} |
| 28 | +n_gpus_rollout=6 |
| 29 | +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) |
| 30 | + |
| 31 | +# ================= Rollout ================= |
| 32 | +rollout_mode="async" |
| 33 | +rollout_name="vllm" |
| 34 | +return_raw_chat="True" |
| 35 | +gen_tp=2 |
| 36 | +n_resp_per_prompt=8 |
| 37 | +gpu_memory_utilization=0.9 |
| 38 | +enable_chunked_prefill=False |
| 39 | +calculate_log_probs=True |
| 40 | + |
| 41 | +# ================= Sequence Lengths ================= |
| 42 | +max_prompt_length=2048 |
| 43 | +max_response_length=2048 |
| 44 | +max_num_batched_tokens=$((max_prompt_length + max_response_length)) |
| 45 | + |
| 46 | +# ================= Megatron Parallelism ================= |
| 47 | +train_tp=2 |
| 48 | +train_pp=1 |
| 49 | +train_cp=1 |
| 50 | +sequence_parallel=True |
| 51 | +use_remove_padding=True |
| 52 | +precision_dtype="bfloat16" |
| 53 | + |
| 54 | +# ================= Batch Sizes ================= |
| 55 | +train_prompt_bsz=0 |
| 56 | +gen_prompt_bsz=1 |
| 57 | +train_prompt_mini_bsz=8 |
| 58 | +micro_bsz_per_gpu=1 |
| 59 | +use_dynamic_bsz=False |
| 60 | +log_prob_micro_bsz_per_gpu=1 |
| 61 | + |
| 62 | +# ================= Algorithm ================= |
| 63 | +adv_estimator=grpo |
| 64 | +loss_agg_mode="seq-mean-token-mean" |
| 65 | +clip_ratio_low=1.0 |
| 66 | +clip_ratio_high=1e9 |
| 67 | +clip_ratio_c=1e9 |
| 68 | +use_kl_loss=False |
| 69 | +kl_loss_coef=0.0 |
| 70 | +use_kl_in_reward=False |
| 71 | +kl_coef=0.0 |
| 72 | +entropy_coeff=0 |
| 73 | +grad_clip=1.0 |
| 74 | + |
| 75 | +# ================= Optimizer ================= |
| 76 | +lr=1e-6 |
| 77 | +lr_warmup_steps=0 |
| 78 | +weight_decay=0.1 |
| 79 | + |
| 80 | +# ================= IS / Rollout Correction ================= |
| 81 | +rollout_is="sequence" |
| 82 | +rollout_is_threshold="8.0" |
| 83 | +rollout_rs=null |
| 84 | +rollout_rs_threshold=null |
| 85 | + |
| 86 | +# ================= Synchronous Training ================= |
| 87 | +staleness_threshold=0.0 |
| 88 | +updates_per_param_sync=1 |
| 89 | +num_minibatches_per_update=1 |
| 90 | +partial_rollout=False |
| 91 | +use_rollout_log_probs=True |
| 92 | + |
| 93 | +# Set to True to view per-trajectory gradient statistics |
| 94 | +update_policy_per_traj=False |
| 95 | + |
| 96 | +# ================= Training/Rollout Steps ================= |
| 97 | +total_rollout_steps=$((500 * num_minibatches_per_update * updates_per_param_sync * train_prompt_mini_bsz)) |
| 98 | +epochs=10000000 |
| 99 | +test_freq=10 |
| 100 | +save_freq=-1 |
| 101 | + |
| 102 | +# ================= Logging ================= |
| 103 | +exp_name="Synchronous GSM8k Qwen2-1.5B ${n_gpus_rollout}-${n_gpus_training} ${loss_agg_mode} ${max_response_length}-len ${weight_decay}-wd" |
| 104 | +exp_name_safe=${exp_name//\//_} |
| 105 | +log_dir="logs/${exp_name_safe}" |
| 106 | +CKPTS_DIR="${log_dir}" |
| 107 | + |
| 108 | +trainer_logger="['console','wandb']" |
| 109 | +log_val_generations=0 |
| 110 | +wandb_entity=${wandb_entity:-""} |
| 111 | +wandb_group=${wandb_group:-"vcpo-release"} |
| 112 | +val_before_train=False |
| 113 | + |
| 114 | +# ================= LR decay ================= |
| 115 | +lr_decay_style="constant" |
| 116 | +lr_decay_steps=${total_rollout_steps} |
| 117 | + |
| 118 | +# ================= Run ================= |
| 119 | +python -m recipe.fully_async_policy.fully_async_main \ |
| 120 | + --config-name=fully_async_ppo_megatron_trainer.yaml \ |
| 121 | + data.train_files="${TRAIN_FILE}" \ |
| 122 | + data.val_files="${TEST_FILE}" \ |
| 123 | + data.prompt_key=prompt \ |
| 124 | + data.truncation='left' \ |
| 125 | + data.max_prompt_length=${max_prompt_length} \ |
| 126 | + data.max_response_length=${max_response_length} \ |
| 127 | + data.train_batch_size=${train_prompt_bsz} \ |
| 128 | + data.gen_batch_size=${gen_prompt_bsz} \ |
| 129 | + data.return_raw_chat=${return_raw_chat} \ |
| 130 | + data.filter_overlong_prompts=True \ |
| 131 | + data.filter_overlong_prompts_workers=8 \ |
| 132 | + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ |
| 133 | + algorithm.adv_estimator=${adv_estimator} \ |
| 134 | + algorithm.use_kl_in_reward=${use_kl_in_reward} \ |
| 135 | + algorithm.kl_ctrl.kl_coef=${kl_coef} \ |
| 136 | + algorithm.rollout_correction.rollout_is=${rollout_is} \ |
| 137 | + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ |
| 138 | + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ |
| 139 | + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ |
| 140 | + actor_rollout_ref.actor.strategy=megatron \ |
| 141 | + critic.strategy=megatron \ |
| 142 | + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ |
| 143 | + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ |
| 144 | + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ |
| 145 | + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ |
| 146 | + actor_rollout_ref.actor.clip_ratio_c=${clip_ratio_c} \ |
| 147 | + actor_rollout_ref.model.path="${MODEL_PATH}" \ |
| 148 | + actor_rollout_ref.model.use_remove_padding=${use_remove_padding} \ |
| 149 | + actor_rollout_ref.hybrid_engine=False \ |
| 150 | + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ |
| 151 | + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ |
| 152 | + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_bsz_per_gpu} \ |
| 153 | + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ |
| 154 | + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ |
| 155 | + actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \ |
| 156 | + actor_rollout_ref.actor.megatron.sequence_parallel=${sequence_parallel} \ |
| 157 | + actor_rollout_ref.actor.megatron.dtype=${precision_dtype} \ |
| 158 | + actor_rollout_ref.actor.megatron.use_remove_padding=${use_remove_padding} \ |
| 159 | + actor_rollout_ref.actor.megatron.param_offload=False \ |
| 160 | + actor_rollout_ref.actor.megatron.optimizer_offload=False \ |
| 161 | + actor_rollout_ref.actor.megatron.grad_offload=False \ |
| 162 | + actor_rollout_ref.actor.optim.lr=${lr} \ |
| 163 | + actor_rollout_ref.actor.optim.lr_warmup_steps=${lr_warmup_steps} \ |
| 164 | + actor_rollout_ref.actor.optim.lr_decay_style=${lr_decay_style} \ |
| 165 | + actor_rollout_ref.actor.optim.lr_decay_steps=${lr_decay_steps} \ |
| 166 | + actor_rollout_ref.actor.optim.weight_decay=${weight_decay} \ |
| 167 | + actor_rollout_ref.actor.optim.clip_grad=${grad_clip} \ |
| 168 | + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ |
| 169 | + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ |
| 170 | + actor_rollout_ref.actor.use_rollout_log_probs=${use_rollout_log_probs} \ |
| 171 | + actor_rollout_ref.actor.update_policy_per_traj=${update_policy_per_traj} \ |
| 172 | + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ |
| 173 | + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ |
| 174 | + actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \ |
| 175 | + actor_rollout_ref.ref.megatron.sequence_parallel=${sequence_parallel} \ |
| 176 | + actor_rollout_ref.ref.megatron.dtype=${precision_dtype} \ |
| 177 | + actor_rollout_ref.ref.megatron.use_remove_padding=${use_remove_padding} \ |
| 178 | + actor_rollout_ref.ref.megatron.param_offload=True \ |
| 179 | + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ |
| 180 | + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \ |
| 181 | + actor_rollout_ref.rollout.name=${rollout_name} \ |
| 182 | + actor_rollout_ref.rollout.mode=${rollout_mode} \ |
| 183 | + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \ |
| 184 | + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ |
| 185 | + actor_rollout_ref.rollout.dtype=${precision_dtype} \ |
| 186 | + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ |
| 187 | + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ |
| 188 | + actor_rollout_ref.rollout.temperature=1.0 \ |
| 189 | + actor_rollout_ref.rollout.top_p=1.0 \ |
| 190 | + actor_rollout_ref.rollout.top_k=-1 \ |
| 191 | + actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \ |
| 192 | + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ |
| 193 | + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ |
| 194 | + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ |
| 195 | + actor_rollout_ref.rollout.val_kwargs.n=3 \ |
| 196 | + actor_rollout_ref.rollout.calculate_log_probs=${calculate_log_probs} \ |
| 197 | + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ |
| 198 | + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \ |
| 199 | + critic.megatron.tensor_model_parallel_size=${train_tp} \ |
| 200 | + critic.megatron.pipeline_model_parallel_size=${train_pp} \ |
| 201 | + critic.megatron.context_parallel_size=${train_cp} \ |
| 202 | + critic.megatron.sequence_parallel=${sequence_parallel} \ |
| 203 | + critic.megatron.dtype=${precision_dtype} \ |
| 204 | + trainer.logger=${trainer_logger} \ |
| 205 | + trainer.project_name="${project_name}" \ |
| 206 | + trainer.experiment_name="${exp_name}" \ |
| 207 | + +trainer.wandb_entity="${wandb_entity}" \ |
| 208 | + +trainer.wandb_group="${wandb_group}" \ |
| 209 | + trainer.val_before_train=${val_before_train} \ |
| 210 | + trainer.save_freq=${save_freq} \ |
| 211 | + trainer.rollout_data_dir="${log_dir}" \ |
| 212 | + trainer.log_val_generations=${log_val_generations} \ |
| 213 | + trainer.default_local_dir="${CKPTS_DIR}" \ |
| 214 | + trainer.nnodes="${NNODES}" \ |
| 215 | + trainer.n_gpus_per_node="${n_gpus_training}" \ |
| 216 | + rollout.nnodes="${NNODES}" \ |
| 217 | + rollout.n_gpus_per_node="${n_gpus_rollout}" \ |
| 218 | + rollout.total_rollout_steps="${total_rollout_steps}" \ |
| 219 | + rollout.total_epochs="${epochs}" \ |
| 220 | + rollout.test_freq="${test_freq}" \ |
| 221 | + async_training.staleness_threshold="${staleness_threshold}" \ |
| 222 | + async_training.trigger_parameter_sync_step="${updates_per_param_sync}" \ |
| 223 | + async_training.require_batches="${num_minibatches_per_update}" \ |
| 224 | + async_training.partial_rollout="${partial_rollout}" \ |
| 225 | + async_training.compute_prox_log_prob=True \ |
| 226 | + async_training.use_rollout_log_probs="${use_rollout_log_probs}" \ |
| 227 | + 2>&1 | tee -a "${run_log_file}" |
0 commit comments