88import time
99import uuid
1010from collections .abc import Mapping
11+ from collections import defaultdict
1112from typing import Any , Dict , List , Literal , Optional , Tuple
1213
1314import numpy as np
2526
2627configure_logger ()
2728
28-
29+ from transformers import AutoTokenizer
30+ model_dir = '/mnt/teamdrive/RAG_RL/models/meta-llama/Llama-3.2-3B'
31+ tok = AutoTokenizer .from_pretrained (str (model_dir ), local_files_only = True , use_fast = True )
32+ def _decode (ids , skip_special_tokens = True ):
33+ return tok .decode (ids , skip_special_tokens = skip_special_tokens , clean_up_tokenization_spaces = False )
34+
2935def get_left_padded_ids_and_attention_mask (
3036 ids : List [int ], max_length : int , pad_token_id : int
3137) -> Tuple [List [int ], List [int ]]:
@@ -555,13 +561,28 @@ def get_test_metrics(self):
555561 assert len (self ._completed_rollouts_v0 ) == self ._total_tasks_queued
556562
557563 sample_stat_list : List [Dict [str , Any ]] = []
558- for _ , rollout in self ._completed_rollouts_v0 .items ():
564+ sample_stat_list_by_source : Dict [str , List [Dict [str , Any ]]] = defaultdict (
565+ list
566+ ) # FIXME: Evaluate whether grouping stats by source is actually needed.
567+
568+ for rollout_id , rollout in self ._completed_rollouts_v0 .items ():
559569 final_reward = self ._fillna_reward (rollout )
560570 if not rollout .triplets :
561571 print (f"Warning: No triplets found for test rollout { rollout .rollout_id } ." )
562572 sample_stat_list .append ({"reward" : final_reward })
563573 continue
564574 response_length_list = [len (triplet .response .get ("token_ids" , [])) for triplet in rollout .triplets ]
575+ if "data_source" in self ._task_id_to_original_sample [rollout_id ]:
576+ # When a test sample includes a 'data_source' field, record per-source statistics for test results.
577+ data_source = self ._task_id_to_original_sample [rollout_id ]["data_source" ]
578+ sample_stat_list_by_source [data_source ].append (
579+ {
580+ "sum_response_length" : np .sum (response_length_list ),
581+ "mean_response_length" : np .mean (response_length_list ) if response_length_list else 0 ,
582+ "turn_count" : len (rollout .triplets ),
583+ "reward" : final_reward ,
584+ }
585+ )
565586 sample_stat_list .append (
566587 {
567588 "sum_response_length" : np .sum (response_length_list ),
@@ -570,18 +591,45 @@ def get_test_metrics(self):
570591 "reward" : final_reward ,
571592 }
572593 )
594+ metric_dict : Dict [str , Any ] = {}
573595
574596 stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat ]
575- return {
576- "val/n_rollouts" : len (sample_stat_list ),
577- "val/n_rollouts_w_trace" : len (stats_w_trace ),
578- "val/reward" : np .mean (
579- [stat ["reward" ] for stat in sample_stat_list ]
580- ), # each rollout must have a reward (fillna if missing)
581- "val/mean_response_length" : np .mean ([stat ["mean_response_length" ] for stat in stats_w_trace ]),
582- "val/sum_response_length" : np .mean ([stat ["sum_response_length" ] for stat in stats_w_trace ]),
583- "val/turn_count" : np .mean ([stat ["turn_count" ] for stat in stats_w_trace ]),
597+ stats_w_trace_by_source = {
598+ data_source : [stat for stat in sample_stats if "sum_response_length" in stat ]
599+ for data_source , sample_stats in sample_stat_list_by_source .items ()
584600 }
601+ for data_source , sample_stats in sample_stat_list_by_source .items ():
602+ metric_dict .update (
603+ {
604+ f"val/{ data_source } /n_rollouts" : len (sample_stats ),
605+ f"val/{ data_source } /n_rollouts_w_trace" : len (stats_w_trace_by_source [data_source ]),
606+ f"val/{ data_source } /reward" : np .mean (
607+ [stat ["reward" ] for stat in sample_stats ]
608+ ), # each rollout must have a reward (fillna if missing)
609+ f"val/{ data_source } /mean_response_length" : np .mean (
610+ [stat ["mean_response_length" ] for stat in stats_w_trace_by_source [data_source ]]
611+ ),
612+ f"val/{ data_source } /sum_response_length" : np .mean (
613+ [stat ["sum_response_length" ] for stat in stats_w_trace_by_source [data_source ]]
614+ ),
615+ f"val/{ data_source } /turn_count" : np .mean (
616+ [stat ["turn_count" ] for stat in stats_w_trace_by_source [data_source ]]
617+ ),
618+ }
619+ )
620+ metric_dict .update (
621+ {
622+ "val/n_rollouts" : len (sample_stat_list ),
623+ "val/n_rollouts_w_trace" : len (stats_w_trace ),
624+ "val/reward" : np .mean (
625+ [stat ["reward" ] for stat in sample_stat_list ]
626+ ), # each rollout must have a reward (fillna if missing)
627+ "val/mean_response_length" : np .mean ([stat ["mean_response_length" ] for stat in stats_w_trace ]),
628+ "val/sum_response_length" : np .mean ([stat ["sum_response_length" ] for stat in stats_w_trace ]),
629+ "val/turn_count" : np .mean ([stat ["turn_count" ] for stat in stats_w_trace ]),
630+ }
631+ )
632+ return metric_dict
585633
586634 def get_train_data_batch (self , max_prompt_length : int , max_response_length : int , device : torch .device ):
587635 """
@@ -684,19 +732,32 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
684732 for rollout_id , sample_info in finished_id_to_sample_info .items ():
685733 merged_trace_idx : List [List [int ]] = []
686734 current_merged_trace_idx : List [int ] = []
687- current_context : List [ int ] = []
735+ current_context : str = ""
688736 for turn_index , trace in enumerate (sample_info ["trace_list" ]):
689- if (trace ["prompt_ids" ] + trace ["response_ids" ])[: len (current_context )] == current_context :
690- current_context = trace ["prompt_ids" ] + trace ["response_ids" ]
737+ # print('~' * 20)
738+ # print((trace["prompt_ids"] + trace["response_ids"]))
739+ # print(current_context)
740+ # print(f'|START|{_decode((trace["prompt_ids"] + trace["response_ids"]))}|END|')
741+ # print(f'|START|{_decode(current_context)}|END|')
742+
743+ temp_combined = _decode (trace ["prompt_ids" ] + trace ["response_ids" ])
744+ if temp_combined [: len (current_context )] == current_context :
745+ # if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context:
746+ current_context = temp_combined
691747 current_merged_trace_idx .append (turn_index )
692748 else :
693749 # assert len(current_merged_trace_idx) > 0
694750 merged_trace_idx .append (current_merged_trace_idx )
695751 current_merged_trace_idx = [turn_index ]
696- current_context = trace [ "prompt_ids" ] + trace [ "response_ids" ]
752+ current_context = temp_combined
697753 if current_merged_trace_idx not in merged_trace_idx :
698754 merged_trace_idx .append (current_merged_trace_idx )
699755
756+ print ('-' * 20 )
757+ print (merged_trace_idx )
758+ # assert len(merged_trace_idx) == 1
759+ # assert sum(len(x) for x in merged_trace_idx) == len(sample_info["trace_list"])
760+
700761 for current_merged_trace_idx in merged_trace_idx :
701762 prompt_ids = sample_info ["trace_list" ][current_merged_trace_idx [0 ]]["prompt_ids" ]
702763 response_ids = sample_info ["trace_list" ][current_merged_trace_idx [0 ]]["response_ids" ]
0 commit comments