@@ -101,12 +101,15 @@ def mask(
101101 metrics : Dict [str , list [torch .Tensor ]],
102102 lower_bound : float ,
103103 upper_bound : float ,
104- ) -> torch .Tensor :
104+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
105105 assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
106106 metrics_append (metrics , "mask_fraction_low" , (weights < lower_bound ).int ())
107107 metrics_append (metrics , "mask_fraction_high" , (weights > upper_bound ).int ())
108- mask = (weights >= lower_bound ) & (weights <= upper_bound )
109- return weights * mask * loss_mask
108+ in_range = (weights >= lower_bound ) & (weights <= upper_bound )
109+ modified_mask = loss_mask * in_range .float ()
110+ # Zero out padding in weights but preserve values at non-rejected positions
111+ weights = weights * loss_mask
112+ return weights , modified_mask
110113
111114
112115def compute_mis_weights (
@@ -115,7 +118,7 @@ def compute_mis_weights(
115118 train_log_probs : list [torch .Tensor ],
116119 rollout_log_probs : list [torch .Tensor ],
117120 loss_masks : list [torch .Tensor ],
118- ) -> Tuple [list [torch .Tensor ], Dict [str , list [torch .Tensor ]]]:
121+ ) -> Tuple [list [torch .Tensor ], list [ torch . Tensor ], Dict [str , list [torch .Tensor ]]]:
119122 """
120123 Compute the importance sampling (IS) weights and metrics between the inference and training engine.
121124 Args:
@@ -126,7 +129,8 @@ def compute_mis_weights(
126129 For multi-turn RL, the tool response will be marked as 0 in the loss_mask.
127130
128131 Returns:
129- weights: List of importance sampling weights. 1D tensor each.
132+ weights: List of importance sampling weights (safety-bounded; zeroed at padding only). 1D tensor each.
133+ modified_response_masks: List of rejection masks to apply in aggregation (mask mode + veto). 1D tensor each.
130134 metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
131135 """
132136
@@ -148,6 +152,7 @@ def compute_mis_weights(
148152
149153 SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow
150154 all_weights = []
155+ all_modified_masks = []
151156
152157 # handle each sequence independently
153158 for train_log_prob , rollout_log_prob , loss_mask in zip (train_log_probs , rollout_log_probs , loss_masks ):
@@ -172,19 +177,17 @@ def compute_mis_weights(
172177 weights = torch .exp (log_ratio_safe )
173178 metrics_append (metrics , "mean_is_weight_before_clip" , weights )
174179
175- # mask out catastrophic tokens
176- if args .mis_veto_threshold is not None :
177- veto_mask = calculate_veto_mask (raw_log_ratio_diff , loss_mask , args .mis_veto_threshold , metrics )
180+ modified_mask = loss_mask .clone ().float ()
178181
179182 # mode: how to handle the importance sampling weights exceeding the thresholds.
180183 if args .mis_mode == "truncate" :
181184 # Cap the importance sampling weights at the upper threshold
182185 # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33
183186 weights = truncate (weights , loss_mask , metrics , args .mis_upper_bound )
184187 elif args .mis_mode == "mask" :
185- # Zero the importance sampling weights outside the [lower, upper] range.
188+ # Preserve safety-bounded weights; apply thresholds via modified_mask
186189 # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
187- weights = mask (
190+ weights , modified_mask = mask (
188191 weights ,
189192 loss_mask ,
190193 metrics ,
@@ -204,15 +207,20 @@ def compute_mis_weights(
204207 else :
205208 raise ValueError (f"Unsupported mis_mode: { args .mis_mode } " )
206209
207- metrics_append (metrics , "ratio_mean_after_mis" , weights )
210+ # Veto on raw per-token ratios (sequence-wise rejection)
211+ # Works independently of truncate/mask mode and does NOT modify IS weights
208212 if args .mis_veto_threshold is not None :
209- weights = weights * veto_mask
210- metrics_append (metrics , "ratio_mean_after_veto_mask" , weights )
213+ veto_mask = calculate_veto_mask (raw_log_ratio_diff , loss_mask , args .mis_veto_threshold , metrics )
214+ modified_mask = modified_mask * veto_mask
215+
216+ metrics_append (metrics , "ratio_mean_after_mis" , weights )
211217
212218 weights = weights .detach ()
219+ modified_mask = modified_mask .detach ()
213220 all_weights .append (weights )
221+ all_modified_masks .append (modified_mask )
214222
215- return all_weights , metrics
223+ return all_weights , all_modified_masks , metrics
216224
217225
218226def compute_mis_weights_with_cp (
@@ -225,7 +233,7 @@ def compute_mis_weights_with_cp(
225233 total_lengths : list [int ],
226234 response_lengths : list [int ],
227235 ** kwargs : Any ,
228- ) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
236+ ) -> Tuple [torch .Tensor , list [ torch . Tensor ], Dict [str , torch .Tensor ]]:
229237 """
230238 Compute the importance sampling (IS) weights and metrics with context parallel.
231239 Args:
@@ -235,9 +243,9 @@ def compute_mis_weights_with_cp(
235243 total_lengths: List of total lengths.
236244 response_lengths: List of response lengths.
237245 Returns:
238- is_weights: Importance sampling weights on this CP rank and flattened along dim=0.
239- is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each .
240- Also flattened along dim=0 .
246+ pg_loss: Policy gradient loss with IS weights applied ( flattened along dim=0) .
247+ modified_masks: List of modified response masks with rejection applied (one per sequence) .
248+ is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors .
241249 """
242250 # Gather cp slice from other cp ranks
243251 full_rollout_log_probs = [
@@ -249,8 +257,8 @@ def compute_mis_weights_with_cp(
249257 for old_log_prob , total_length , response_length in zip (train_log_probs , total_lengths , response_lengths )
250258 ]
251259
252- # Main logic for is
253- is_weights , is_metrics = compute_mis_weights (
260+ # Main logic for is (decoupled)
261+ is_weights , modified_masks , is_metrics = compute_mis_weights (
254262 args = args ,
255263 train_log_probs = full_old_log_probs ,
256264 rollout_log_probs = full_rollout_log_probs ,
@@ -270,14 +278,15 @@ def slice_cp_and_concat(
270278
271279 result_metrics = {}
272280 is_weights = slice_cp_and_concat (is_weights , total_lengths , response_lengths )
281+
273282 for key , values in is_metrics .items ():
274283 key_name = f"mis_{ key } "
275284 values = slice_cp_and_concat (values , total_lengths , response_lengths )
276285 result_metrics [key_name ] = values
277286
278287 pg_loss = pg_loss * is_weights
279288
280- return pg_loss , result_metrics
289+ return pg_loss , modified_masks , result_metrics
281290
282291
283292def add_ppl_metrics (
0 commit comments