@@ -832,8 +832,7 @@ def _generate_and_score_completions(
832
832
completions = completions_text
833
833
834
834
rewards_per_func = torch .zeros (len (prompts ), len (self .reward_funcs ), device = device )
835
- has_valid_rewards = torch .zeros (len (prompts ), dtype = torch .bool , device = device )
836
-
835
+
837
836
for i , (reward_func , reward_processing_class ) in enumerate (
838
837
zip (self .reward_funcs , self .reward_processing_classes )
839
838
):
@@ -856,7 +855,6 @@ def _generate_and_score_completions(
856
855
reward_inputs = super ()._prepare_inputs (reward_inputs )
857
856
with torch .inference_mode ():
858
857
rewards_per_func [:, i ] = reward_func (** reward_inputs ).logits [:, 0 ] # Shape (B*G,)
859
- has_valid_rewards = has_valid_rewards | ~ torch .isnan (rewards_per_func [:, i ])
860
858
else :
861
859
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
862
860
keys = [key for key in inputs [0 ] if key not in ["prompt" , "completion" ]]
@@ -868,14 +866,6 @@ def _generate_and_score_completions(
868
866
rewards_per_func [:, i ] = torch .tensor (output_reward_func , dtype = torch .float32 , device = device )
869
867
# End of Selection
870
868
871
- # Check if any sample has no valid rewards
872
- if not has_valid_rewards .all ():
873
- invalid_count = (~ has_valid_rewards ).sum ().item ()
874
- warnings .warn (
875
- f"Found { invalid_count } samples with no valid rewards. "
876
- f"Please ensure at least one reward function returns valid rewards for each sample."
877
- )
878
-
879
869
# If all reward functions return None for a given row, issue a detailed warning
880
870
if torch .isnan (rewards_per_func ).all (dim = 1 ).any ():
881
871
nan_row_idx = torch .isnan (rewards_per_func ).all (dim = 1 ).nonzero (as_tuple = True )[0 ][0 ]
@@ -899,21 +889,9 @@ def _generate_and_score_completions(
899
889
# Calculate the weighted sum, ignoring NaN values
900
890
rewards = torch .nansum (weighted_rewards , dim = 1 )
901
891
902
- # Check if all rewards are NaN
903
- if torch .isnan (rewards ).all ():
904
- warnings .warn (
905
- "No valid rewards found for any samples. All reward functions returned None. "
906
- "Training will not be effective without valid rewards."
907
- )
908
- # Set all rewards to 0 to avoid NaN propagation
909
- rewards = torch .zeros_like (rewards )
910
-
911
892
# Compute grouped-wise rewards
912
893
mean_grouped_rewards = rewards .view (- 1 , self .num_generations ).mean (dim = 1 )
913
894
std_grouped_rewards = rewards .view (- 1 , self .num_generations ).std (dim = 1 )
914
-
915
- # Replace any NaN values in the standard deviation with a small positive number
916
- std_grouped_rewards = torch .nan_to_num (std_grouped_rewards , nan = 1e-8 )
917
895
918
896
# Normalize the rewards to compute the advantages
919
897
mean_grouped_rewards = mean_grouped_rewards .repeat_interleave (self .num_generations , dim = 0 )
@@ -939,11 +917,9 @@ def _generate_and_score_completions(
939
917
reward_func_name = reward_func .config ._name_or_path .split ("/" )[- 1 ]
940
918
else :
941
919
reward_func_name = reward_func .__name__
942
-
943
920
# Only calculate mean for samples where this reward function was applied (non-NaN values)
944
- mean_rewards = torch .nanmean (rewards_per_func [:, i ])
945
-
946
- self ._metrics [mode ][f"rewards/{ reward_func_name } " ].append (mean_rewards .mean ().item ())
921
+ mean_rewards = torch .nanmean (rewards_per_func [:, i ]).item ()
922
+ self ._metrics [mode ][f"rewards/{ reward_func_name } " ].append (mean_rewards )
947
923
self ._metrics [mode ]["reward" ].append (rewards .mean ().item ())
948
924
self ._metrics [mode ]["reward_std" ].append (std_grouped_rewards .mean ().item ())
949
925
0 commit comments