Skip to content

Commit 92e6e97

Browse files
[todo] filter out catastrophic tokens
1 parent 71194c3 commit 92e6e97

File tree

2 files changed

+37
-92
lines changed

2 files changed

+37
-92
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_reinforce_plus_plus_baseline_advantages,
1515
get_reinforce_plus_plus_returns,
1616
)
17-
from slime.utils.tis import assert_tis_input_format, compute_tis_weights
17+
from slime.utils.tis import compute_tis_weights
1818

1919
from .cp_utils import (
2020
all_gather_with_cp,
@@ -324,24 +324,23 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
324324
]
325325

326326
# old_log_probs, log_probs, loss_masks are all concated into 1D tensor
327-
full_old_log_probs = torch.cat(full_old_log_probs, dim=0)
328-
full_log_probs = torch.cat(full_log_probs, dim=0)
327+
full_old_log_probs_flat = torch.cat(full_old_log_probs, dim=0)
328+
full_log_probs_flat = torch.cat(full_log_probs, dim=0)
329329
# loss_mask is not sliced by cp, so no need to all_gather
330-
full_loss_masks = torch.cat(batch["loss_masks"], dim=0)
331-
332-
assert_tis_input_format(full_old_log_probs, full_log_probs, full_loss_masks)
330+
full_loss_masks_flat = torch.cat(batch["loss_masks"], dim=0)
333331

334332
tis_weights, tis_metrics = compute_tis_weights(
335-
old_log_prob=full_old_log_probs,
336-
rollout_log_prob=full_log_probs,
337-
loss_mask=full_loss_masks,
333+
old_log_prob_flat=full_old_log_probs_flat,
334+
rollout_log_prob_flat=full_log_probs_flat,
335+
loss_mask_flat=full_loss_masks_flat,
338336
level=getattr(args, "tis_level", "token"),
339337
mode=getattr(args, "tis_mode", "truncate"),
340338
upper_threshold=getattr(args, "tis_threshold_upper", 2.0),
341339
lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)),
342340
veto_threshold=getattr(args, "tis_veto_threshold", 1e-4),
343341
safety_bound=getattr(args, "tis_safety_bound", 20.0),
344-
response_lengths=total_lengths,
342+
response_lengths=response_lengths,
343+
total_lengths=total_lengths,
345344
)
346345

347346
ois = (-ppo_kl).exp()

slime/utils/tis.py

Lines changed: 28 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,21 @@
1-
import re
21
from typing import Any, Dict, Optional, Tuple
32

43
import torch
54

65

7-
def assert_tis_input_format(
8-
full_old_log_probs: torch.Tensor,
9-
full_log_probs: torch.Tensor,
10-
full_loss_masks: torch.Tensor,
11-
) -> None:
12-
assert all(
13-
tensor.dim() == 1 for tensor in [full_old_log_probs, full_log_probs, full_loss_masks]
14-
), f"{full_old_log_probs.dim()} vs {full_log_probs.dim()} vs {full_loss_masks.dim()}"
15-
16-
assert (
17-
full_old_log_probs.shape == full_log_probs.shape and full_old_log_probs.shape == full_loss_masks.shape
18-
), f"{full_old_log_probs.shape} vs {full_log_probs.shape} vs {full_loss_masks.shape}"
19-
20-
loss_mask_str = "".join([str(int(x)) for x in full_loss_masks])
21-
pattern = r"^1+(0+1+)*0*1*$"
22-
assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!"
23-
24-
25-
def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
26-
"""
27-
Computes the sum of the tensor x, masked by the mask.
28-
29-
x = [[1, 2, 3], [4, 5, 6]]
30-
mask = [[1, 1, 1], [1, 1, 0]]
31-
masked_sum(x, mask, dim=-1) = [6, 9]
32-
"""
33-
valid_tokens = mask.sum(dim=dim)
34-
assert valid_tokens.min() > 0, "any sequence must have at least one valid token"
35-
assert x.shape == mask.shape, "x and mask must have the same shape"
36-
return (x * mask).sum(dim=dim)
37-
38-
39-
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
40-
"""
41-
Computes the mean of the tensor x, masked by the mask.
42-
43-
x = [[1, 2, 3], [4, 5, 6]]
44-
mask = [[1, 1, 1], [1, 1, 0]]
45-
masked_mean(x, mask, dim=-1) = [2, 4.5]
46-
"""
47-
valid_tokens = mask.sum(dim=dim)
48-
assert valid_tokens.min() > 0, "any sequence must have at least one valid token"
49-
return masked_sum(x, mask, dim=dim) / valid_tokens
50-
51-
52-
def per_seq_masked_mean(
53-
x: torch.Tensor,
54-
mask: torch.Tensor,
55-
response_lengths: Optional[list[int]] = None,
56-
) -> torch.Tensor:
57-
"""
58-
计算按样本的 masked mean 后再求和,返回一个可加性的标量(适配 DP 汇总)。
59-
支持二维 [B, T] 与拍平后一维、并提供 response_lengths 的两种输入形态。
60-
"""
61-
if response_lengths is not None and len(response_lengths) > 0:
62-
sequence_log_ratios = torch.split(x, [int(l) for l in response_lengths], dim=0)
63-
sequence_loss_masks = torch.split(mask, [int(l) for l in response_lengths], dim=0)
64-
seq_means = [
65-
masked_mean(sequence_log_ratio, sequence_loss_mask)
66-
for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks)
67-
]
68-
return torch.stack(seq_means).sum()
69-
# fallback:视为单一样本
70-
return masked_mean(x, mask).unsqueeze(0).sum()
71-
72-
736
def compute_tis_weights(
747
*,
75-
old_log_prob: torch.Tensor,
76-
rollout_log_prob: torch.Tensor,
77-
loss_mask: torch.Tensor,
8+
old_log_prob_flat: torch.Tensor,
9+
rollout_log_prob_flat: torch.Tensor,
10+
loss_mask_flat: torch.Tensor,
7811
level: str = "token",
7912
mode: str = "truncate",
8013
upper_threshold: Optional[float] = None,
8114
lower_threshold: Optional[float] = None,
8215
veto_threshold: float = 1e-4,
8316
safety_bound: float = 20.0,
8417
response_lengths: Optional[list[int]] = None,
18+
total_lengths: Optional[list[int]] = None,
8519
) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]:
8620
"""
8721
Compute the truncated importance sampling (TIS) weights and metrics.
@@ -92,9 +26,11 @@ def compute_tis_weights(
9226
https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
9327
9428
Args:
95-
old_log_prob: Flattened log probs from training backend. Shape: [sum(response_lengths)]
96-
rollout_log_prob: Flattened log probs from rollout backend. Shape: [sum(response_lengths)]
97-
loss_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)]
29+
old_log_prob_flat: Flattened log probs from training backend. Shape: [sum(response_lengths)]
30+
rollout_log_prob_flat: Flattened log probs from rollout backend. Shape: [sum(response_lengths)]
31+
loss_mask_flat: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)]
32+
Note that for single turn RL, the loss_mask_flat is [1] * sum(response_lengths)
33+
For multi turn RL, the tool response will be marked as 0 in the loss_mask_flat.
9834
level: The aggregation level for the importance sampling weights.
9935
- "token": per-token importance sampling weights, biased low variance.
10036
- "sequence": product over tokens, unbiased but high variance.
@@ -107,43 +43,51 @@ def compute_tis_weights(
10743
If not provided, it will be set to 1.0 / upper_threshold.
10844
veto_threshold: If any token's importance sampling weight is less than this, zero the entire sequence weight.
10945
safety_bound: The safety bound for the log-space ratio to avoid numerical overflow.
46+
response_lengths: The length of the response for each sequence.
47+
total_lengths: The total length of the whole sequence for each sequence.
11048
11149
Returns:
11250
weights: The importance sampling weights. [batch_size, seq_len]
11351
metrics: The metrics for the importance sampling weights.
11452
"""
53+
54+
assert all(
55+
tensor.dim() == 1 for tensor in [old_log_prob_flat, rollout_log_prob_flat, loss_mask_flat]
56+
), f"{old_log_prob_flat.dim()} vs {rollout_log_prob_flat.dim()} vs {loss_mask_flat.dim()}"
57+
11558
assert (
116-
loss_mask.shape == old_log_prob.shape and loss_mask.shape == rollout_log_prob.shape
117-
), "loss_mask, old_log_prob, and rollout_log_prob must have the same shape"
118-
assert response_lengths is not None and len(response_lengths) > 0, "response_lengths must be provided"
59+
old_log_prob_flat.shape == rollout_log_prob_flat.shape and old_log_prob_flat.shape == loss_mask_flat.shape
60+
), f"{old_log_prob_flat.shape} vs {rollout_log_prob_flat.shape} vs {loss_mask_flat.shape}"
11961

12062
if upper_threshold is None:
12163
return None, {}
12264
if lower_threshold is None:
12365
lower_threshold = 1.0 / upper_threshold
12466

125-
device = old_log_prob.device
126-
log_ratio = old_log_prob - rollout_log_prob
67+
device = old_log_prob_flat.device
68+
log_ratio = old_log_prob_flat - rollout_log_prob_flat
12769

12870
log_upper_threshold = torch.log(torch.tensor(upper_threshold, device=device))
12971
log_lower_threshold = torch.log(torch.tensor(lower_threshold, device=device))
13072

73+
# compute TIS weights without truncation/clipping
74+
13175
if level == "token":
132-
# Token-level IS: π_training(a|s) / π_rollout(a|s) per token
76+
# Token-level IS: π_training(a|s) / π_rollout(a|s) per token
13377
# The truncation will be applied later.
13478
log_ratio_for_metrics = log_ratio # [sum(response_lengths)]
13579
log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound)
13680
weights = torch.exp(log_ratio_safe)
13781
elif level in ["sequence", "geometric"]:
13882
# Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens
13983
sequence_log_ratios = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0)
140-
sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0)
84+
sequence_loss_masks = torch.split(loss_mask_flat, [int(l) for l in response_lengths], dim=0)
14185
per_seq_vals = []
14286
for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks):
14387
if level == "sequence":
14488
val = (sequence_log_ratio * sequence_loss_mask).sum()
14589
else: # geometric
146-
val = masked_mean(sequence_log_ratio, sequence_loss_mask)
90+
val = (sequence_log_ratio * sequence_loss_mask).sum() / sequence_loss_mask.sum()
14791
per_seq_vals.append(torch.clamp(val, min=-safety_bound, max=safety_bound))
14892
per_seq_vals = torch.stack(per_seq_vals) # [num_sequences]
14993
per_seq_weights = torch.exp(per_seq_vals)
@@ -157,10 +101,12 @@ def compute_tis_weights(
157101
else:
158102
raise ValueError(f"Invalid importance sampling level: {level}")
159103

104+
# TODO:继续 filter out catastrophic tokens
105+
160106
log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device))
161107
# Veto sequences with any token's log ratio below the threshold.
162108
# log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold
163-
catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask.bool()
109+
catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask_flat.bool()
164110
# Build per-sequence veto and expand to tokens
165111
cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0)
166112
has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device)

0 commit comments

Comments
 (0)