|
| 1 | +# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) |
| 2 | + |
| 3 | +> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) |
| 4 | +
|
| 5 | +🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) |
| 6 | + |
| 7 | + |
| 8 | +> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. |
| 9 | +> |
| 10 | +>  |
| 11 | +
|
| 12 | +## Quickstart |
| 13 | + |
| 14 | +1. Prepare the datasets **on the Ray cluster**: |
| 15 | + |
| 16 | +```bash |
| 17 | +bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default |
| 18 | +``` |
| 19 | + |
| 20 | +2. Submit the job to the Ray cluster **from any machine**: |
| 21 | + |
| 22 | +```bash |
| 23 | +cd verl # Repo root |
| 24 | +export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to |
| 25 | +export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster |
| 26 | +# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml |
| 27 | +export RUNTIME_ENV="./verl/trainer/runtime_env.yaml" |
| 28 | +bash recipe/dapo/run_dapo_qwen2.5_32b.sh |
| 29 | +``` |
| 30 | + |
| 31 | +## Reproduction Runs |
| 32 | + |
| 33 | +| Setup | AIME 2024 Acc. | Training Script | Training Record | |
| 34 | +| -------------------------------------------- | -------------- | ---------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | |
| 35 | +| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | |
| 36 | +| DAPO w/o Dynamic Sampling | 50% | [run_dapo_wo_ds_qwen2.5_32b.sh](./run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | |
| 37 | +| DAPO | 52% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | |
| 38 | + |
| 39 | +## Configuration |
| 40 | + |
| 41 | +> [!NOTE] |
| 42 | +> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. |
| 43 | +
|
| 44 | +### Separated Clip Epsilons (-> Clip-Higher) |
| 45 | + |
| 46 | +An example configuration: |
| 47 | + |
| 48 | +```yaml |
| 49 | +actor_rollout_ref: |
| 50 | + actor: |
| 51 | + clip_ratio_low: 0.2 |
| 52 | + clip_ratio_high: 0.28 |
| 53 | +``` |
| 54 | +
|
| 55 | +`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective. |
| 56 | + |
| 57 | +Core relevant code: |
| 58 | + |
| 59 | +```python |
| 60 | +pg_losses1 = -advantages * ratio |
| 61 | +pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) |
| 62 | +pg_losses = torch.maximum(pg_losses1, pg_losses2) |
| 63 | +``` |
| 64 | + |
| 65 | +### Dynamic Sampling (with Group Filtering) |
| 66 | + |
| 67 | +An example configuration: |
| 68 | + |
| 69 | +```yaml |
| 70 | +data: |
| 71 | + gen_batch_size: 1536 |
| 72 | + train_batch_size: 512 |
| 73 | +algorithm: |
| 74 | + filter_groups: |
| 75 | + enable: True |
| 76 | + metric: acc # score / seq_reward / seq_final_reward / ... |
| 77 | + max_num_gen_batches: 10 # Non-positive values mean no upper limit |
| 78 | +``` |
| 79 | + |
| 80 | +Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0. |
| 81 | + |
| 82 | +The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`. |
| 83 | + |
| 84 | +Core relevant code: |
| 85 | + |
| 86 | +```python |
| 87 | +prompt_bsz = self.config.data.train_batch_size |
| 88 | +if num_prompt_in_batch < prompt_bsz: |
| 89 | + print(f'{num_prompt_in_batch=} < {prompt_bsz=}') |
| 90 | + num_gen_batches += 1 |
| 91 | + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches |
| 92 | + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: |
| 93 | + print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') |
| 94 | + continue |
| 95 | + else: |
| 96 | + raise ValueError( |
| 97 | + f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' |
| 98 | + ) |
| 99 | +else: |
| 100 | + # Align the batch |
| 101 | + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n |
| 102 | + batch = batch[:traj_bsz] |
| 103 | +``` |
| 104 | + |
| 105 | +### Flexible Loss Aggregation Mode (-> Token-level Loss) |
| 106 | + |
| 107 | +An example configuration: |
| 108 | + |
| 109 | +```yaml |
| 110 | +actor_rollout_ref: |
| 111 | + actor: |
| 112 | + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" |
| 113 | + # NOTE: "token-mean" is the default behavior |
| 114 | +``` |
| 115 | + |
| 116 | +Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch. |
| 117 | + |
| 118 | +Core relevant code: |
| 119 | + |
| 120 | +```python |
| 121 | +if loss_agg_mode == "token-mean": |
| 122 | + loss = verl_F.masked_mean(loss_mat, loss_mask) |
| 123 | +elif loss_agg_mode == "seq-mean-token-sum": |
| 124 | + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum |
| 125 | + loss = torch.mean(seq_losses) # seq-mean |
| 126 | +elif loss_agg_mode == "seq-mean-token-mean": |
| 127 | + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean |
| 128 | + loss = torch.mean(seq_losses) # seq-mean |
| 129 | +else: |
| 130 | + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") |
| 131 | +``` |
| 132 | + |
| 133 | +### Overlong Reward Shaping |
| 134 | + |
| 135 | +An example configuration: |
| 136 | + |
| 137 | +```yaml |
| 138 | +data: |
| 139 | + max_response_length: 20480 # 16384 + 4096 |
| 140 | +reward_model: |
| 141 | + overlong_buffer: |
| 142 | + enable: True |
| 143 | + len: 4096 |
| 144 | + penalty_factor: 1.0 |
| 145 | +``` |
| 146 | + |
| 147 | +Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit. |
| 148 | + |
| 149 | +Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens. |
| 150 | + |
| 151 | +Core relevant code: |
| 152 | + |
| 153 | +```python |
| 154 | +if self.overlong_buffer_cfg.enable: |
| 155 | + overlong_buffer_len = self.overlong_buffer_cfg.len |
| 156 | + expected_len = self.max_resp_len - overlong_buffer_len |
| 157 | + exceed_len = valid_response_length - expected_len |
| 158 | + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor |
| 159 | + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) |
| 160 | + reward += overlong_reward |
| 161 | +``` |
0 commit comments