Decouple IS Weights from Rejection Sampling in MIS#657
Conversation
|
great catch |
| @@ -1,5 +1,5 @@ | |||
| # Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py | |||
| use_mis: false | |||
| use_tis: true | |||
There was a problem hiding this comment.
shall we change the mis -> tis here?
There was a problem hiding this comment.
It just seems "use_mis" is not used anywhere - we may delete it?
| ) | ||
|
|
||
| pg_loss = sum_of_sample_mean(pg_loss) | ||
| pg_clipfrac = sum_of_sample_mean(pg_clipfrac) |
There was a problem hiding this comment.
If not use_tis, then pg_loss would rely on the passed in sum_of_sample_mean. If using tis, the we will create a new sum_of_sample_mean with modified_response_masks by:
sum_of_sample_mean = get_sum_of_sample_mean(
total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
)Right?
There was a problem hiding this comment.
yes, if we don't use TIS then we do not update this sum_of_sample_mean function, which was originally created from loss_mask
yitianlian
left a comment
There was a problem hiding this comment.
I think adding more scalable func for different TIS methods in Slime would be a great feature. Firstly, improve the MIS example, and secondly, introduce a new return in the TIS function, as some methods may mask more tokens.
|
@yueming-yuan please have a look on verl-project/verl#3984 Additional to verl-project/verl#3915, we have now fully separated rejection sampling masks from importance weights, allowing them to be combined independently. |
Thanks!! We'll check this new version and look into integration. |
|
Nice examples: # Input
log_ratios = [0.1, -1.5, 0.8, -10.0, 0.3]
ratios = [1.11, 0.22, 2.23, 0.00005, 1.35] # tokens 1,2,3 rejected
# previous implementation
is_weights = [1.11, 0.0, 0.0, 0.0, 1.35]
loss_mask = [1, 1, 1, 1, 1]
pg_loss = sum(loss * is_weights) / sum(loss_mask)
= (1.11 + 0 + 0 + 0 + 1.35) / 5 = 0.49 # ⚠️ denominator includes masked entries -> Smaller loss norm
# NEW Implementation
is_weights = [1.11, 0.22, 2.23, 0.00005, 1.35] # weights preserved
modified_mask = [1, 0, 0, 0, 1] # rejection separate
pg_loss = sum(loss * is_weights * modified_mask) / sum(modified_mask)
= (1.11 + 0 + 0 + 0 + 1.35) / 2 = 1.23 # ✅ denominator excludes masked entries |
References
Summary
Refactors the Masked Importance Sampling (MIS) implementation to properly separate IS weight correction from rejection sampling. This fixes a critical gradient normalization bug where rejected tokens were incorrectly included in the loss denominator.
Motivation
The previous implementation have two distinct mechanisms by zeroing IS weights at rejected positions:
This led to an issue: rejected tokens had zero weights (numerator) but were still counted in the loss denominator, causing incorrect gradient scaling.
Example
Consider a sequence with 5 tokens where 3 are rejected by mask mode (ratios outside [0.5, 2.0]):
Main Changes
1. API Change
compute_mis_weights_with_cp()now returns 3 values instead of 2:2. Separation of IS Weights and Rejection Sampling
IS Weights (
is_weights):truncate: Upper clamped tomis_upper_boundmask: Safety-bounded only (no threshold clamping)clip: Clamped to [lower, upper]Rejection Sampling (
modified_response_masks):maskmode: Excludes tokens with IS ratios outside [lower, upper]veto: Excludes entire sequences with catastrophic tokens (ratio < veto_threshold)3. Correct Loss Normalization
Files Changed
Modified Files
slime/backends/megatron_utils/loss.py: Updated to correct loss normalization with modified masksexamples/train_infer_mismatch_helper/mis.py: Refactored to return 3-tuple, separated IS weights from rejectionImpact by Mode
truncatemode: No behavioral changemaskmode: Gradient scale will change (increase) when rejection rate > 0clipmode: No behavioral changeveto: Gradient scale will change for affected sequences