@@ -376,6 +376,64 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
376376 if args .log_passrate :
377377 log_passrate (rollout_id , args , rollout_data )
378378
379+ if args .log_correct_samples :
380+ if mpu .get_tensor_model_parallel_rank () == 0 and mpu .is_pipeline_last_stage ():
381+ cp_size = mpu .get_context_parallel_world_size ()
382+ log_dict = {}
383+ response_lengths = rollout_data ["response_lengths" ]
384+ loss_masks = rollout_data ["loss_masks" ]
385+ total_lengths = rollout_data ["total_lengths" ]
386+
387+ def quantile (total_value , n_quantiles , data ) -> dict :
388+ import math
389+
390+ assert n_quantiles > 1 , f"n_quantiles({ n_quantiles } ) must be greater than 1."
391+
392+ quantiles = [((i + 1 ) / n_quantiles ) for i in range (n_quantiles )]
393+ cut_points = [total_value * q for q in quantiles ]
394+ cut_points [- 1 ] = total_value
395+
396+ count = [0 ] * n_quantiles
397+ for d in data :
398+ for i , point in enumerate (cut_points ):
399+ if d <= point :
400+ count [i ] += 1
401+ break
402+
403+ total = sum (count ) + 1e-9
404+ percentile = [c / total for c in count ]
405+
406+ percentile = {f"p{ min (math .ceil (q * 100 ),100 )} " : p for q , p in zip (quantiles , percentile , strict = True )}
407+ return percentile
408+
409+ raw_rewards = rollout_data ["raw_reward" ]
410+ # Additional metrics for correct cases are calculated separately below.
411+ correct_response_lengths = []
412+ correct_total_lengths = []
413+ correct_loss_masks = []
414+ correct_entropy = []
415+ for i , raw_reward in enumerate (raw_rewards ):
416+ if raw_reward == 1 :
417+ correct_response_lengths .append (response_lengths [i ])
418+ correct_total_lengths .append (total_lengths [i ])
419+ correct_loss_masks .append (loss_masks [i ])
420+ correct_entropy .append (- rollout_data ["log_probs" ][i ])
421+ num_correct_responses = len (correct_total_lengths )
422+ rollout_data ["correct_response_lengths" ] = correct_response_lengths
423+ correct_response_length_percentile = quantile (
424+ args .rollout_max_response_len , 4 , rollout_data ["correct_response_lengths" ]
425+ )
426+ for p , val in correct_response_length_percentile .items ():
427+ rollout_data [f"correct_length/{ p } " ] = [val ] * num_correct_responses
428+ if len (correct_entropy ) > 0 :
429+ sum_of_sample_mean = get_sum_of_sample_mean (
430+ correct_total_lengths , correct_response_lengths , correct_loss_masks
431+ )
432+ correct_entropy = sum_of_sample_mean (torch .cat (correct_entropy , dim = 0 ))
433+ rollout_data ["correct_entropy" ] = [correct_entropy .item ()] * num_correct_responses
434+ else :
435+ rollout_data ["correct_entropy" ] = [0 ] * num_correct_responses
436+
379437
380438def log_multi_turn_data (rollout_id : int , args : Namespace , rollout_data : RolloutBatch ) -> None :
381439 """
0 commit comments