-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Add MiniRL Support #4491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add MiniRL Support #4491
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an initial implementation for the MiniRL policy loss. My review focuses on two main areas: the correctness of the new loss function in core_algos.py and the robustness of the example script run_llama3.2_1b_minrl.sh.
I've identified a critical issue in the compute_policy_loss_minirl function where the gradient calculation appears to be incorrect due to premature detaching of tensors and a non-standard loss formulation. I've provided a detailed explanation and a code suggestion to align it with a standard importance-sampled policy gradient objective, which is likely the intended behavior. Additionally, I've found a high-severity issue in the new shell script, where it fails to validate required environment variables, potentially leading to hard-to-debug errors. A suggestion is provided to add checks and fail early with a clear message.
| clip_ratio = config.clip_ratio | ||
| eps_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio | ||
| eps_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio | ||
|
|
||
| # log r_t = log π_θ - log π_θ_old | ||
| negative_approx_kl = (log_prob - old_log_prob).detach() | ||
| # Clamp for numerical stability before exponentiating | ||
| negative_approx_kl_clamped = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) | ||
| staleness_ratio = torch.exp(negative_approx_kl_clamped) # r_t(θ) | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 2) MiniRL gate M_t (Eq. 7), using r_t(θ) **detached**: | ||
| # | ||
| # M_t = 0 if A>0 and r_t > 1+ε_high | ||
| # = 0 if A<0 and r_t < 1-ε_low | ||
| # = 1 otherwise | ||
| # | ||
| # This prevents aggressive policy updates from stale trajectories, but | ||
| # does not change the gradient formula itself. | ||
| # ------------------------------------------------------------------------- | ||
| r_detached = staleness_ratio #.detach() | ||
| ones = torch.ones_like(advantages) | ||
| gate = ones.clone() | ||
|
|
||
| too_large_pos = (advantages > 0) & (r_detached > 1.0 + eps_high) | ||
| too_small_neg = (advantages < 0) & (r_detached < 1.0 - eps_low) | ||
| gate = gate.masked_fill(too_large_pos | too_small_neg, 0.0) | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 3) Inference discrepancy weights: | ||
| # rollout_is_weights ≈ π_θ_old / μ_rollout (from rollout_corr_helper) | ||
| # | ||
| # This is already: | ||
| # * computed in log-space as old_log_prob - rollout_log_prob | ||
| # * truncated, optionally batch-normalized | ||
| # * detached (no gradient), as required by IS theory | ||
| # | ||
| # If not provided, we set it to 1 and get pure on-policy MiniRL. | ||
| # ------------------------------------------------------------------------- | ||
| if rollout_is_weights is None: | ||
| inference_discrepancy = torch.ones_like(advantages) | ||
| else: | ||
| # Mask is already applied inside rollout_corr_helper, but we re-mask to be safe. | ||
| inference_discrepancy = rollout_is_weights * response_mask | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 4) Total IS weight: | ||
| # w_t(θ) ≈ (π_θ / π_θ_old) ⋅ (π_θ_old / μ_rollout) | ||
| # = staleness_ratio ⋅ inference_discrepancy | ||
| # | ||
| # This matches Eq. (5) in the paper while reusing rollout_is_weights, | ||
| # so the loss never needs rollout_log_prob directly. | ||
| # ------------------------------------------------------------------------- | ||
| total_is_weight = staleness_ratio * inference_discrepancy | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 5) MiniRL objective: | ||
| # J ≈ E[ M_t ⋅ w_t(θ) ⋅ Â(x,y) ⋅ log π_θ ] | ||
| # so loss is: | ||
| # L = - M_t ⋅ w_t(θ) ⋅ Â(x,y) ⋅ log π_θ | ||
| # | ||
| # Gradients: | ||
| # ∇_θ L = - E[ M_t ⋅ w_t(θ) ⋅ Â ⋅ ∇ log π_θ ] | ||
| # | ||
| # where: | ||
| # - M_t only gates tokens (built from r_detached) | ||
| # - w_t(θ) combines staleness + rollout mismatch | ||
| # ------------------------------------------------------------------------- | ||
| pg_losses = -gate * total_is_weight * advantages * log_prob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of the MiniRL policy loss appears to have a few inconsistencies that could lead to incorrect gradients and behavior. The core issues are:
- Premature Detaching: On line 1135,
negative_approx_klis detached immediately. This prevents gradients from the current policylog_probfrom flowing into thestaleness_ratioandtotal_is_weight. The docstring and the commented-out.detach()on line 1150 suggest that the staleness ratio was intended to have a gradient, with a detached version used only for the gate. - Non-standard Loss Objective: On line 1198, the loss
pg_lossesis multiplied bylog_prob. Standard policy gradient objectives with importance sampling (like PPO) define the loss as a function ofadvantages * ratio, notadvantages * ratio * log_prob. The latter leads to a non-standard gradient update.
To align this with a more standard and likely correct implementation of a gated, importance-sampled policy gradient, I suggest the following changes:
- Compute
staleness_ratiowith gradients. - Create a detached version of it specifically for the gate.
- Define the loss as
-gate * total_is_weight * advantagesto follow the standard policy gradient theorem.
clip_ratio = config.clip_ratio
eps_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
eps_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
# log r_t = log π_θ - log π_θ_old. This should carry gradients.
negative_approx_kl = log_prob - old_log_prob
# Clamp for numerical stability before exponentiating
negative_approx_kl_clamped = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
staleness_ratio = torch.exp(negative_approx_kl_clamped) # r_t(θ)
# -------------------------------------------------------------------------
# 2) MiniRL gate M_t (Eq. 7), using r_t(θ) **detached**:
#
# M_t = 0 if A>0 and r_t > 1+ε_high
# = 0 if A<0 and r_t < 1-ε_low
# = 1 otherwise
#
# This prevents aggressive policy updates from stale trajectories, but
# does not change the gradient formula itself.
# -------------------------------------------------------------------------
r_detached = staleness_ratio.detach()
ones = torch.ones_like(advantages)
gate = ones.clone()
too_large_pos = (advantages > 0) & (r_detached > 1.0 + eps_high)
too_small_neg = (advantages < 0) & (r_detached < 1.0 - eps_low)
gate = gate.masked_fill(too_large_pos | too_small_neg, 0.0)
# -------------------------------------------------------------------------
# 3) Inference discrepancy weights:
# rollout_is_weights ≈ π_θ_old / μ_rollout (from rollout_corr_helper)
#
# This is already:
# * computed in log-space as old_log_prob - rollout_log_prob
# * truncated, optionally batch-normalized
# * detached (no gradient), as required by IS theory
#
# If not provided, we set it to 1 and get pure on-policy MiniRL.
# -------------------------------------------------------------------------
if rollout_is_weights is None:
inference_discrepancy = torch.ones_like(advantages)
else:
# Mask is already applied inside rollout_corr_helper, but we re-mask to be safe.
inference_discrepancy = rollout_is_weights * response_mask
# -------------------------------------------------------------------------
# 4) Total IS weight:
# w_t(θ) ≈ (π_θ / π_θ_old) ⋅ (π_θ_old / μ_rollout)
# = staleness_ratio ⋅ inference_discrepancy
#
# This matches Eq. (5) in the paper while reusing rollout_is_weights,
# so the loss never needs rollout_log_prob directly.
# -------------------------------------------------------------------------
total_is_weight = staleness_ratio * inference_discrepancy
# -------------------------------------------------------------------------
# 5) MiniRL objective (standard IS policy gradient):
# J ≈ E[ M_t ⋅ w_t(θ) ⋅ Â(x,y) ]
# so loss is:
# L = - M_t ⋅ w_t(θ) ⋅ Â(x,y)
#
# Gradients:
# ∇_θ L = - E[ M_t ⋅ Â ⋅ ∇_θ w_t(θ) ]
# = - E[ M_t ⋅ Â ⋅ w_t(θ) ⋅ ∇_θ log π_θ ]
#
# where:
# - M_t only gates tokens (built from r_detached)
# - w_t(θ) combines staleness + rollout mismatch and has gradients
# -------------------------------------------------------------------------
pg_losses = -gate * total_is_weight * advantagesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the attached paper, this loss function is more of a REINFORCE like objective right ? Then the ratio is only being used for a decoupled clip.
QWEN-lessons.pdf
| MODEL_PATH="" | ||
|
|
||
| DATA_DIR="" | ||
| CKPTS_DIR="" | ||
|
|
||
| GPUS_PER_NODE=8 | ||
| NNODES=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script initializes MODEL_PATH, DATA_DIR, and CKPTS_DIR to empty strings. If a user runs the script without setting these variables externally, they will be passed as empty strings to the python trainer, which will likely cause a hard-to-debug error deep inside the training code. It's better to fail early with a clear error message. I suggest adding checks to ensure these required variables are set to non-empty values.
| MODEL_PATH="" | |
| DATA_DIR="" | |
| CKPTS_DIR="" | |
| GPUS_PER_NODE=8 | |
| NNODES=1 | |
| MODEL_PATH="" | |
| DATA_DIR="" | |
| CKPTS_DIR="" | |
| if [[ -z "${MODEL_PATH}" || -z "${DATA_DIR}" || -z "${CKPTS_DIR}" ]]; then | |
| echo "Error: MODEL_PATH, DATA_DIR, and CKPTS_DIR must be set to non-empty values." >&2 | |
| exit 1 | |
| fi | |
| GPUS_PER_NODE=8 | |
| NNODES=1 |
|
@shamanez please take a deep look at the rollout correction module in VeRL. https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html |
|
@szrlee Hello Yingru, I actually used the rollout correction module, but according to this paper, we also need to compute staleness ratio, which is the ratio between pi_old and the pi_theta (what we update during the ppo mini batch.) That is why I computed that inside the loss function. |
|
@szrlee I actually reused the error correction module, and explained the logic inside the loss dunction. Please let me know if there's anything missing . |
@shamanez what is the difference with decoupled_token_is() in https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html#method-summary-table? using token_is at \pi_{rollout} -> \pi_{old} and ppo_{clip} for \pi_{old} -> \pi_{new}? There is no need to implement another loss for miniRL accounting for \pi_{old} -> \pi_{new} as it is essentially ppo_clip loss. You can either check the math definition or code or Qwen's paper. It is just rebranding. |
|
@szrlee ok I understand, the But could you please help me to understand if the current PPO clip loss method also addresses following constrains. Well please correct me if I'm am wrong, but according to their logic , the importance sampling weight depends on two ratios as follows.
If I understand it correctly, here pi_theta is actually the loss getting updated inside the mini_batch ppo. In the error correction module only compute the training-internee discrepancy once per the batch. The reason, I thought of adding it as a new loss is , as you can see in their experiments, pi_theta is basically the policy that is getting updated in the mini batch..
|
@shamanez Please look at the preset methods on deoupled_token_is() in the algorithm.py and see how it works. thank you! We already address both 3 policy case (\pi_{rollout} -> \pi_{old}, \pi_{old} -> \pi_{new}) and 2 policy case (\pi_{rollout} -> \pi_new) in rollout correction module. Qwen's paper also cites our work. |










What does this PR do?
This PR adds an initial implementation of the MINRL objective described in
arXiv:2512.01374.
Concretely, it introduces a new loss helper
compute_policy_loss_minirlincore_algos.py, which is intended to be used as an alternative policy lossfor future experiments.
This is still work-in-progress: the implementation compiles and runs on small
tests, but I have not yet trained a full model due to limited compute.
Feedback and suggestions on the design and integration are very welcome.
Checklist Before Starting
https://github.com/volcengine/verl/pulls?q=MINRL
[{modules}] {type}: {description}(using
[algo] feat: add MINRL policy loss)Test
I do not have enough compute to run a full training run yet.
Current checks:
compute_policy_loss_minirlruns without errors and returns finite losses.
Planned: