Skip to content

Commit 71194c3

Browse files
slice tis with slice_log_prob_with_cp
1 parent 5cac6e0 commit 71194c3

File tree

2 files changed

+56
-64
lines changed

2 files changed

+56
-64
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from typing import Union
32

43
import torch
@@ -15,7 +14,7 @@
1514
get_reinforce_plus_plus_baseline_advantages,
1615
get_reinforce_plus_plus_returns,
1716
)
18-
from slime.utils.tis import compute_kl_metrics, compute_tis_weights
17+
from slime.utils.tis import assert_tis_input_format, compute_tis_weights
1918

2019
from .cp_utils import (
2120
all_gather_with_cp,
@@ -314,78 +313,52 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
314313
# Apply TIS off-policy correction using importance sampling if enabled
315314
if args.use_tis:
316315
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
317-
cp_size = mpu.get_context_parallel_world_size()
318-
upper = args.tis_threshold_upper
319-
lower = args.tis_threshold_lower
320-
assert upper == 2.0
321316

322-
total_lengths = batch["total_lengths"]
323-
response_lengths = batch["response_lengths"]
317+
full_log_probs = [
318+
all_gather_with_cp(log_prob, total_length, response_length)
319+
for log_prob, total_length, response_length in zip(log_probs, total_lengths, response_lengths)
320+
]
321+
full_old_log_probs = [
322+
all_gather_with_cp(old_log_prob, total_length, response_length)
323+
for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths)
324+
]
324325

325-
# 1) 组装全序列 old/rollout/mask(CP=1 直接拼接;CP>1 用 all_gather 重建)
326-
if cp_size == 1:
327-
full_old_list = batch["log_probs"]
328-
full_rollout_list = batch["rollout_log_probs"]
329-
full_mask_list = batch["loss_masks"]
330-
else:
331-
full_old_list = [
332-
all_gather_with_cp(lp, total_len, resp_len)
333-
for lp, total_len, resp_len in zip(batch["log_probs"], total_lengths, response_lengths)
334-
]
335-
full_rollout_list = [
336-
all_gather_with_cp(lp, total_len, resp_len)
337-
for lp, total_len, resp_len in zip(batch["rollout_log_probs"], total_lengths, response_lengths)
338-
]
339-
# loss_masks 已是每样本全序列
340-
full_mask_list = batch["loss_masks"]
341-
342-
old_full_flat = torch.cat(full_old_list, dim=0)
343-
rollout_full_flat = torch.cat(full_rollout_list, dim=0)
344-
mask_full_flat = torch.cat(full_mask_list, dim=0).to(device=log_probs.device)
345-
346-
# 2) 基本一致性与格式校验
347-
assert old_full_flat.shape == rollout_full_flat.shape == mask_full_flat.shape
348-
loss_mask_str = "".join([str(int(x)) for x in mask_full_flat])
349-
pattern = r"^1+(0+1+)*0*1*$"
350-
assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!"
351-
352-
# 3) 全序列上计算 TIS 权重和指标
353-
tis_weights_full_flat, tis_metrics = compute_tis_weights(
354-
old_log_prob=old_full_flat,
355-
rollout_log_prob=rollout_full_flat,
356-
loss_mask=mask_full_flat,
326+
# old_log_probs, log_probs, loss_masks are all concated into 1D tensor
327+
full_old_log_probs = torch.cat(full_old_log_probs, dim=0)
328+
full_log_probs = torch.cat(full_log_probs, dim=0)
329+
# loss_mask is not sliced by cp, so no need to all_gather
330+
full_loss_masks = torch.cat(batch["loss_masks"], dim=0)
331+
332+
assert_tis_input_format(full_old_log_probs, full_log_probs, full_loss_masks)
333+
334+
tis_weights, tis_metrics = compute_tis_weights(
335+
old_log_prob=full_old_log_probs,
336+
rollout_log_prob=full_log_probs,
337+
loss_mask=full_loss_masks,
357338
level=getattr(args, "tis_level", "token"),
358339
mode=getattr(args, "tis_mode", "truncate"),
359-
upper_threshold=upper,
360-
lower_threshold=lower,
340+
upper_threshold=getattr(args, "tis_threshold_upper", 2.0),
341+
lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)),
361342
veto_threshold=getattr(args, "tis_veto_threshold", 1e-4),
362343
safety_bound=getattr(args, "tis_safety_bound", 20.0),
363-
response_lengths=response_lengths,
344+
response_lengths=total_lengths,
364345
)
365346

366-
# On-policy ratio for monitoring (π_new/π_old)
367347
ois = (-ppo_kl).exp()
368348

369-
# 4) 应用权重(CP>1 时回切至本地切片)
370-
if tis_weights_full_flat is not None:
371-
if cp_size == 1:
372-
pg_loss = pg_loss * tis_weights_full_flat
373-
else:
374-
per_seq_weights = list(torch.split(tis_weights_full_flat, [int(l) for l in response_lengths], dim=0))
375-
local_weight_chunks = [
376-
slice_log_prob_with_cp(w, total_len, resp_len)
377-
for w, total_len, resp_len in zip(per_seq_weights, total_lengths, response_lengths)
378-
]
379-
tis_weights_local_flat = torch.cat(local_weight_chunks, dim=0)
380-
pg_loss = pg_loss * tis_weights_local_flat
381-
382-
# 5) KL 指标统一基于全序列
383-
kl_metrics = compute_kl_metrics(
384-
old_log_prob=old_full_flat,
385-
rollout_log_prob=rollout_full_flat,
386-
loss_mask=mask_full_flat,
387-
response_lengths=response_lengths,
388-
)
349+
# tis_weights is a 1D tensor, should be sliced to the local cp rank
350+
local_tis_chunks = []
351+
start = 0
352+
for total_len, response_len in zip(total_lengths, response_lengths):
353+
end = start + int(response_len)
354+
seq_weights = tis_weights[start:end]
355+
# Slice to the two local chunks of this CP rank
356+
local_chunk = slice_log_prob_with_cp(seq_weights, int(total_len), int(response_len))
357+
local_tis_chunks.append(local_chunk)
358+
start = end
359+
tis_weights = torch.cat(local_tis_chunks, dim=0)
360+
361+
pg_loss = pg_loss * tis_weights
389362

390363
pg_loss = sum_of_sample_mean(pg_loss)
391364
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)

slime/utils/tis.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1+
import re
12
from typing import Any, Dict, Optional, Tuple
23

34
import torch
45

56

7+
def assert_tis_input_format(
8+
full_old_log_probs: torch.Tensor,
9+
full_log_probs: torch.Tensor,
10+
full_loss_masks: torch.Tensor,
11+
) -> None:
12+
assert all(
13+
tensor.dim() == 1 for tensor in [full_old_log_probs, full_log_probs, full_loss_masks]
14+
), f"{full_old_log_probs.dim()} vs {full_log_probs.dim()} vs {full_loss_masks.dim()}"
15+
16+
assert (
17+
full_old_log_probs.shape == full_log_probs.shape and full_old_log_probs.shape == full_loss_masks.shape
18+
), f"{full_old_log_probs.shape} vs {full_log_probs.shape} vs {full_loss_masks.shape}"
19+
20+
loss_mask_str = "".join([str(int(x)) for x in full_loss_masks])
21+
pattern = r"^1+(0+1+)*0*1*$"
22+
assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!"
23+
24+
625
def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
726
"""
827
Computes the sum of the tensor x, masked by the mask.

0 commit comments

Comments
 (0)