Version Requirement: ms-swift>=3.11
TL;DR: While GRPO introduces vLLM to accelerate the sampling process, it also introduces Training-Inference Mismatch issues that may affect training stability. This document explains the background, causes, and solutions to this problem.
The training objective of GRPO (Group Relative Policy Optimization) can be expressed as:
Where:
-
$r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{<t})}{\pi_{\theta_{\text{old}}}(y_t|x, y_{<t})}$ is the importance sampling ratio -
$\hat{A}_t$ is the advantage function, calculated based on reward and group baseline -
$\epsilon$ is the clipping parameter
Core Assumption: Samples
- The rollout model and the training model (policy model) should be the same model
$\pi_\theta$ - The probability distributions of both models should be exactly identical, i.e.,
$\pi_{\text{rollout}} = \pi_\theta$
GRPO's training speed is largely constrained by the sampling process (rollout). To accelerate this, training frameworks introduce high-performance inference engines (such as vLLM) for sampling. The ideal assumption is that through weight synchronization, vLLM maintains consistency with the training model, i.e.,
However, in practice, even with fully synchronized weights, due to differences in operator implementations, the probability distributions still deviate:
At this point, the actual training objective becomes:
Where samples come from
To address training-inference mismatch, we can introduce Importance Sampling (IS) correction mechanisms.
The basic idea of importance sampling is: when samples come from distribution
Applied to the GRPO scenario, the corrected loss function is:
Where
Importance sampling weights can be computed and applied at different granularities:
- Token-Level
Compute the importance sampling ratio at each token:
- Sequence-Level
Compute the sequence-level importance sampling ratio, then broadcast to each token:
Excessively large importance sampling weights can cause gradient explosion and destabilize training. Therefore, weight control is necessary:
Truncate the importance sampling weight to the
This method retains all samples but limits their influence.
Discard token/sequence data where weights exceed the threshold:
Combining granularity and control strategies, there are four correction modes (selected via --rollout_importance_sampling_mode parameter):
| Mode | Description |
|---|---|
token_truncate |
Token-level truncation |
token_mask |
Token-level masking |
sequence_truncate |
Sequence-level truncation |
sequence_mask |
Sequence-level masking |
The threshold is set via the --rollout_importance_sampling_threshold parameter.
To monitor the degree of training-inference mismatch during training, we add the following metrics to the logs (prefixed with rollout_correction/):
KL divergence measures the deviation between the rollout policy and the training policy. Both metrics estimate
Direct estimator kl:
K3 estimator k3_kl:
The K3 estimator is more numerically stable when KL values are small and is always non-negative.
Perplexity measures the model's prediction uncertainty for a sequence:
Related metrics:
-
training_ppl/training_log_ppl: Training policy PPL and its logarithm -
rollout_ppl/rollout_log_ppl: Rollout policy PPL and its logarithm -
log_ppl_diff: Log PPL difference, positive value means training policy assigns lower probability -
log_ppl_abs_diff: Absolute log PPL difference -
log_ppl_diff_max/log_ppl_diff_min: Max/min of log PPL difference -
ppl_ratio: PPL ratio $\frac{\text{PPL}{\text{training}}}{\text{PPL}{\text{rollout}}}$
χ² divergence measures the variance of importance sampling weights:
-
chi2_token: Token-level χ² divergence,$\mathbb{E}[\rho_t^2] - 1$ -
chi2_seq: Sequence-level χ² divergence (geometric mean based),$\mathbb{E}[\rho_{\text{geo}}^2] - 1$ , where$\rho_{\text{geo}} = \exp(\frac{1}{T}\sum_t \log \rho_t)$
Higher χ² divergence indicates larger IS weight variance and less stable training. chi2_seq uses geometric mean instead of product, making it comparable in scale to chi2_token.
Effective sample size measures the number of samples that actually contribute after importance sampling:
A larger ESS value (closer to 1) indicates more uniform importance sampling weight distribution and higher sample utilization efficiency. When all weights are equal (on-policy), ESS = 1; when weights differ significantly (severely off-policy), ESS becomes small.
is_weight_mean: Average importance sampling weight, ideal value is 1.0clipped_frac: Fraction of samples that were truncated or masked
If you only want to monitor the degree of training-inference mismatch without enabling importance sampling correction, you can set:
--log_rollout_offpolicy_metrics true
This will log all diagnostic metrics (KL, PPL, χ², etc.) without modifying the loss function.
Enable the correction mechanism with the following parameters:
--rollout_importance_sampling_mode (default None)
--rollout_importance_sampling_threshold (default 2)
When rollout_importance_sampling_mode is set, diagnostic metrics are automatically logged without needing to set log_rollout_offpolicy_metrics.
In addition to importance sampling correction, you can use Off-Policy Sequence Masking to address training-inference mismatch. This technique comes from the DeepSeek-V3.2 paper.
The core idea of Off-Policy Sequence Masking is: when the current policy deviates significantly from the old policy (rollout or old policy), directly discard (mask) that sequence from loss computation. This approach specifically targets sequences with negative advantage, as these are more likely to cause training instability when policy shift is large.
Specifically, for each sequence, compute:
Sequence completion_mask=1):
$\delta_i > \tau$ -
AND
$\hat{A}_i < 0$
Where:
-
$\pi_{\text{old}}$ preferentially usesrollout_per_token_logps(logprobs from rollout/behavior policy); if unavailable, falls back toold_per_token_logps -
$\tau$ is the user-set threshold (--off_policy_sequence_mask_delta, default None = disabled)
- https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
- https://fengyao.notion.site/off-policy-rl
- https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/rollout_corr_helper.py
- DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models