Skip to content

Commit 667c40b

Browse files
committed
add a use_rollout_logprobs args to use rollout logprobs to calaulating is ratio
1 parent 1736c4b commit 667c40b

File tree

1 file changed

+2
-1
lines changed
  • slime/backends/megatron_utils

1 file changed

+2
-1
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,13 @@ def vanilla_tis_function(
431431
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
432432
old_log_probs = torch.cat(train_log_probs, dim=0)
433433
tis = torch.exp(old_log_probs - rollout_log_probs)
434+
tis_abs = torch.exp((old_log_probs - rollout_log_probs).abs())
434435
tis_weights = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip)
435436
tis_clipfrac = (tis_weights != tis).float()
436437
metrics = {
437438
"tis": tis.clone().detach(),
438439
"tis_clipfrac": tis_clipfrac.clone().detach(),
439-
"tis_abs": (1 - tis).abs().clone().detach(),
440+
"tis_abs": tis_abs.clone().detach(),
440441
}
441442
return tis_weights, metrics
442443

0 commit comments

Comments
 (0)