Skip to content

Commit 079a537

Browse files
authored
[FSDP] fix the rollout/raw_reward metrics calculation (#806)
1 parent d8dcb00 commit 079a537

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,12 @@ def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None:
446446

447447
self.compute_log_prob("actor", packed_batches)
448448

449-
for metric_key in ["log_probs", "ref_log_probs", "advantages", "returns", "raw_reward"]:
449+
if "raw_reward" in rollout_data and dist.get_rank() == 0:
450+
raw_reward_list = rollout_data["raw_reward"]
451+
if raw_reward_list:
452+
log_dict["rollout/raw_reward"] = sum(raw_reward_list) / len(raw_reward_list)
453+
454+
for metric_key in ["log_probs", "ref_log_probs", "advantages", "returns"]:
450455
if metric_key not in packed_batches[0]:
451456
continue
452457
val = torch.tensor([0.0], device=torch.cuda.current_device())

0 commit comments

Comments
 (0)