|
1 | | -import re |
2 | 1 | from typing import Union |
3 | 2 |
|
4 | 3 | import torch |
|
15 | 14 | get_reinforce_plus_plus_baseline_advantages, |
16 | 15 | get_reinforce_plus_plus_returns, |
17 | 16 | ) |
18 | | -from slime.utils.tis import compute_kl_metrics, compute_tis_weights |
| 17 | +from slime.utils.tis import assert_tis_input_format, compute_tis_weights |
19 | 18 |
|
20 | 19 | from .cp_utils import ( |
21 | 20 | all_gather_with_cp, |
@@ -314,78 +313,52 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): |
314 | 313 | # Apply TIS off-policy correction using importance sampling if enabled |
315 | 314 | if args.use_tis: |
316 | 315 | assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" |
317 | | - cp_size = mpu.get_context_parallel_world_size() |
318 | | - upper = args.tis_threshold_upper |
319 | | - lower = args.tis_threshold_lower |
320 | | - assert upper == 2.0 |
321 | 316 |
|
322 | | - total_lengths = batch["total_lengths"] |
323 | | - response_lengths = batch["response_lengths"] |
| 317 | + full_log_probs = [ |
| 318 | + all_gather_with_cp(log_prob, total_length, response_length) |
| 319 | + for log_prob, total_length, response_length in zip(log_probs, total_lengths, response_lengths) |
| 320 | + ] |
| 321 | + full_old_log_probs = [ |
| 322 | + all_gather_with_cp(old_log_prob, total_length, response_length) |
| 323 | + for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) |
| 324 | + ] |
324 | 325 |
|
325 | | - # 1) 组装全序列 old/rollout/mask(CP=1 直接拼接;CP>1 用 all_gather 重建) |
326 | | - if cp_size == 1: |
327 | | - full_old_list = batch["log_probs"] |
328 | | - full_rollout_list = batch["rollout_log_probs"] |
329 | | - full_mask_list = batch["loss_masks"] |
330 | | - else: |
331 | | - full_old_list = [ |
332 | | - all_gather_with_cp(lp, total_len, resp_len) |
333 | | - for lp, total_len, resp_len in zip(batch["log_probs"], total_lengths, response_lengths) |
334 | | - ] |
335 | | - full_rollout_list = [ |
336 | | - all_gather_with_cp(lp, total_len, resp_len) |
337 | | - for lp, total_len, resp_len in zip(batch["rollout_log_probs"], total_lengths, response_lengths) |
338 | | - ] |
339 | | - # loss_masks 已是每样本全序列 |
340 | | - full_mask_list = batch["loss_masks"] |
341 | | - |
342 | | - old_full_flat = torch.cat(full_old_list, dim=0) |
343 | | - rollout_full_flat = torch.cat(full_rollout_list, dim=0) |
344 | | - mask_full_flat = torch.cat(full_mask_list, dim=0).to(device=log_probs.device) |
345 | | - |
346 | | - # 2) 基本一致性与格式校验 |
347 | | - assert old_full_flat.shape == rollout_full_flat.shape == mask_full_flat.shape |
348 | | - loss_mask_str = "".join([str(int(x)) for x in mask_full_flat]) |
349 | | - pattern = r"^1+(0+1+)*0*1*$" |
350 | | - assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!" |
351 | | - |
352 | | - # 3) 全序列上计算 TIS 权重和指标 |
353 | | - tis_weights_full_flat, tis_metrics = compute_tis_weights( |
354 | | - old_log_prob=old_full_flat, |
355 | | - rollout_log_prob=rollout_full_flat, |
356 | | - loss_mask=mask_full_flat, |
| 326 | + # old_log_probs, log_probs, loss_masks are all concated into 1D tensor |
| 327 | + full_old_log_probs = torch.cat(full_old_log_probs, dim=0) |
| 328 | + full_log_probs = torch.cat(full_log_probs, dim=0) |
| 329 | + # loss_mask is not sliced by cp, so no need to all_gather |
| 330 | + full_loss_masks = torch.cat(batch["loss_masks"], dim=0) |
| 331 | + |
| 332 | + assert_tis_input_format(full_old_log_probs, full_log_probs, full_loss_masks) |
| 333 | + |
| 334 | + tis_weights, tis_metrics = compute_tis_weights( |
| 335 | + old_log_prob=full_old_log_probs, |
| 336 | + rollout_log_prob=full_log_probs, |
| 337 | + loss_mask=full_loss_masks, |
357 | 338 | level=getattr(args, "tis_level", "token"), |
358 | 339 | mode=getattr(args, "tis_mode", "truncate"), |
359 | | - upper_threshold=upper, |
360 | | - lower_threshold=lower, |
| 340 | + upper_threshold=getattr(args, "tis_threshold_upper", 2.0), |
| 341 | + lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)), |
361 | 342 | veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), |
362 | 343 | safety_bound=getattr(args, "tis_safety_bound", 20.0), |
363 | | - response_lengths=response_lengths, |
| 344 | + response_lengths=total_lengths, |
364 | 345 | ) |
365 | 346 |
|
366 | | - # On-policy ratio for monitoring (π_new/π_old) |
367 | 347 | ois = (-ppo_kl).exp() |
368 | 348 |
|
369 | | - # 4) 应用权重(CP>1 时回切至本地切片) |
370 | | - if tis_weights_full_flat is not None: |
371 | | - if cp_size == 1: |
372 | | - pg_loss = pg_loss * tis_weights_full_flat |
373 | | - else: |
374 | | - per_seq_weights = list(torch.split(tis_weights_full_flat, [int(l) for l in response_lengths], dim=0)) |
375 | | - local_weight_chunks = [ |
376 | | - slice_log_prob_with_cp(w, total_len, resp_len) |
377 | | - for w, total_len, resp_len in zip(per_seq_weights, total_lengths, response_lengths) |
378 | | - ] |
379 | | - tis_weights_local_flat = torch.cat(local_weight_chunks, dim=0) |
380 | | - pg_loss = pg_loss * tis_weights_local_flat |
381 | | - |
382 | | - # 5) KL 指标统一基于全序列 |
383 | | - kl_metrics = compute_kl_metrics( |
384 | | - old_log_prob=old_full_flat, |
385 | | - rollout_log_prob=rollout_full_flat, |
386 | | - loss_mask=mask_full_flat, |
387 | | - response_lengths=response_lengths, |
388 | | - ) |
| 349 | + # tis_weights is a 1D tensor, should be sliced to the local cp rank |
| 350 | + local_tis_chunks = [] |
| 351 | + start = 0 |
| 352 | + for total_len, response_len in zip(total_lengths, response_lengths): |
| 353 | + end = start + int(response_len) |
| 354 | + seq_weights = tis_weights[start:end] |
| 355 | + # Slice to the two local chunks of this CP rank |
| 356 | + local_chunk = slice_log_prob_with_cp(seq_weights, int(total_len), int(response_len)) |
| 357 | + local_tis_chunks.append(local_chunk) |
| 358 | + start = end |
| 359 | + tis_weights = torch.cat(local_tis_chunks, dim=0) |
| 360 | + |
| 361 | + pg_loss = pg_loss * tis_weights |
389 | 362 |
|
390 | 363 | pg_loss = sum_of_sample_mean(pg_loss) |
391 | 364 | pg_clipfrac = sum_of_sample_mean(pg_clipfrac) |
|
0 commit comments