Skip to content

Conversation

@shamanez
Copy link

What does this PR do?

This PR adds an initial implementation of the MINRL objective described in
arXiv:2512.01374.

Concretely, it introduces a new loss helper compute_policy_loss_minirl in
core_algos.py, which is intended to be used as an alternative policy loss
for future experiments.

This is still work-in-progress: the implementation compiles and runs on small
tests, but I have not yet trained a full model due to limited compute.
Feedback and suggestions on the design and integration are very welcome.

Checklist Before Starting

Test

I do not have enough compute to run a full training run yet.

Current checks:

  • Code passes local linting/formatting.
  • Basic smoke test on a small batch to ensure compute_policy_loss_minirl
    runs without errors and returns finite losses.

Planned:

  • Full training run once compute is available.
  • Compare learning curves against the default policy loss.

@CLAassistant
Copy link

CLAassistant commented Dec 11, 2025

CLA assistant check
All committers have signed the CLA.

@shamanez shamanez mentioned this pull request Dec 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an initial implementation for the MiniRL policy loss. My review focuses on two main areas: the correctness of the new loss function in core_algos.py and the robustness of the example script run_llama3.2_1b_minrl.sh.

I've identified a critical issue in the compute_policy_loss_minirl function where the gradient calculation appears to be incorrect due to premature detaching of tensors and a non-standard loss formulation. I've provided a detailed explanation and a code suggestion to align it with a standard importance-sampled policy gradient objective, which is likely the intended behavior. Additionally, I've found a high-severity issue in the new shell script, where it fails to validate required environment variables, potentially leading to hard-to-debug errors. A suggestion is provided to add checks and fail early with a clear message.

Comment on lines +1130 to +1198
clip_ratio = config.clip_ratio
eps_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
eps_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio

# log r_t = log π_θ - log π_θ_old
negative_approx_kl = (log_prob - old_log_prob).detach()
# Clamp for numerical stability before exponentiating
negative_approx_kl_clamped = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
staleness_ratio = torch.exp(negative_approx_kl_clamped) # r_t(θ)

# -------------------------------------------------------------------------
# 2) MiniRL gate M_t (Eq. 7), using r_t(θ) **detached**:
#
# M_t = 0 if A>0 and r_t > 1+ε_high
# = 0 if A<0 and r_t < 1-ε_low
# = 1 otherwise
#
# This prevents aggressive policy updates from stale trajectories, but
# does not change the gradient formula itself.
# -------------------------------------------------------------------------
r_detached = staleness_ratio #.detach()
ones = torch.ones_like(advantages)
gate = ones.clone()

too_large_pos = (advantages > 0) & (r_detached > 1.0 + eps_high)
too_small_neg = (advantages < 0) & (r_detached < 1.0 - eps_low)
gate = gate.masked_fill(too_large_pos | too_small_neg, 0.0)

# -------------------------------------------------------------------------
# 3) Inference discrepancy weights:
# rollout_is_weights ≈ π_θ_old / μ_rollout (from rollout_corr_helper)
#
# This is already:
# * computed in log-space as old_log_prob - rollout_log_prob
# * truncated, optionally batch-normalized
# * detached (no gradient), as required by IS theory
#
# If not provided, we set it to 1 and get pure on-policy MiniRL.
# -------------------------------------------------------------------------
if rollout_is_weights is None:
inference_discrepancy = torch.ones_like(advantages)
else:
# Mask is already applied inside rollout_corr_helper, but we re-mask to be safe.
inference_discrepancy = rollout_is_weights * response_mask

# -------------------------------------------------------------------------
# 4) Total IS weight:
# w_t(θ) ≈ (π_θ / π_θ_old) ⋅ (π_θ_old / μ_rollout)
# = staleness_ratio ⋅ inference_discrepancy
#
# This matches Eq. (5) in the paper while reusing rollout_is_weights,
# so the loss never needs rollout_log_prob directly.
# -------------------------------------------------------------------------
total_is_weight = staleness_ratio * inference_discrepancy

# -------------------------------------------------------------------------
# 5) MiniRL objective:
# J ≈ E[ M_t ⋅ w_t(θ) ⋅ Â(x,y) ⋅ log π_θ ]
# so loss is:
# L = - M_t ⋅ w_t(θ) ⋅ Â(x,y) ⋅ log π_θ
#
# Gradients:
# ∇_θ L = - E[ M_t ⋅ w_t(θ) ⋅ Â ⋅ ∇ log π_θ ]
#
# where:
# - M_t only gates tokens (built from r_detached)
# - w_t(θ) combines staleness + rollout mismatch
# -------------------------------------------------------------------------
pg_losses = -gate * total_is_weight * advantages * log_prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of the MiniRL policy loss appears to have a few inconsistencies that could lead to incorrect gradients and behavior. The core issues are:

  1. Premature Detaching: On line 1135, negative_approx_kl is detached immediately. This prevents gradients from the current policy log_prob from flowing into the staleness_ratio and total_is_weight. The docstring and the commented-out .detach() on line 1150 suggest that the staleness ratio was intended to have a gradient, with a detached version used only for the gate.
  2. Non-standard Loss Objective: On line 1198, the loss pg_losses is multiplied by log_prob. Standard policy gradient objectives with importance sampling (like PPO) define the loss as a function of advantages * ratio, not advantages * ratio * log_prob. The latter leads to a non-standard gradient update.

To align this with a more standard and likely correct implementation of a gated, importance-sampled policy gradient, I suggest the following changes:

  • Compute staleness_ratio with gradients.
  • Create a detached version of it specifically for the gate.
  • Define the loss as -gate * total_is_weight * advantages to follow the standard policy gradient theorem.
    clip_ratio = config.clip_ratio
    eps_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
    eps_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio

    # log r_t = log π_θ - log π_θ_old. This should carry gradients.
    negative_approx_kl = log_prob - old_log_prob
    # Clamp for numerical stability before exponentiating
    negative_approx_kl_clamped = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
    staleness_ratio = torch.exp(negative_approx_kl_clamped)  # r_t(θ)

    # -------------------------------------------------------------------------
    # 2) MiniRL gate M_t (Eq. 7), using r_t(θ) **detached**:
    #
    #       M_t = 0 if  A>0 and r_t > 1+ε_high
    #           = 0 if  A<0 and r_t < 1-ε_low
    #           = 1 otherwise
    #
    #    This prevents aggressive policy updates from stale trajectories, but
    #    does not change the gradient formula itself.
    # -------------------------------------------------------------------------
    r_detached = staleness_ratio.detach()
    ones = torch.ones_like(advantages)
    gate = ones.clone()

    too_large_pos = (advantages > 0) & (r_detached > 1.0 + eps_high)
    too_small_neg = (advantages < 0) & (r_detached < 1.0 - eps_low)
    gate = gate.masked_fill(too_large_pos | too_small_neg, 0.0)

    # -------------------------------------------------------------------------
    # 3) Inference discrepancy weights:
    #       rollout_is_weights ≈ π_θ_old / μ_rollout  (from rollout_corr_helper)
    #
    #    This is already:
    #      * computed in log-space as old_log_prob - rollout_log_prob
    #      * truncated, optionally batch-normalized
    #      * detached (no gradient), as required by IS theory
    #
    #    If not provided, we set it to 1 and get pure on-policy MiniRL.
    # -------------------------------------------------------------------------
    if rollout_is_weights is None:
        inference_discrepancy = torch.ones_like(advantages)
    else:
        # Mask is already applied inside rollout_corr_helper, but we re-mask to be safe.
        inference_discrepancy = rollout_is_weights * response_mask

    # -------------------------------------------------------------------------
    # 4) Total IS weight:
    #       w_t(θ) ≈ (π_θ / π_θ_old) ⋅ (π_θ_old / μ_rollout)
    #               = staleness_ratio ⋅ inference_discrepancy
    #
    #    This matches Eq. (5) in the paper while reusing rollout_is_weights,
    #    so the loss never needs rollout_log_prob directly.
    # -------------------------------------------------------------------------
    total_is_weight = staleness_ratio * inference_discrepancy

    # -------------------------------------------------------------------------
    # 5) MiniRL objective (standard IS policy gradient):
    #       J ≈ E[ M_t ⋅ w_t(θ) ⋅ Â(x,y) ]
    #    so loss is:
    #       L = - M_t ⋅ w_t(θ) ⋅ Â(x,y)
    #
    #    Gradients:
    #       ∇_θ L = - E[ M_t ⋅ Â ⋅ ∇_θ w_t(θ) ]
    #             = - E[ M_t ⋅ Â ⋅ w_t(θ) ⋅ ∇_θ log π_θ ]
    #
    #    where:
    #       - M_t only gates tokens (built from r_detached)
    #       - w_t(θ) combines staleness + rollout mismatch and has gradients
    # -------------------------------------------------------------------------
    pg_losses = -gate * total_is_weight * advantages

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the attached paper, this loss function is more of a REINFORCE like objective right ? Then the ratio is only being used for a decoupled clip.
QWEN-lessons.pdf

Comment on lines +8 to +14
MODEL_PATH=""

DATA_DIR=""
CKPTS_DIR=""

GPUS_PER_NODE=8
NNODES=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The script initializes MODEL_PATH, DATA_DIR, and CKPTS_DIR to empty strings. If a user runs the script without setting these variables externally, they will be passed as empty strings to the python trainer, which will likely cause a hard-to-debug error deep inside the training code. It's better to fail early with a clear error message. I suggest adding checks to ensure these required variables are set to non-empty values.

Suggested change
MODEL_PATH=""
DATA_DIR=""
CKPTS_DIR=""
GPUS_PER_NODE=8
NNODES=1
MODEL_PATH=""
DATA_DIR=""
CKPTS_DIR=""
if [[ -z "${MODEL_PATH}" || -z "${DATA_DIR}" || -z "${CKPTS_DIR}" ]]; then
echo "Error: MODEL_PATH, DATA_DIR, and CKPTS_DIR must be set to non-empty values." >&2
exit 1
fi
GPUS_PER_NODE=8
NNODES=1

@shamanez
Copy link
Author

Some convergence proof.

Screenshot 2025-12-17 at 11 39 46 am Screenshot 2025-12-17 at 11 41 52 am Screenshot 2025-12-17 at 11 42 21 am Screenshot 2025-12-17 at 11 43 38 am

@eric-haibin-lin @vermouth1992

@shamanez shamanez changed the title Add MiniRL Support [WIP] Add MiniRL Support Dec 17, 2025
@szrlee
Copy link
Collaborator

szrlee commented Dec 17, 2025

@shamanez please take a deep look at the rollout correction module in VeRL. https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html
You may reuse this module.

@shamanez
Copy link
Author

@szrlee Hello Yingru, I actually used the rollout correction module, but according to this paper, we also need to compute staleness ratio, which is the ratio between pi_old and the pi_theta (what we update during the ppo mini batch.)

That is why I computed that inside the loss function.

@shamanez
Copy link
Author

@szrlee I actually reused the error correction module, and explained the logic inside the loss dunction. Please let me know if there's anything missing .

@szrlee
Copy link
Collaborator

szrlee commented Dec 18, 2025

@szrlee I actually reused the error correction module, and explained the logic inside the loss dunction. Please let me know if there's anything missing .

@shamanez what is the difference with decoupled_token_is() in https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html#method-summary-table?

using token_is at \pi_{rollout} -> \pi_{old} and ppo_{clip} for \pi_{old} -> \pi_{new}?

There is no need to implement another loss for miniRL accounting for \pi_{old} -> \pi_{new} as it is essentially ppo_clip loss. You can either check the math definition or code or Qwen's paper. It is just rebranding.

@szrlee szrlee self-assigned this Dec 18, 2025
@shamanez
Copy link
Author

shamanez commented Dec 18, 2025

@szrlee ok I understand, the Bypass Mode Presets (PPO-clip) does almost the same. Does it mean, this loss function in the core_algo.py (compute_policy_loss_with_rollout_correction) , implements the logic ?

But could you please help me to understand if the current PPO clip loss method also addresses following constrains.

Well please correct me if I'm am wrong, but according to their logic , the importance sampling weight depends on two ratios as follows.

image

If I understand it correctly, here pi_theta is actually the loss getting updated inside the mini_batch ppo. In the error correction module only compute the training-internee discrepancy once per the batch.

Also they use decoupled PPO.
image

The reason, I thought of adding it as a new loss is , as you can see in their experiments, pi_theta is basically the policy that is getting updated in the mini batch..

image

@szrlee
Copy link
Collaborator

szrlee commented Dec 19, 2025

@szrlee ok I understand, the Bypass Mode Presets (PPO-clip) does almost the same. Does it mean, this loss function in the core_algo.py (compute_policy_loss_with_rollout_correction) , implements the logic ?

But could you please help me to understand if the current PPO clip loss method also addresses following constrains.

Well please correct me if I'm am wrong, but according to their logic , the importance sampling weight depends on two ratios as follows.

image If I understand it correctly, here pi_theta is actually the loss getting updated inside the mini_batch ppo. In the error correction module only compute the training-internee discrepancy once per the batch.

Also they use decoupled PPO. image

The reason, I thought of adding it as a new loss is , as you can see in their experiments, pi_theta is basically the policy that is getting updated in the mini batch..

image

@shamanez Please look at the preset methods on deoupled_token_is() in the algorithm.py and see how it works. thank you!

We already address both 3 policy case (\pi_{rollout} -> \pi_{old}, \pi_{old} -> \pi_{new}) and 2 policy case (\pi_{rollout} -> \pi_new) in rollout correction module.

Qwen's paper also cites our work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants