We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d9c1aac commit 22f4e22Copy full SHA for 22f4e22
slime/backends/megatron_utils/loss.py
@@ -410,6 +410,7 @@ def policy_loss_function(
410
]
411
ppo_kl = [kl.expand_as(log_prob) for kl, log_prob in zip(ppo_kl, log_probs)]
412
ppo_kl = torch.cat(ppo_kl, dim=0)
413
+ old_log_probs = torch.cat(old_log_probs, dim=0)
414
log_probs = torch.cat(log_probs, dim=0)
415
else:
416
old_log_probs = torch.cat(old_log_probs, dim=0)
0 commit comments