1- import re
21from typing import Any , Dict , Optional , Tuple
32
43import torch
54
65
7- def assert_tis_input_format (
8- full_old_log_probs : torch .Tensor ,
9- full_log_probs : torch .Tensor ,
10- full_loss_masks : torch .Tensor ,
11- ) -> None :
12- assert all (
13- tensor .dim () == 1 for tensor in [full_old_log_probs , full_log_probs , full_loss_masks ]
14- ), f"{ full_old_log_probs .dim ()} vs { full_log_probs .dim ()} vs { full_loss_masks .dim ()} "
15-
16- assert (
17- full_old_log_probs .shape == full_log_probs .shape and full_old_log_probs .shape == full_loss_masks .shape
18- ), f"{ full_old_log_probs .shape } vs { full_log_probs .shape } vs { full_loss_masks .shape } "
19-
20- loss_mask_str = "" .join ([str (int (x )) for x in full_loss_masks ])
21- pattern = r"^1+(0+1+)*0*1*$"
22- assert re .fullmatch (pattern , loss_mask_str ), "loss_mask format is not expected!"
23-
24-
25- def masked_sum (x : torch .Tensor , mask : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
26- """
27- Computes the sum of the tensor x, masked by the mask.
28-
29- x = [[1, 2, 3], [4, 5, 6]]
30- mask = [[1, 1, 1], [1, 1, 0]]
31- masked_sum(x, mask, dim=-1) = [6, 9]
32- """
33- valid_tokens = mask .sum (dim = dim )
34- assert valid_tokens .min () > 0 , "any sequence must have at least one valid token"
35- assert x .shape == mask .shape , "x and mask must have the same shape"
36- return (x * mask ).sum (dim = dim )
37-
38-
39- def masked_mean (x : torch .Tensor , mask : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
40- """
41- Computes the mean of the tensor x, masked by the mask.
42-
43- x = [[1, 2, 3], [4, 5, 6]]
44- mask = [[1, 1, 1], [1, 1, 0]]
45- masked_mean(x, mask, dim=-1) = [2, 4.5]
46- """
47- valid_tokens = mask .sum (dim = dim )
48- assert valid_tokens .min () > 0 , "any sequence must have at least one valid token"
49- return masked_sum (x , mask , dim = dim ) / valid_tokens
50-
51-
52- def per_seq_masked_mean (
53- x : torch .Tensor ,
54- mask : torch .Tensor ,
55- response_lengths : Optional [list [int ]] = None ,
56- ) -> torch .Tensor :
57- """
58- 计算按样本的 masked mean 后再求和,返回一个可加性的标量(适配 DP 汇总)。
59- 支持二维 [B, T] 与拍平后一维、并提供 response_lengths 的两种输入形态。
60- """
61- if response_lengths is not None and len (response_lengths ) > 0 :
62- sequence_log_ratios = torch .split (x , [int (l ) for l in response_lengths ], dim = 0 )
63- sequence_loss_masks = torch .split (mask , [int (l ) for l in response_lengths ], dim = 0 )
64- seq_means = [
65- masked_mean (sequence_log_ratio , sequence_loss_mask )
66- for sequence_log_ratio , sequence_loss_mask in zip (sequence_log_ratios , sequence_loss_masks )
67- ]
68- return torch .stack (seq_means ).sum ()
69- # fallback:视为单一样本
70- return masked_mean (x , mask ).unsqueeze (0 ).sum ()
71-
72-
736def compute_tis_weights (
747 * ,
75- old_log_prob : torch .Tensor ,
76- rollout_log_prob : torch .Tensor ,
77- loss_mask : torch .Tensor ,
8+ old_log_prob_flat : torch .Tensor ,
9+ rollout_log_prob_flat : torch .Tensor ,
10+ loss_mask_flat : torch .Tensor ,
7811 level : str = "token" ,
7912 mode : str = "truncate" ,
8013 upper_threshold : Optional [float ] = None ,
8114 lower_threshold : Optional [float ] = None ,
8215 veto_threshold : float = 1e-4 ,
8316 safety_bound : float = 20.0 ,
8417 response_lengths : Optional [list [int ]] = None ,
18+ total_lengths : Optional [list [int ]] = None ,
8519) -> Tuple [Optional [torch .Tensor ], Dict [str , Any ]]:
8620 """
8721 Compute the truncated importance sampling (TIS) weights and metrics.
@@ -92,9 +26,11 @@ def compute_tis_weights(
9226 https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
9327
9428 Args:
95- old_log_prob: Flattened log probs from training backend. Shape: [sum(response_lengths)]
96- rollout_log_prob: Flattened log probs from rollout backend. Shape: [sum(response_lengths)]
97- loss_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)]
29+ old_log_prob_flat: Flattened log probs from training backend. Shape: [sum(response_lengths)]
30+ rollout_log_prob_flat: Flattened log probs from rollout backend. Shape: [sum(response_lengths)]
31+ loss_mask_flat: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)]
32+ Note that for single turn RL, the loss_mask_flat is [1] * sum(response_lengths)
33+ For multi turn RL, the tool response will be marked as 0 in the loss_mask_flat.
9834 level: The aggregation level for the importance sampling weights.
9935 - "token": per-token importance sampling weights, biased low variance.
10036 - "sequence": product over tokens, unbiased but high variance.
@@ -107,43 +43,51 @@ def compute_tis_weights(
10743 If not provided, it will be set to 1.0 / upper_threshold.
10844 veto_threshold: If any token's importance sampling weight is less than this, zero the entire sequence weight.
10945 safety_bound: The safety bound for the log-space ratio to avoid numerical overflow.
46+ response_lengths: The length of the response for each sequence.
47+ total_lengths: The total length of the whole sequence for each sequence.
11048
11149 Returns:
11250 weights: The importance sampling weights. [batch_size, seq_len]
11351 metrics: The metrics for the importance sampling weights.
11452 """
53+
54+ assert all (
55+ tensor .dim () == 1 for tensor in [old_log_prob_flat , rollout_log_prob_flat , loss_mask_flat ]
56+ ), f"{ old_log_prob_flat .dim ()} vs { rollout_log_prob_flat .dim ()} vs { loss_mask_flat .dim ()} "
57+
11558 assert (
116- loss_mask .shape == old_log_prob .shape and loss_mask .shape == rollout_log_prob .shape
117- ), "loss_mask, old_log_prob, and rollout_log_prob must have the same shape"
118- assert response_lengths is not None and len (response_lengths ) > 0 , "response_lengths must be provided"
59+ old_log_prob_flat .shape == rollout_log_prob_flat .shape and old_log_prob_flat .shape == loss_mask_flat .shape
60+ ), f"{ old_log_prob_flat .shape } vs { rollout_log_prob_flat .shape } vs { loss_mask_flat .shape } "
11961
12062 if upper_threshold is None :
12163 return None , {}
12264 if lower_threshold is None :
12365 lower_threshold = 1.0 / upper_threshold
12466
125- device = old_log_prob .device
126- log_ratio = old_log_prob - rollout_log_prob
67+ device = old_log_prob_flat .device
68+ log_ratio = old_log_prob_flat - rollout_log_prob_flat
12769
12870 log_upper_threshold = torch .log (torch .tensor (upper_threshold , device = device ))
12971 log_lower_threshold = torch .log (torch .tensor (lower_threshold , device = device ))
13072
73+ # compute TIS weights without truncation/clipping
74+
13175 if level == "token" :
132- # Token-level IS: π_training(a|s) / π_rollout(a|s) per token
76+ # Token-level IS: π_training(a|s) / π_rollout(a|s) per token
13377 # The truncation will be applied later.
13478 log_ratio_for_metrics = log_ratio # [sum(response_lengths)]
13579 log_ratio_safe = torch .clamp (log_ratio , min = - safety_bound , max = safety_bound )
13680 weights = torch .exp (log_ratio_safe )
13781 elif level in ["sequence" , "geometric" ]:
13882 # Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens
13983 sequence_log_ratios = torch .split (log_ratio , [int (l ) for l in response_lengths ], dim = 0 )
140- sequence_loss_masks = torch .split (loss_mask , [int (l ) for l in response_lengths ], dim = 0 )
84+ sequence_loss_masks = torch .split (loss_mask_flat , [int (l ) for l in response_lengths ], dim = 0 )
14185 per_seq_vals = []
14286 for sequence_log_ratio , sequence_loss_mask in zip (sequence_log_ratios , sequence_loss_masks ):
14387 if level == "sequence" :
14488 val = (sequence_log_ratio * sequence_loss_mask ).sum ()
14589 else : # geometric
146- val = masked_mean (sequence_log_ratio , sequence_loss_mask )
90+ val = (sequence_log_ratio * sequence_loss_mask ). sum () / sequence_loss_mask . sum ( )
14791 per_seq_vals .append (torch .clamp (val , min = - safety_bound , max = safety_bound ))
14892 per_seq_vals = torch .stack (per_seq_vals ) # [num_sequences]
14993 per_seq_weights = torch .exp (per_seq_vals )
@@ -157,10 +101,12 @@ def compute_tis_weights(
157101 else :
158102 raise ValueError (f"Invalid importance sampling level: { level } " )
159103
104+ # TODO:继续 filter out catastrophic tokens
105+
160106 log_veto_threshold = torch .log (torch .tensor (veto_threshold , device = device ))
161107 # Veto sequences with any token's log ratio below the threshold.
162108 # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold
163- catastrophic_tokens = (log_ratio < log_veto_threshold ) & loss_mask .bool ()
109+ catastrophic_tokens = (log_ratio < log_veto_threshold ) & loss_mask_flat .bool ()
164110 # Build per-sequence veto and expand to tokens
165111 cat_chunks = torch .split (catastrophic_tokens , [int (l ) for l in response_lengths ], dim = 0 )
166112 has_catastrophic_per_seq = torch .tensor ([chunk .any () for chunk in cat_chunks ], device = device )
0 commit comments