|
| 1 | +# Rollout Correction Methods |
| 2 | + |
| 3 | +Rollout correction (e.g, TIS, MIS) through algorithmic methods. |
| 4 | + |
| 5 | + |
| 6 | +## Quick Takeaway |
| 7 | + |
| 8 | +This function is used to solve offline scenarios through algorithmic adaptations, e.g. TIS/MIS. |
| 9 | + |
| 10 | +We included 3 rollout correction algorithms: |
| 11 | + |
| 12 | +1. decoupled, 3-policies PPO with rollout importance sampling |
| 13 | +2. direct rollout policy overwriting in the standard PPO |
| 14 | +3. pure REINFORCE loss (without PPO clipping) with rollout importance sampling |
| 15 | + |
| 16 | + |
| 17 | +`--use-tis`: use this flag to **turn on TIS/MIS** for rollout correction (details in **Algorithms**). |
| 18 | +You may specify the **IS/RS configs** with a config file using `--custom-config-path`. |
| 19 | + |
| 20 | +`--use-rollout-logprobs`: When use this flag, the logprobs will **not** be recomputed by training engine - rollout log probs will be directly used in PPO/GRPO loss. |
| 21 | + |
| 22 | +`--get-mismatch-metrics`: When you don't want to add TIS/MIS, but still want to monitor the mismatch-related metrics (e.g. rollout-training KL). It will **only return mismatch metrics** but not change the loss in any way. |
| 23 | + |
| 24 | + |
| 25 | +## Algorithms |
| 26 | + |
| 27 | +We give examples of the algorithms for solving the training-inference mismatch issue. |
| 28 | + |
| 29 | +### [Baseline: No Mismatch Correction] Standard PPO |
| 30 | + |
| 31 | +This is the basic PPO algorithm with potentially training-inference mismatch issue when the output of SGLang and Megatron does not exactly match. |
| 32 | + |
| 33 | +$$ |
| 34 | +L_{\text{PPO}}(\theta) |
| 35 | += - \mathbb{E}_{x \sim \mathcal{D},\, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[ |
| 36 | + \min \left( |
| 37 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{Megatron}}}(y \mid x)} A_t, |
| 38 | + \mathrm{clip}\left( |
| 39 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{Megatron}}}(y \mid x)}, |
| 40 | + 1 - \epsilon, |
| 41 | + 1 + \epsilon |
| 42 | + \right) A_t |
| 43 | + \right) |
| 44 | +\right]. |
| 45 | +$$ |
| 46 | + |
| 47 | +### Bypassing PPO importance sampling |
| 48 | + |
| 49 | +Like REINFORCE, we directly use the rollout engine's log probs as the old policy in offline PPO's importance sampling, rather than the recomputed log-probs from the training engine. |
| 50 | + |
| 51 | +$$ |
| 52 | +L_{\text{PPO-bypass}}(\theta) |
| 53 | += - \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[ |
| 54 | + \min \left( |
| 55 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)} A_t, |
| 56 | + \mathrm{clip}\left( |
| 57 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)}, |
| 58 | + 1 - \epsilon, |
| 59 | + 1 + \epsilon |
| 60 | + \right) A_t |
| 61 | + \right) |
| 62 | +\right]. |
| 63 | +$$ |
| 64 | + |
| 65 | +Advantages: |
| 66 | + |
| 67 | +- Efficiency: skip `log_prob` recomputation on training engine. Reduce one expensive forward pass on all the generated trajectories. |
| 68 | + |
| 69 | +### Decoupled, 3-policy PPO Importance Sampling |
| 70 | + |
| 71 | +[Decoupled PPO](https://arxiv.org/pdf/2110.00641) achieves batch-independent PPO by decoupling two roles: Proximal Policy (anchor policy for PPO clipping, control update size) and Behavior Policy (for off-policy correction in importance sampling). Therefore, there are totally 3 roles engaged in this mode, **target policy** $\pi_\theta$, **proximal policy** $\pi_{\textcolor{blue}{\text{old}}}$, and **behavior policy** $\pi_{\textcolor{red}{\text{SGLang}}}$. $\pi_{\textcolor{blue}{\text{old}}}$ is recomputed with Megatron at the beginning of each training step. |
| 72 | + |
| 73 | +$$ |
| 74 | +L_{\text{PPO-decoupled}}(\theta) |
| 75 | += - \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[ |
| 76 | + \frac{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)} |
| 77 | + \min \left( |
| 78 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)} A_t, |
| 79 | + \mathrm{clip}\left( |
| 80 | + \frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)}, |
| 81 | + 1 - \epsilon, |
| 82 | + 1 + \epsilon |
| 83 | + \right) A_t |
| 84 | + \right) |
| 85 | +\right]. |
| 86 | +$$ |
| 87 | + |
| 88 | +Advantages: |
| 89 | + |
| 90 | +- Achieves batch size invariance and efficient stale data utilization |
| 91 | +- Enables accurate off-policy metrics monitoring |
| 92 | + |
| 93 | +## APIs of Algorithms |
| 94 | + |
| 95 | +You may choose from above algorithms by specifying arguments below: |
| 96 | + |
| 97 | +`--use-rollout-log-probs`: True if only use `rollout_log_probs` to compute the loss, bypassing old_log_probs calculated by training engine; |
| 98 | + |
| 99 | +`--use-rollout-correction`: True if apply importance sampling/rejection sampling to loss. |
| 100 | + |
| 101 | +| `use_rollout_log_probs` | `use_rollout_correction` | Algorithm | Policies |Compute old_log_probs | Batch Invariant | Recommended TIS Mode | |
| 102 | +|-----------------|-------------|-----------|--------------|---------------|-----------------|----------------------| |
| 103 | +| False | False | Standard PPO (Algorithm 0) | 2 ($\pi_\theta$, $\pi_{\textcolor{blue}{\text{old}}}$)|Yes | No | N/A | |
| 104 | +| True | False | Bypassing PPO (Algorithm 3) | 2 ($\pi_\theta$, $\pi_{\textcolor{red}{\text{SGLang}}}$) |🚀 Skipped | No | N/A | |
| 105 | +| False | True | Decoupled PPO (Algorithm 2) | 3 ($\pi_\theta$, $\pi_{\textcolor{blue}{\text{old}}}$, $\pi_{\textcolor{red}{\text{SGLang}}}$) |Yes | Yes | token/seq/geo | |
| 106 | + |
| 107 | +## Configs and Recommended Settings |
| 108 | + |
| 109 | +When choosing to use importance sampling or rejection sampling for mismatch correction (`use-rollout-correction` enabled, Algorithm 2 & 3), you may specify the IS modes and applied levels. |
| 110 | + |
| 111 | +### Arguments |
| 112 | + |
| 113 | +`use-tis`: Enable importance sampling. The IS weight will be multiplied by the policy gradient loss. |
| 114 | + |
| 115 | +- `--tis-mode`: Mode for IS. Allowed mode: **truncate**, **clip**. |
| 116 | +- `--tis-lower-bound`, `--tis-upper-bound`: Bounds for IS weights. |
| 117 | +- `--tis-level`: Allowed levels: **token**, **sequence**, **geometric**. See explanations below. |
| 118 | +- `--tis-batch-normalize`: Normalize IS weights to mean=1.0 across batch |
| 119 | + |
| 120 | + |
| 121 | +`use-rs`: Enable rejection sampling. When choosing to use rejection sampling, the tokens/sequences with an IS weight out of threshold will be directly masked. Those rejected tokens/sequences will not be considered for loss averaging. |
| 122 | + |
| 123 | +- `--rs-lower-bound`, `--rs-upper-bound`: Bounds for RS |
| 124 | +- `--rs-level`: Allowed levels: **token**, **sequence**, **geometric**. See explanations below. |
| 125 | +- `--rs-veto-threshold`: Sequence-level rejection threshold for catastrophic mismatches |
| 126 | + |
| 127 | +### Importance Sampling |
| 128 | + |
| 129 | +For both IS and RS, we provided 3 levels: **token**, **sequence**, **geometric**. |
| 130 | + |
| 131 | +**Token Level (default)**: |
| 132 | + |
| 133 | +Computes importance weights independently for each token: |
| 134 | +$w_i = \exp\left(\log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i)\right)$ |
| 135 | + |
| 136 | +Characteristics: Biased but computationally simple, suitable for most scenarios |
| 137 | + |
| 138 | +**Sequence Level**: |
| 139 | + |
| 140 | +Uses the product of all token weights as the sequence weight: |
| 141 | +$w_{\text{seq}} = \exp\left( \sum_i \left( \log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i) \right) \right)$ |
| 142 | + |
| 143 | +Characteristics: Unbiased but high variance, suitable for sequence-level optimization |
| 144 | + |
| 145 | +**Geometric Level**: |
| 146 | + |
| 147 | +Uses geometric mean to compute sequence weight: |
| 148 | +$w_{\text{seq}} = \exp\left( \frac{1}{n} \sum_{i=1}^{n} \left( \log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i) \right) \right)$ |
| 149 | + |
| 150 | +Characteristics: Biased but low variance, balances bias and variance |
| 151 | + |
| 152 | +### Rejection Sampling |
| 153 | + |
| 154 | +**Token Level**: Reject tokens with IS weight out of threshold |
| 155 | + |
| 156 | +**Sequence Level:** Reject sequences with mean IS weight out of threshold |
| 157 | + |
| 158 | +**Geometric Level:** Reject sequences with geometric mean IS weight out of threshold |
| 159 | + |
| 160 | +- Extremely selective: Requires near-perfect policy match |
| 161 | +- High rejection rate: Only suitable for very slight distribution shifts |
| 162 | + |
| 163 | +**Veto Mechanism**: |
| 164 | + |
| 165 | +Veto mechanism rejects sequences with catastrophically low token probabilities. |
| 166 | +Reject entire sequence if $\exists t \in T$ such that $\rho_t < C_{\text{veto}}$ |
| 167 | + |
| 168 | +- Prevents catastrophic updates from tokens with near-zero probability under $\pi_{\text{old}}$ |
| 169 | +- Independent of IS/RS settings |
| 170 | + |
| 171 | +*Typical values: $10^{-4}$ to $10^{-6}$* |
| 172 | + |
| 173 | +## Mismatch Metrics |
| 174 | + |
| 175 | +When rollout log probabilities are available, SLIME automatically tracks comprehensive metrics to monitor training-inference mismatch and importance sampling weights. These metrics help diagnose policy divergence and guide hyperparameter tuning. |
| 176 | + |
| 177 | +### Mismatch Monitoring Metrics |
| 178 | + |
| 179 | +These metrics quantify the difference between training and rollout policies. They are computed automatically when `rollout_log_probs` are provided, regardless of whether TIS/MIS correction is enabled. |
| 180 | + |
| 181 | +| Metric Name | Description | |
| 182 | +|------------|-------------| |
| 183 | +| `mismatch_training_log_ppl` | Negative mean log probability under training policy: $-\mathbb{E}[\log \pi_{\text{train}}]$ | |
| 184 | +| `mismatch_training_ppl` | Perplexity of training policy: $\exp(-\mathbb{E}[\log \pi_{\text{train}}])$ | |
| 185 | +| `mismatch_rollout_log_ppl` | Negative mean log probability under rollout policy: $-\mathbb{E}[\log \pi_{\text{rollout}}]$ | |
| 186 | +| `mismatch_rollout_ppl` | Perplexity of rollout policy: $\exp(-\mathbb{E}[\log \pi_{\text{rollout}}])$ | |
| 187 | +| `mismatch_kl` | Forward KL divergence estimator: $\mathbb{E}[\log \pi_{\text{rollout}} - \log \pi_{\text{train}}]$ | |
| 188 | +| `mismatch_k3_kl` | K3 KL estimator: $\mathbb{E}[\exp(r) - r - 1]$ where $r = \log \pi_{\text{train}} - \log \pi_{\text{rollout}}$ | |
| 189 | +| `mismatch_log_ppl_diff` | Log perplexity difference| |
| 190 | +| `mismatch_log_ppl_abs_diff` | Absolute log perplexity difference | |
| 191 | +| `mismatch_ppl_ratio` | Perplexity ratio | |
| 192 | +| `train_rollout_logprob_abs_diff` | Token-level absolute log probability difference | |
| 193 | + |
| 194 | +**Usage**: These metrics help you monitor policy drift. Large values indicate a significant mismatch between the training and rollout engines. |
| 195 | + |
| 196 | +### IS/RS Correction Metrics |
| 197 | + |
| 198 | +These metrics track importance sampling weights and corrections. They are only computed when `--use-tis` is enabled. |
| 199 | + |
| 200 | +When using `--custom-tis-function-path` pointing to MIS implementation (e.g., `mis.py`), additional fine-grained metrics become available: |
| 201 | + |
| 202 | +| Metric Name | Description | Required Args | Optional Control Args | |
| 203 | +|------------|-------------|---------------|----------------------| |
| 204 | +| `ois` | On-policy importance sampling ratio: $\exp(\log \pi_{\text{train}} - \log \pi_{\text{old}})$ | `--use-tis` | Only for Algorithm 2 (Decoupled PPO) | |
| 205 | +| `mis_mean_is_weight_before_clip` | Raw IS weights before any correction: $\exp(\text{log-ratio})$ | `--use-tis` | `--mis-level` (token/sequence/geometric) | |
| 206 | +| `mis_ratio_mean_after_mis` | IS weights after correction (bounded or masked) | `--use-tis` | `--mis-mode`, bounds | |
| 207 | +| `mis_truncate_fraction` | Fraction of weights truncated (mode-specific) | `--use-tis`, `--mis-mode=truncate` | `--mis-upper-bound` | |
| 208 | +| `mis_clip_fraction_low` | Fraction of weights clipped below lower bound | `--use-tis`, `--mis-mode=clip` | `--mis-lower-bound`, `--mis-upper-bound` | |
| 209 | +| `mis_clip_fraction_high` | Fraction of weights clipped above upper bound | `--use-tis`, `--mis-mode=clip` | `--mis-lower-bound`, `--mis-upper-bound` | |
| 210 | +| `mis_mask_fraction_low` | Fraction of tokens rejected (below lower bound) | `--use-tis`, `--mis-mode=mask` | `--mis-lower-bound`, `--mis-upper-bound` | |
| 211 | +| `mis_mask_fraction_high` | Fraction of tokens rejected (above upper bound) | `--use-tis`, `--mis-mode=mask` | `--mis-lower-bound`, `--mis-upper-bound` | |
| 212 | +| `mis_catastrophic_token_fraction` | Fraction of catastrophic tokens (veto-specific) | `--use-tis`, `--mis-veto-threshold` set | Sequence-level rejection | |
| 213 | +| `mis_catastrophic_seq_fraction` | Fraction of sequences with catastrophic tokens | `--use-tis`, `--mis-veto-threshold` set | Sequence-level rejection | |
| 214 | +| `mis_batch_norm_factor` | Batch normalization factor applied to weights | `--use-tis`, `--mis-batch-normalize` | Normalizes mean to 1.0 | |
| 215 | + |
| 216 | +## Reference |
| 217 | + |
| 218 | +We thank the materials below for their excellent findings and theories. |
| 219 | + |
| 220 | +1. [Mathematical Formulations of Rollout Correction Methods in verl (Yingru Li)](https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md). |
| 221 | +2. [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) |
| 222 | +3. [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) |
0 commit comments