@@ -145,7 +145,7 @@ def compute_mis_weights(
145145 len (train_log_probs ) == len (rollout_log_probs ) == len (loss_masks )
146146 ), f"Input lists must have the same number of sequences: { len (train_log_probs )} vs { len (rollout_log_probs )} vs { len (loss_masks )} "
147147
148- for i , (train , rollout , loss_mask ) in enumerate (zip (train_log_probs , rollout_log_probs , loss_masks )):
148+ for i , (train , rollout , loss_mask ) in enumerate (zip (train_log_probs , rollout_log_probs , loss_masks , strict = False )):
149149 assert (
150150 train .shape == rollout .shape == loss_mask .shape
151151 ), f"Sequence { i } : shapes must match - train: { train .shape } , rollout: { rollout .shape } , loss_mask: { loss_mask .shape } "
@@ -164,15 +164,19 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
164164 else :
165165 raise ValueError (f"Invalid level: { level } " )
166166
167- for train_log_prob , rollout_log_prob , loss_mask in zip (train_log_probs , rollout_log_probs , loss_masks ):
167+ for train_log_prob , rollout_log_prob , loss_mask in zip (
168+ train_log_probs , rollout_log_probs , loss_masks , strict = False
169+ ):
168170 add_ppl_metrics (train_log_prob , rollout_log_prob , loss_mask , metrics )
169171
170172 # only calculate mismatch metrics if TIS is not used
171173 if not args .use_tis :
172174 return None , loss_masks , metrics
173175
174176 # handle each sequence independently
175- for train_log_prob , rollout_log_prob , loss_mask in zip (train_log_probs , rollout_log_probs , loss_masks ):
177+ for train_log_prob , rollout_log_prob , loss_mask in zip (
178+ train_log_probs , rollout_log_probs , loss_masks , strict = False
179+ ):
176180 loss_mask = loss_mask .float ()
177181 raw_log_ratio_diff = train_log_prob - rollout_log_prob
178182 modified_mask = loss_mask .clone ().float ()
@@ -228,14 +232,14 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
228232 tis_level = args .tis_level if args .use_tis else "token"
229233 if tis_level == "token" :
230234 # Token-level: normalize over all token weights
231- total_weights_sum = sum (masked_sum (w , m ) for w , m in zip (all_weights , loss_masks ))
235+ total_weights_sum = sum (masked_sum (w , m ) for w , m in zip (all_weights , loss_masks , strict = False ))
232236 total_mask_count = sum (m .sum () for m in loss_masks )
233237 weights_mean = total_weights_sum / torch .clamp_min (total_mask_count , 1 )
234238 elif tis_level == "sequence" :
235239 # Sequence-level: normalize over sequence weights (one weight per sequence)
236240 # For each sequence, compute mean over valid tokens (they all have the same weight)
237241 # then average across sequences
238- seq_weights_means = [masked_mean (w , m ) for w , m in zip (all_weights , loss_masks )]
242+ seq_weights_means = [masked_mean (w , m ) for w , m in zip (all_weights , loss_masks , strict = False )]
239243 weights_mean = sum (seq_weights_means ) / len (seq_weights_means )
240244 else :
241245 raise ValueError (f"Unsupported tis_level: { tis_level } " )
@@ -279,11 +283,15 @@ def compute_mis_weights_with_cp(
279283 # Gather cp slice from other cp ranks
280284 full_rollout_log_probs = [
281285 all_gather_with_cp (log_prob , total_length , response_length )
282- for log_prob , total_length , response_length in zip (rollout_log_probs , total_lengths , response_lengths )
286+ for log_prob , total_length , response_length in zip (
287+ rollout_log_probs , total_lengths , response_lengths , strict = False
288+ )
283289 ]
284290 full_old_log_probs = [
285291 all_gather_with_cp (old_log_prob , total_length , response_length )
286- for old_log_prob , total_length , response_length in zip (train_log_probs , total_lengths , response_lengths )
292+ for old_log_prob , total_length , response_length in zip (
293+ train_log_probs , total_lengths , response_lengths , strict = False
294+ )
287295 ]
288296
289297 # Main logic for is (decoupled)
0 commit comments