Skip to content

Commit a06adef

Browse files
authored
move pg_loss into tis_function for icepop (#635)
1 parent 08118ec commit a06adef

File tree

2 files changed

+9
-5
lines changed
  • examples/train_infer_mismatch_helper
  • slime/backends/megatron_utils

2 files changed

+9
-5
lines changed

examples/train_infer_mismatch_helper/mis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def compute_mis_weights(
218218
def compute_mis_weights_with_cp(
219219
args,
220220
*,
221+
pg_loss: torch.Tensor,
221222
train_log_probs: list[torch.Tensor],
222223
rollout_log_probs: list[torch.Tensor],
223224
loss_masks: list[torch.Tensor],
@@ -274,7 +275,9 @@ def slice_cp_and_concat(
274275
values = slice_cp_and_concat(values, total_lengths, response_lengths)
275276
result_metrics[key_name] = values
276277

277-
return is_weights, result_metrics
278+
pg_loss = pg_loss * is_weights
279+
280+
return pg_loss, result_metrics
278281

279282

280283
def add_ppl_metrics(

slime/backends/megatron_utils/loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def policy_loss_function(
424424
def vanilla_tis_function(
425425
args,
426426
*,
427+
pg_loss: torch.Tensor,
427428
train_log_probs: list[torch.Tensor],
428429
rollout_log_probs: list[torch.Tensor],
429430
**kwargs: Any,
@@ -439,13 +440,15 @@ def vanilla_tis_function(
439440
"tis_clipfrac": tis_clipfrac.clone().detach(),
440441
"tis_abs": tis_abs.clone().detach(),
441442
}
442-
return tis_weights, metrics
443+
pg_loss = pg_loss * tis_weights
444+
return pg_loss, metrics
443445

444446
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
445447

446448
ois = (-ppo_kl).exp()
447449
tis_kwargs = {
448450
"args": args,
451+
"pg_loss": pg_loss,
449452
"train_log_probs": batch["log_probs"],
450453
"rollout_log_probs": batch["rollout_log_probs"],
451454
"loss_masks": batch["loss_masks"],
@@ -457,9 +460,7 @@ def vanilla_tis_function(
457460
tis_func = load_function(args.custom_tis_function_path)
458461
else:
459462
tis_func = vanilla_tis_function
460-
tis_weights, tis_metrics = tis_func(**tis_kwargs)
461-
462-
pg_loss = pg_loss * tis_weights
463+
pg_loss, tis_metrics = tis_func(**tis_kwargs)
463464

464465
pg_loss = sum_of_sample_mean(pg_loss)
465466
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)

0 commit comments

Comments
 (0)