Skip to content

Commit 8ec2e42

Browse files
committed
Online fixes
1 parent 218d493 commit 8ec2e42

File tree

1 file changed

+3
-27
lines changed

1 file changed

+3
-27
lines changed

trl/trainer/grpo_trainer.py

+3-27
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,7 @@ def _generate_and_score_completions(
832832
completions = completions_text
833833

834834
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
835-
has_valid_rewards = torch.zeros(len(prompts), dtype=torch.bool, device=device)
836-
835+
837836
for i, (reward_func, reward_processing_class) in enumerate(
838837
zip(self.reward_funcs, self.reward_processing_classes)
839838
):
@@ -856,7 +855,6 @@ def _generate_and_score_completions(
856855
reward_inputs = super()._prepare_inputs(reward_inputs)
857856
with torch.inference_mode():
858857
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
859-
has_valid_rewards = has_valid_rewards | ~torch.isnan(rewards_per_func[:, i])
860858
else:
861859
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
862860
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
@@ -868,14 +866,6 @@ def _generate_and_score_completions(
868866
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
869867
# End of Selection
870868

871-
# Check if any sample has no valid rewards
872-
if not has_valid_rewards.all():
873-
invalid_count = (~has_valid_rewards).sum().item()
874-
warnings.warn(
875-
f"Found {invalid_count} samples with no valid rewards. "
876-
f"Please ensure at least one reward function returns valid rewards for each sample."
877-
)
878-
879869
# If all reward functions return None for a given row, issue a detailed warning
880870
if torch.isnan(rewards_per_func).all(dim=1).any():
881871
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
@@ -899,21 +889,9 @@ def _generate_and_score_completions(
899889
# Calculate the weighted sum, ignoring NaN values
900890
rewards = torch.nansum(weighted_rewards, dim=1)
901891

902-
# Check if all rewards are NaN
903-
if torch.isnan(rewards).all():
904-
warnings.warn(
905-
"No valid rewards found for any samples. All reward functions returned None. "
906-
"Training will not be effective without valid rewards."
907-
)
908-
# Set all rewards to 0 to avoid NaN propagation
909-
rewards = torch.zeros_like(rewards)
910-
911892
# Compute grouped-wise rewards
912893
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
913894
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
914-
915-
# Replace any NaN values in the standard deviation with a small positive number
916-
std_grouped_rewards = torch.nan_to_num(std_grouped_rewards, nan=1e-8)
917895

918896
# Normalize the rewards to compute the advantages
919897
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
@@ -939,11 +917,9 @@ def _generate_and_score_completions(
939917
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
940918
else:
941919
reward_func_name = reward_func.__name__
942-
943920
# Only calculate mean for samples where this reward function was applied (non-NaN values)
944-
mean_rewards = torch.nanmean(rewards_per_func[:, i])
945-
946-
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards.mean().item())
921+
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
922+
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
947923
self._metrics[mode]["reward"].append(rewards.mean().item())
948924
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
949925

0 commit comments

Comments
 (0)