88import  time 
99import  uuid 
1010from  collections .abc  import  Mapping 
11+ from  collections  import  defaultdict 
1112from  typing  import  Any , Dict , List , Literal , Optional , Tuple 
1213
1314import  numpy  as  np 
2526
2627configure_logger ()
2728
28- 
29+ from  transformers  import  AutoTokenizer 
30+ model_dir  =  '/mnt/teamdrive/RAG_RL/models/meta-llama/Llama-3.2-3B' 
31+ tok  =  AutoTokenizer .from_pretrained (str (model_dir ), local_files_only = True , use_fast = True )
32+ def  _decode (ids , skip_special_tokens = True ):
33+     return  tok .decode (ids , skip_special_tokens = skip_special_tokens , clean_up_tokenization_spaces = False )
34+         
2935def  get_left_padded_ids_and_attention_mask (
3036    ids : List [int ], max_length : int , pad_token_id : int 
3137) ->  Tuple [List [int ], List [int ]]:
@@ -555,13 +561,28 @@ def get_test_metrics(self):
555561        assert  len (self ._completed_rollouts_v0 ) ==  self ._total_tasks_queued 
556562
557563        sample_stat_list : List [Dict [str , Any ]] =  []
558-         for  _ , rollout  in  self ._completed_rollouts_v0 .items ():
564+         sample_stat_list_by_source : Dict [str , List [Dict [str , Any ]]] =  defaultdict (
565+             list 
566+         )  # FIXME: Evaluate whether grouping stats by source is actually needed. 
567+ 
568+         for  rollout_id , rollout  in  self ._completed_rollouts_v0 .items ():
559569            final_reward  =  self ._fillna_reward (rollout )
560570            if  not  rollout .triplets :
561571                print (f"Warning: No triplets found for test rollout { rollout .rollout_id }  ." )
562572                sample_stat_list .append ({"reward" : final_reward })
563573                continue 
564574            response_length_list  =  [len (triplet .response .get ("token_ids" , [])) for  triplet  in  rollout .triplets ]
575+             if  "data_source"  in  self ._task_id_to_original_sample [rollout_id ]:
576+                 # When a test sample includes a 'data_source' field, record per-source statistics for test results. 
577+                 data_source  =  self ._task_id_to_original_sample [rollout_id ]["data_source" ]
578+                 sample_stat_list_by_source [data_source ].append (
579+                     {
580+                         "sum_response_length" : np .sum (response_length_list ),
581+                         "mean_response_length" : np .mean (response_length_list ) if  response_length_list  else  0 ,
582+                         "turn_count" : len (rollout .triplets ),
583+                         "reward" : final_reward ,
584+                     }
585+                 )
565586            sample_stat_list .append (
566587                {
567588                    "sum_response_length" : np .sum (response_length_list ),
@@ -570,18 +591,45 @@ def get_test_metrics(self):
570591                    "reward" : final_reward ,
571592                }
572593            )
594+         metric_dict : Dict [str , Any ] =  {}
573595
574596        stats_w_trace  =  [stat  for  stat  in  sample_stat_list  if  "sum_response_length"  in  stat ]
575-         return  {
576-             "val/n_rollouts" : len (sample_stat_list ),
577-             "val/n_rollouts_w_trace" : len (stats_w_trace ),
578-             "val/reward" : np .mean (
579-                 [stat ["reward" ] for  stat  in  sample_stat_list ]
580-             ),  # each rollout must have a reward (fillna if missing) 
581-             "val/mean_response_length" : np .mean ([stat ["mean_response_length" ] for  stat  in  stats_w_trace ]),
582-             "val/sum_response_length" : np .mean ([stat ["sum_response_length" ] for  stat  in  stats_w_trace ]),
583-             "val/turn_count" : np .mean ([stat ["turn_count" ] for  stat  in  stats_w_trace ]),
597+         stats_w_trace_by_source  =  {
598+             data_source : [stat  for  stat  in  sample_stats  if  "sum_response_length"  in  stat ]
599+             for  data_source , sample_stats  in  sample_stat_list_by_source .items ()
584600        }
601+         for  data_source , sample_stats  in  sample_stat_list_by_source .items ():
602+             metric_dict .update (
603+                 {
604+                     f"val/{ data_source }  /n_rollouts" : len (sample_stats ),
605+                     f"val/{ data_source }  /n_rollouts_w_trace" : len (stats_w_trace_by_source [data_source ]),
606+                     f"val/{ data_source }  /reward" : np .mean (
607+                         [stat ["reward" ] for  stat  in  sample_stats ]
608+                     ),  # each rollout must have a reward (fillna if missing) 
609+                     f"val/{ data_source }  /mean_response_length" : np .mean (
610+                         [stat ["mean_response_length" ] for  stat  in  stats_w_trace_by_source [data_source ]]
611+                     ),
612+                     f"val/{ data_source }  /sum_response_length" : np .mean (
613+                         [stat ["sum_response_length" ] for  stat  in  stats_w_trace_by_source [data_source ]]
614+                     ),
615+                     f"val/{ data_source }  /turn_count" : np .mean (
616+                         [stat ["turn_count" ] for  stat  in  stats_w_trace_by_source [data_source ]]
617+                     ),
618+                 }
619+             )
620+         metric_dict .update (
621+             {
622+                 "val/n_rollouts" : len (sample_stat_list ),
623+                 "val/n_rollouts_w_trace" : len (stats_w_trace ),
624+                 "val/reward" : np .mean (
625+                     [stat ["reward" ] for  stat  in  sample_stat_list ]
626+                 ),  # each rollout must have a reward (fillna if missing) 
627+                 "val/mean_response_length" : np .mean ([stat ["mean_response_length" ] for  stat  in  stats_w_trace ]),
628+                 "val/sum_response_length" : np .mean ([stat ["sum_response_length" ] for  stat  in  stats_w_trace ]),
629+                 "val/turn_count" : np .mean ([stat ["turn_count" ] for  stat  in  stats_w_trace ]),
630+             }
631+         )
632+         return  metric_dict 
585633
586634    def  get_train_data_batch (self , max_prompt_length : int , max_response_length : int , device : torch .device ):
587635        """ 
@@ -684,19 +732,32 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
684732            for  rollout_id , sample_info  in  finished_id_to_sample_info .items ():
685733                merged_trace_idx : List [List [int ]] =  []
686734                current_merged_trace_idx : List [int ] =  []
687-                 current_context : List [ int ]  =  [] 
735+                 current_context : str  =  "" 
688736                for  turn_index , trace  in  enumerate (sample_info ["trace_list" ]):
689-                     if  (trace ["prompt_ids" ] +  trace ["response_ids" ])[: len (current_context )] ==  current_context :
690-                         current_context  =  trace ["prompt_ids" ] +  trace ["response_ids" ]
737+                     # print('~' * 20) 
738+                     # print((trace["prompt_ids"] + trace["response_ids"])) 
739+                     # print(current_context) 
740+                     # print(f'|START|{_decode((trace["prompt_ids"] + trace["response_ids"]))}|END|') 
741+                     # print(f'|START|{_decode(current_context)}|END|') 
742+                     
743+                     temp_combined  =  _decode (trace ["prompt_ids" ] +  trace ["response_ids" ])
744+                     if  temp_combined [: len (current_context )] ==  current_context :
745+                     # if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context: 
746+                         current_context  =  temp_combined 
691747                        current_merged_trace_idx .append (turn_index )
692748                    else :
693749                        # assert len(current_merged_trace_idx) > 0 
694750                        merged_trace_idx .append (current_merged_trace_idx )
695751                        current_merged_trace_idx  =  [turn_index ]
696-                         current_context  =  trace [ "prompt_ids" ]  +   trace [ "response_ids" ] 
752+                         current_context  =  temp_combined 
697753                if  current_merged_trace_idx  not  in   merged_trace_idx :
698754                    merged_trace_idx .append (current_merged_trace_idx )
699755
756+                 print ('-'  *  20 )
757+                 print (merged_trace_idx )
758+                 # assert len(merged_trace_idx) == 1 
759+                 # assert sum(len(x) for x in merged_trace_idx) == len(sample_info["trace_list"]) 
760+                 
700761                for  current_merged_trace_idx  in  merged_trace_idx :
701762                    prompt_ids  =  sample_info ["trace_list" ][current_merged_trace_idx [0 ]]["prompt_ids" ]
702763                    response_ids  =  sample_info ["trace_list" ][current_merged_trace_idx [0 ]]["response_ids" ]
0 commit comments