Skip to content

Commit ad2ada3

Browse files
authored
Decouple IS Weights from Rejection Sampling in MIS (#657)
1 parent 96895b3 commit ad2ada3

File tree

3 files changed

+42
-26
lines changed

3 files changed

+42
-26
lines changed

examples/train_infer_mismatch_helper/mis.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,15 @@ def mask(
101101
metrics: Dict[str, list[torch.Tensor]],
102102
lower_bound: float,
103103
upper_bound: float,
104-
) -> torch.Tensor:
104+
) -> Tuple[torch.Tensor, torch.Tensor]:
105105
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
106106
metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int())
107107
metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int())
108-
mask = (weights >= lower_bound) & (weights <= upper_bound)
109-
return weights * mask * loss_mask
108+
in_range = (weights >= lower_bound) & (weights <= upper_bound)
109+
modified_mask = loss_mask * in_range.float()
110+
# Zero out padding in weights but preserve values at non-rejected positions
111+
weights = weights * loss_mask
112+
return weights, modified_mask
110113

111114

112115
def compute_mis_weights(
@@ -115,7 +118,7 @@ def compute_mis_weights(
115118
train_log_probs: list[torch.Tensor],
116119
rollout_log_probs: list[torch.Tensor],
117120
loss_masks: list[torch.Tensor],
118-
) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]:
121+
) -> Tuple[list[torch.Tensor], list[torch.Tensor], Dict[str, list[torch.Tensor]]]:
119122
"""
120123
Compute the importance sampling (IS) weights and metrics between the inference and training engine.
121124
Args:
@@ -126,7 +129,8 @@ def compute_mis_weights(
126129
For multi-turn RL, the tool response will be marked as 0 in the loss_mask.
127130
128131
Returns:
129-
weights: List of importance sampling weights. 1D tensor each.
132+
weights: List of importance sampling weights (safety-bounded; zeroed at padding only). 1D tensor each.
133+
modified_response_masks: List of rejection masks to apply in aggregation (mask mode + veto). 1D tensor each.
130134
metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
131135
"""
132136

@@ -148,6 +152,7 @@ def compute_mis_weights(
148152

149153
SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow
150154
all_weights = []
155+
all_modified_masks = []
151156

152157
# handle each sequence independently
153158
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
@@ -172,19 +177,17 @@ def compute_mis_weights(
172177
weights = torch.exp(log_ratio_safe)
173178
metrics_append(metrics, "mean_is_weight_before_clip", weights)
174179

175-
# mask out catastrophic tokens
176-
if args.mis_veto_threshold is not None:
177-
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics)
180+
modified_mask = loss_mask.clone().float()
178181

179182
# mode: how to handle the importance sampling weights exceeding the thresholds.
180183
if args.mis_mode == "truncate":
181184
# Cap the importance sampling weights at the upper threshold
182185
# https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33
183186
weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound)
184187
elif args.mis_mode == "mask":
185-
# Zero the importance sampling weights outside the [lower, upper] range.
188+
# Preserve safety-bounded weights; apply thresholds via modified_mask
186189
# https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
187-
weights = mask(
190+
weights, modified_mask = mask(
188191
weights,
189192
loss_mask,
190193
metrics,
@@ -204,15 +207,20 @@ def compute_mis_weights(
204207
else:
205208
raise ValueError(f"Unsupported mis_mode: {args.mis_mode}")
206209

207-
metrics_append(metrics, "ratio_mean_after_mis", weights)
210+
# Veto on raw per-token ratios (sequence-wise rejection)
211+
# Works independently of truncate/mask mode and does NOT modify IS weights
208212
if args.mis_veto_threshold is not None:
209-
weights = weights * veto_mask
210-
metrics_append(metrics, "ratio_mean_after_veto_mask", weights)
213+
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics)
214+
modified_mask = modified_mask * veto_mask
215+
216+
metrics_append(metrics, "ratio_mean_after_mis", weights)
211217

212218
weights = weights.detach()
219+
modified_mask = modified_mask.detach()
213220
all_weights.append(weights)
221+
all_modified_masks.append(modified_mask)
214222

215-
return all_weights, metrics
223+
return all_weights, all_modified_masks, metrics
216224

217225

218226
def compute_mis_weights_with_cp(
@@ -225,7 +233,7 @@ def compute_mis_weights_with_cp(
225233
total_lengths: list[int],
226234
response_lengths: list[int],
227235
**kwargs: Any,
228-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
236+
) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]:
229237
"""
230238
Compute the importance sampling (IS) weights and metrics with context parallel.
231239
Args:
@@ -235,9 +243,9 @@ def compute_mis_weights_with_cp(
235243
total_lengths: List of total lengths.
236244
response_lengths: List of response lengths.
237245
Returns:
238-
is_weights: Importance sampling weights on this CP rank and flattened along dim=0.
239-
is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
240-
Also flattened along dim=0.
246+
pg_loss: Policy gradient loss with IS weights applied (flattened along dim=0).
247+
modified_masks: List of modified response masks with rejection applied (one per sequence).
248+
is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
241249
"""
242250
# Gather cp slice from other cp ranks
243251
full_rollout_log_probs = [
@@ -249,8 +257,8 @@ def compute_mis_weights_with_cp(
249257
for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths)
250258
]
251259

252-
# Main logic for is
253-
is_weights, is_metrics = compute_mis_weights(
260+
# Main logic for is (decoupled)
261+
is_weights, modified_masks, is_metrics = compute_mis_weights(
254262
args=args,
255263
train_log_probs=full_old_log_probs,
256264
rollout_log_probs=full_rollout_log_probs,
@@ -270,14 +278,15 @@ def slice_cp_and_concat(
270278

271279
result_metrics = {}
272280
is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths)
281+
273282
for key, values in is_metrics.items():
274283
key_name = f"mis_{key}"
275284
values = slice_cp_and_concat(values, total_lengths, response_lengths)
276285
result_metrics[key_name] = values
277286

278287
pg_loss = pg_loss * is_weights
279288

280-
return pg_loss, result_metrics
289+
return pg_loss, modified_masks, result_metrics
281290

282291

283292
def add_ppl_metrics(

examples/train_infer_mismatch_helper/mis.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py
2-
use_mis: false
2+
use_tis: true
33

44
# Aggregation level for importance sampling weights:
55
# token: per-token
@@ -11,7 +11,7 @@ mis_level: "token"
1111
# truncate: cap to upper bound, TIS
1212
# mask: zero outside [lower, upper], MIS
1313
# clip: clip to [lower, upper], CIS
14-
mis_mode: "truncate"
14+
mis_mode: "mask"
1515

1616
# For mask or clip mode, the lower bound of the IS weights.
1717
# For truncate mode, it will not be used.

slime/backends/megatron_utils/loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,9 @@ def vanilla_tis_function(
427427
pg_loss: torch.Tensor,
428428
train_log_probs: list[torch.Tensor],
429429
rollout_log_probs: list[torch.Tensor],
430+
loss_masks: list[torch.Tensor],
430431
**kwargs: Any,
431-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
432+
) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]:
432433
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
433434
old_log_probs = torch.cat(train_log_probs, dim=0)
434435
tis = torch.exp(old_log_probs - rollout_log_probs)
@@ -441,7 +442,7 @@ def vanilla_tis_function(
441442
"tis_abs": tis_abs.clone().detach(),
442443
}
443444
pg_loss = pg_loss * tis_weights
444-
return pg_loss, metrics
445+
return pg_loss, loss_masks, metrics
445446

446447
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
447448

@@ -460,7 +461,13 @@ def vanilla_tis_function(
460461
tis_func = load_function(args.custom_tis_function_path)
461462
else:
462463
tis_func = vanilla_tis_function
463-
pg_loss, tis_metrics = tis_func(**tis_kwargs)
464+
pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs)
465+
466+
# [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction
467+
# modified_response_masks will be sliced with cp in get_sum_of_sample_mean
468+
sum_of_sample_mean = get_sum_of_sample_mean(
469+
total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
470+
)
464471

465472
pg_loss = sum_of_sample_mean(pg_loss)
466473
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)

0 commit comments

Comments
 (0)