We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1736c4b commit 667c40bCopy full SHA for 667c40b
slime/backends/megatron_utils/loss.py
@@ -431,12 +431,13 @@ def vanilla_tis_function(
431
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
432
old_log_probs = torch.cat(train_log_probs, dim=0)
433
tis = torch.exp(old_log_probs - rollout_log_probs)
434
+ tis_abs = torch.exp((old_log_probs - rollout_log_probs).abs())
435
tis_weights = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip)
436
tis_clipfrac = (tis_weights != tis).float()
437
metrics = {
438
"tis": tis.clone().detach(),
439
"tis_clipfrac": tis_clipfrac.clone().detach(),
- "tis_abs": (1 - tis).abs().clone().detach(),
440
+ "tis_abs": tis_abs.clone().detach(),
441
}
442
return tis_weights, metrics
443
0 commit comments