Skip to content

Decouple IS Weights from Rejection Sampling in MIS#657

Merged
zhuzilin merged 5 commits intoTHUDM:mainfrom
yueming-yuan:is_rejection
Nov 2, 2025
Merged

Decouple IS Weights from Rejection Sampling in MIS#657
zhuzilin merged 5 commits intoTHUDM:mainfrom
yueming-yuan:is_rejection

Conversation

@yueming-yuan
Copy link
Collaborator

References

  1. This refactoring follows the design of verl#3915. Thanks for the insights!
  2. Thanks to the great contribution of this paper When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

Summary

Refactors the Masked Importance Sampling (MIS) implementation to properly separate IS weight correction from rejection sampling. This fixes a critical gradient normalization bug where rejected tokens were incorrectly included in the loss denominator.

Motivation

The previous implementation have two distinct mechanisms by zeroing IS weights at rejected positions:

  1. IS weight correction: Applies π_train/π_rollout ratios to correct for distribution mismatch
  2. Rejection sampling: Excludes outlier samples from training

This led to an issue: rejected tokens had zero weights (numerator) but were still counted in the loss denominator, causing incorrect gradient scaling.

Example

Consider a sequence with 5 tokens where 3 are rejected by mask mode (ratios outside [0.5, 2.0]):

# Input
log_ratios = [0.1, -1.5, 0.8, -10.0, 0.3]
ratios     = [1.11, 0.22, 2.23, 0.00005, 1.35]  # tokens 1,2,3 rejected

# previous implementation
is_weights = [1.11, 0.0, 0.0, 0.0, 1.35]
loss_mask  = [1, 1, 1, 1, 1] 
pg_loss = sum(loss * is_weights) / sum(loss_mask)
            = (1.11 + 0 + 0 + 0 + 1.35) / 5 = 0.49  # ⚠️ denominator includes masked entries -> Smaller loss norm

# NEW Implementation
is_weights       = [1.11, 0.22, 2.23, 0.00005, 1.35]  # weights preserved
modified_mask    = [1, 0, 0, 0, 1]                      # rejection separate
pg_loss = sum(loss * is_weights * modified_mask) / sum(modified_mask)
            = (1.11 + 0 + 0 + 0 + 1.35) / 2 = 1.23   # ✅ denominator excludes masked entries

Main Changes

1. API Change

compute_mis_weights_with_cp() now returns 3 values instead of 2:

# Before
pg_loss, metrics = compute_mis_weights_with_cp(...)

# After
pg_loss, modified_response_masks, metrics = compute_mis_weights_with_cp(...)

2. Separation of IS Weights and Rejection Sampling

IS Weights (is_weights):

  • Safety-bounded to [exp(-20), exp(20)] to prevent overflow
  • Mode-specific processing:
    • truncate: Upper clamped to mis_upper_bound
    • mask: Safety-bounded only (no threshold clamping)
    • clip: Clamped to [lower, upper]
  • Zeroed at padding positions only
  • Used for weighting policy gradient

Rejection Sampling (modified_response_masks):

  • mask mode: Excludes tokens with IS ratios outside [lower, upper]
  • veto: Excludes entire sequences with catastrophic tokens (ratio < veto_threshold)
  • Used for loss aggregation denominator

3. Correct Loss Normalization

# In loss.py (Line 463-470)
pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs)

# Rebuild sum_of_sample_mean with modified masks for correct denominator
sum_of_sample_mean = get_sum_of_sample_mean(
    total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
)

pg_loss = sum_of_sample_mean(pg_loss)  # Now uses correct denominator

Files Changed

Modified Files

  • slime/backends/megatron_utils/loss.py: Updated to correct loss normalization with modified masks
  • examples/train_infer_mismatch_helper/mis.py: Refactored to return 3-tuple, separated IS weights from rejection

Impact by Mode

  • truncate mode: No behavioral change
  • mask mode: Gradient scale will change (increase) when rejection rate > 0
  • clip mode: No behavioral change
  • With veto: Gradient scale will change for affected sequences

@zhaochenyang20
Copy link
Collaborator

great catch

@yueming-yuan yueming-yuan marked this pull request as draft October 31, 2025 18:06
@@ -1,5 +1,5 @@
# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py
use_mis: false
use_tis: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we change the mis -> tis here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just seems "use_mis" is not used anywhere - we may delete it?

)

pg_loss = sum_of_sample_mean(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not use_tis, then pg_loss would rely on the passed in sum_of_sample_mean. If using tis, the we will create a new sum_of_sample_mean with modified_response_masks by:

        sum_of_sample_mean = get_sum_of_sample_mean(
            total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
        )

Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if we don't use TIS then we do not update this sum_of_sample_mean function, which was originally created from loss_mask

Copy link
Collaborator

@yitianlian yitianlian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding more scalable func for different TIS methods in Slime would be a great feature. Firstly, improve the MIS example, and secondly, introduce a new return in the TIS function, as some methods may mask more tokens.

@yueming-yuan yueming-yuan marked this pull request as ready for review November 1, 2025 16:37
@zhuzilin zhuzilin merged commit ad2ada3 into THUDM:main Nov 2, 2025
4 checks passed
@szrlee
Copy link

szrlee commented Nov 2, 2025

@yueming-yuan please have a look on verl-project/verl#3984

Additional to verl-project/verl#3915, we have now fully separated rejection sampling masks from importance weights, allowing them to be combined independently.

@yueming-yuan
Copy link
Collaborator Author

@yueming-yuan please have a look on volcengine/verl#3984

Additional to volcengine/verl#3915, we have now fully separated rejection sampling masks from importance weights, allowing them to be combined independently.

Thanks!! We'll check this new version and look into integration.

@zhaochenyang20
Copy link
Collaborator

Nice examples:

# Input
log_ratios = [0.1, -1.5, 0.8, -10.0, 0.3]
ratios     = [1.11, 0.22, 2.23, 0.00005, 1.35]  # tokens 1,2,3 rejected

# previous implementation
is_weights = [1.11, 0.0, 0.0, 0.0, 1.35]
loss_mask  = [1, 1, 1, 1, 1] 
pg_loss = sum(loss * is_weights) / sum(loss_mask)
            = (1.11 + 0 + 0 + 0 + 1.35) / 5 = 0.49  # ⚠️ denominator includes masked entries -> Smaller loss norm

# NEW Implementation
is_weights       = [1.11, 0.22, 2.23, 0.00005, 1.35]  # weights preserved
modified_mask    = [1, 0, 0, 0, 1]                      # rejection separate
pg_loss = sum(loss * is_weights * modified_mask) / sum(modified_mask)
            = (1.11 + 0 + 0 + 0 + 1.35) / 2 = 1.23   # ✅ denominator excludes masked entries

llltttwww pushed a commit to llltttwww/slime that referenced this pull request Nov 30, 2025
Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants