Skip to content

Commit 9bbe337

Browse files
More TIS features; skip recompute; mismatch metrics without TIS (#690)
Co-authored-by: 赵晨阳 <zhaochen20@outlook.com>
1 parent 079a537 commit 9bbe337

File tree

6 files changed

+359
-81
lines changed

6 files changed

+359
-81
lines changed
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Rollout Correction Methods
2+
3+
Rollout correction (e.g, TIS, MIS) through algorithmic methods.
4+
5+
6+
## Quick Takeaway
7+
8+
This function is used to solve offline scenarios through algorithmic adaptations, e.g. TIS/MIS.
9+
10+
We included 3 rollout correction algorithms:
11+
12+
1. decoupled, 3-policies PPO with rollout importance sampling
13+
2. direct rollout policy overwriting in the standard PPO
14+
3. pure REINFORCE loss (without PPO clipping) with rollout importance sampling
15+
16+
17+
`--use-tis`: use this flag to **turn on TIS/MIS** for rollout correction (details in **Algorithms**).
18+
You may specify the **IS/RS configs** with a config file using `--custom-config-path`.
19+
20+
`--use-rollout-logprobs`: When use this flag, the logprobs will **not** be recomputed by training engine - rollout log probs will be directly used in PPO/GRPO loss.
21+
22+
`--get-mismatch-metrics`: When you don't want to add TIS/MIS, but still want to monitor the mismatch-related metrics (e.g. rollout-training KL). It will **only return mismatch metrics** but not change the loss in any way.
23+
24+
25+
## Algorithms
26+
27+
We give examples of the algorithms for solving the training-inference mismatch issue.
28+
29+
### [Baseline: No Mismatch Correction] Standard PPO
30+
31+
This is the basic PPO algorithm with potentially training-inference mismatch issue when the output of SGLang and Megatron does not exactly match.
32+
33+
$$
34+
L_{\text{PPO}}(\theta)
35+
= - \mathbb{E}_{x \sim \mathcal{D},\, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[
36+
\min \left(
37+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{Megatron}}}(y \mid x)} A_t,
38+
\mathrm{clip}\left(
39+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{Megatron}}}(y \mid x)},
40+
1 - \epsilon,
41+
1 + \epsilon
42+
\right) A_t
43+
\right)
44+
\right].
45+
$$
46+
47+
### Bypassing PPO importance sampling
48+
49+
Like REINFORCE, we directly use the rollout engine's log probs as the old policy in offline PPO's importance sampling, rather than the recomputed log-probs from the training engine.
50+
51+
$$
52+
L_{\text{PPO-bypass}}(\theta)
53+
= - \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[
54+
\min \left(
55+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)} A_t,
56+
\mathrm{clip}\left(
57+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)},
58+
1 - \epsilon,
59+
1 + \epsilon
60+
\right) A_t
61+
\right)
62+
\right].
63+
$$
64+
65+
Advantages:
66+
67+
- Efficiency: skip `log_prob` recomputation on training engine. Reduce one expensive forward pass on all the generated trajectories.
68+
69+
### Decoupled, 3-policy PPO Importance Sampling
70+
71+
[Decoupled PPO](https://arxiv.org/pdf/2110.00641) achieves batch-independent PPO by decoupling two roles: Proximal Policy (anchor policy for PPO clipping, control update size) and Behavior Policy (for off-policy correction in importance sampling). Therefore, there are totally 3 roles engaged in this mode, **target policy** $\pi_\theta$, **proximal policy** $\pi_{\textcolor{blue}{\text{old}}}$, and **behavior policy** $\pi_{\textcolor{red}{\text{SGLang}}}$. $\pi_{\textcolor{blue}{\text{old}}}$ is recomputed with Megatron at the beginning of each training step.
72+
73+
$$
74+
L_{\text{PPO-decoupled}}(\theta)
75+
= - \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\textcolor{red}{\text{SGLang}}}} \left[
76+
\frac{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)}{\pi_{\textcolor{red}{\text{SGLang}}}(y \mid x)}
77+
\min \left(
78+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)} A_t,
79+
\mathrm{clip}\left(
80+
\frac{\pi_\theta(y \mid x)}{\pi_{\textcolor{blue}{\text{old}}}(y \mid x)},
81+
1 - \epsilon,
82+
1 + \epsilon
83+
\right) A_t
84+
\right)
85+
\right].
86+
$$
87+
88+
Advantages:
89+
90+
- Achieves batch size invariance and efficient stale data utilization
91+
- Enables accurate off-policy metrics monitoring
92+
93+
## APIs of Algorithms
94+
95+
You may choose from above algorithms by specifying arguments below:
96+
97+
`--use-rollout-log-probs`: True if only use `rollout_log_probs` to compute the loss, bypassing old_log_probs calculated by training engine;
98+
99+
`--use-rollout-correction`: True if apply importance sampling/rejection sampling to loss.
100+
101+
| `use_rollout_log_probs` | `use_rollout_correction` | Algorithm | Policies |Compute old_log_probs | Batch Invariant | Recommended TIS Mode |
102+
|-----------------|-------------|-----------|--------------|---------------|-----------------|----------------------|
103+
| False | False | Standard PPO (Algorithm 0) | 2 ($\pi_\theta$, $\pi_{\textcolor{blue}{\text{old}}}$)|Yes | No | N/A |
104+
| True | False | Bypassing PPO (Algorithm 3) | 2 ($\pi_\theta$, $\pi_{\textcolor{red}{\text{SGLang}}}$) |🚀 Skipped | No | N/A |
105+
| False | True | Decoupled PPO (Algorithm 2) | 3 ($\pi_\theta$, $\pi_{\textcolor{blue}{\text{old}}}$, $\pi_{\textcolor{red}{\text{SGLang}}}$) |Yes | Yes | token/seq/geo |
106+
107+
## Configs and Recommended Settings
108+
109+
When choosing to use importance sampling or rejection sampling for mismatch correction (`use-rollout-correction` enabled, Algorithm 2 & 3), you may specify the IS modes and applied levels.
110+
111+
### Arguments
112+
113+
`use-tis`: Enable importance sampling. The IS weight will be multiplied by the policy gradient loss.
114+
115+
- `--tis-mode`: Mode for IS. Allowed mode: **truncate**, **clip**.
116+
- `--tis-lower-bound`, `--tis-upper-bound`: Bounds for IS weights.
117+
- `--tis-level`: Allowed levels: **token**, **sequence**, **geometric**. See explanations below.
118+
- `--tis-batch-normalize`: Normalize IS weights to mean=1.0 across batch
119+
120+
121+
`use-rs`: Enable rejection sampling. When choosing to use rejection sampling, the tokens/sequences with an IS weight out of threshold will be directly masked. Those rejected tokens/sequences will not be considered for loss averaging.
122+
123+
- `--rs-lower-bound`, `--rs-upper-bound`: Bounds for RS
124+
- `--rs-level`: Allowed levels: **token**, **sequence**, **geometric**. See explanations below.
125+
- `--rs-veto-threshold`: Sequence-level rejection threshold for catastrophic mismatches
126+
127+
### Importance Sampling
128+
129+
For both IS and RS, we provided 3 levels: **token**, **sequence**, **geometric**.
130+
131+
**Token Level (default)**:
132+
133+
Computes importance weights independently for each token:
134+
$w_i = \exp\left(\log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i)\right)$
135+
136+
Characteristics: Biased but computationally simple, suitable for most scenarios
137+
138+
**Sequence Level**:
139+
140+
Uses the product of all token weights as the sequence weight:
141+
$w_{\text{seq}} = \exp\left( \sum_i \left( \log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i) \right) \right)$
142+
143+
Characteristics: Unbiased but high variance, suitable for sequence-level optimization
144+
145+
**Geometric Level**:
146+
147+
Uses geometric mean to compute sequence weight:
148+
$w_{\text{seq}} = \exp\left( \frac{1}{n} \sum_{i=1}^{n} \left( \log \pi_{\text{train}}(x_i) - \log \pi_{\text{rollout}}(x_i) \right) \right)$
149+
150+
Characteristics: Biased but low variance, balances bias and variance
151+
152+
### Rejection Sampling
153+
154+
**Token Level**: Reject tokens with IS weight out of threshold
155+
156+
**Sequence Level:** Reject sequences with mean IS weight out of threshold
157+
158+
**Geometric Level:** Reject sequences with geometric mean IS weight out of threshold
159+
160+
- Extremely selective: Requires near-perfect policy match
161+
- High rejection rate: Only suitable for very slight distribution shifts
162+
163+
**Veto Mechanism**:
164+
165+
Veto mechanism rejects sequences with catastrophically low token probabilities.
166+
Reject entire sequence if $\exists t \in T$ such that $\rho_t < C_{\text{veto}}$
167+
168+
- Prevents catastrophic updates from tokens with near-zero probability under $\pi_{\text{old}}$
169+
- Independent of IS/RS settings
170+
171+
*Typical values: $10^{-4}$ to $10^{-6}$*
172+
173+
## Mismatch Metrics
174+
175+
When rollout log probabilities are available, SLIME automatically tracks comprehensive metrics to monitor training-inference mismatch and importance sampling weights. These metrics help diagnose policy divergence and guide hyperparameter tuning.
176+
177+
### Mismatch Monitoring Metrics
178+
179+
These metrics quantify the difference between training and rollout policies. They are computed automatically when `rollout_log_probs` are provided, regardless of whether TIS/MIS correction is enabled.
180+
181+
| Metric Name | Description |
182+
|------------|-------------|
183+
| `mismatch_training_log_ppl` | Negative mean log probability under training policy: $-\mathbb{E}[\log \pi_{\text{train}}]$ |
184+
| `mismatch_training_ppl` | Perplexity of training policy: $\exp(-\mathbb{E}[\log \pi_{\text{train}}])$ |
185+
| `mismatch_rollout_log_ppl` | Negative mean log probability under rollout policy: $-\mathbb{E}[\log \pi_{\text{rollout}}]$ |
186+
| `mismatch_rollout_ppl` | Perplexity of rollout policy: $\exp(-\mathbb{E}[\log \pi_{\text{rollout}}])$ |
187+
| `mismatch_kl` | Forward KL divergence estimator: $\mathbb{E}[\log \pi_{\text{rollout}} - \log \pi_{\text{train}}]$ |
188+
| `mismatch_k3_kl` | K3 KL estimator: $\mathbb{E}[\exp(r) - r - 1]$ where $r = \log \pi_{\text{train}} - \log \pi_{\text{rollout}}$ |
189+
| `mismatch_log_ppl_diff` | Log perplexity difference|
190+
| `mismatch_log_ppl_abs_diff` | Absolute log perplexity difference |
191+
| `mismatch_ppl_ratio` | Perplexity ratio |
192+
| `train_rollout_logprob_abs_diff` | Token-level absolute log probability difference |
193+
194+
**Usage**: These metrics help you monitor policy drift. Large values indicate a significant mismatch between the training and rollout engines.
195+
196+
### IS/RS Correction Metrics
197+
198+
These metrics track importance sampling weights and corrections. They are only computed when `--use-tis` is enabled.
199+
200+
When using `--custom-tis-function-path` pointing to MIS implementation (e.g., `mis.py`), additional fine-grained metrics become available:
201+
202+
| Metric Name | Description | Required Args | Optional Control Args |
203+
|------------|-------------|---------------|----------------------|
204+
| `ois` | On-policy importance sampling ratio: $\exp(\log \pi_{\text{train}} - \log \pi_{\text{old}})$ | `--use-tis` | Only for Algorithm 2 (Decoupled PPO) |
205+
| `mis_mean_is_weight_before_clip` | Raw IS weights before any correction: $\exp(\text{log-ratio})$ | `--use-tis` | `--mis-level` (token/sequence/geometric) |
206+
| `mis_ratio_mean_after_mis` | IS weights after correction (bounded or masked) | `--use-tis` | `--mis-mode`, bounds |
207+
| `mis_truncate_fraction` | Fraction of weights truncated (mode-specific) | `--use-tis`, `--mis-mode=truncate` | `--mis-upper-bound` |
208+
| `mis_clip_fraction_low` | Fraction of weights clipped below lower bound | `--use-tis`, `--mis-mode=clip` | `--mis-lower-bound`, `--mis-upper-bound` |
209+
| `mis_clip_fraction_high` | Fraction of weights clipped above upper bound | `--use-tis`, `--mis-mode=clip` | `--mis-lower-bound`, `--mis-upper-bound` |
210+
| `mis_mask_fraction_low` | Fraction of tokens rejected (below lower bound) | `--use-tis`, `--mis-mode=mask` | `--mis-lower-bound`, `--mis-upper-bound` |
211+
| `mis_mask_fraction_high` | Fraction of tokens rejected (above upper bound) | `--use-tis`, `--mis-mode=mask` | `--mis-lower-bound`, `--mis-upper-bound` |
212+
| `mis_catastrophic_token_fraction` | Fraction of catastrophic tokens (veto-specific) | `--use-tis`, `--mis-veto-threshold` set | Sequence-level rejection |
213+
| `mis_catastrophic_seq_fraction` | Fraction of sequences with catastrophic tokens | `--use-tis`, `--mis-veto-threshold` set | Sequence-level rejection |
214+
| `mis_batch_norm_factor` | Batch normalization factor applied to weights | `--use-tis`, `--mis-batch-normalize` | Normalizes mean to 1.0 |
215+
216+
## Reference
217+
218+
We thank the materials below for their excellent findings and theories.
219+
220+
1. [Mathematical Formulations of Rollout Correction Methods in verl (Yingru Li)](https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md).
221+
2. [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl)
222+
3. [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)

0 commit comments

Comments
 (0)