Skip to content

Commit fa72297

Browse files
committed
change to text string match
1 parent 0d2dece commit fa72297

File tree

1 file changed

+76
-15
lines changed

1 file changed

+76
-15
lines changed

agentlightning/verl/daemon.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99
import uuid
1010
from collections.abc import Mapping
11+
from collections import defaultdict
1112
from typing import Any, Dict, List, Literal, Optional, Tuple
1213

1314
import numpy as np
@@ -25,7 +26,12 @@
2526

2627
configure_logger()
2728

28-
29+
from transformers import AutoTokenizer
30+
model_dir = '/mnt/teamdrive/RAG_RL/models/meta-llama/Llama-3.2-3B'
31+
tok = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=True)
32+
def _decode(ids, skip_special_tokens=True):
33+
return tok.decode(ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False)
34+
2935
def get_left_padded_ids_and_attention_mask(
3036
ids: List[int], max_length: int, pad_token_id: int
3137
) -> Tuple[List[int], List[int]]:
@@ -555,13 +561,28 @@ def get_test_metrics(self):
555561
assert len(self._completed_rollouts_v0) == self._total_tasks_queued
556562

557563
sample_stat_list: List[Dict[str, Any]] = []
558-
for _, rollout in self._completed_rollouts_v0.items():
564+
sample_stat_list_by_source: Dict[str, List[Dict[str, Any]]] = defaultdict(
565+
list
566+
) # FIXME: Evaluate whether grouping stats by source is actually needed.
567+
568+
for rollout_id, rollout in self._completed_rollouts_v0.items():
559569
final_reward = self._fillna_reward(rollout)
560570
if not rollout.triplets:
561571
print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
562572
sample_stat_list.append({"reward": final_reward})
563573
continue
564574
response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
575+
if "data_source" in self._task_id_to_original_sample[rollout_id]:
576+
# When a test sample includes a 'data_source' field, record per-source statistics for test results.
577+
data_source = self._task_id_to_original_sample[rollout_id]["data_source"]
578+
sample_stat_list_by_source[data_source].append(
579+
{
580+
"sum_response_length": np.sum(response_length_list),
581+
"mean_response_length": np.mean(response_length_list) if response_length_list else 0,
582+
"turn_count": len(rollout.triplets),
583+
"reward": final_reward,
584+
}
585+
)
565586
sample_stat_list.append(
566587
{
567588
"sum_response_length": np.sum(response_length_list),
@@ -570,18 +591,45 @@ def get_test_metrics(self):
570591
"reward": final_reward,
571592
}
572593
)
594+
metric_dict: Dict[str, Any] = {}
573595

574596
stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat]
575-
return {
576-
"val/n_rollouts": len(sample_stat_list),
577-
"val/n_rollouts_w_trace": len(stats_w_trace),
578-
"val/reward": np.mean(
579-
[stat["reward"] for stat in sample_stat_list]
580-
), # each rollout must have a reward (fillna if missing)
581-
"val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
582-
"val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
583-
"val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
597+
stats_w_trace_by_source = {
598+
data_source: [stat for stat in sample_stats if "sum_response_length" in stat]
599+
for data_source, sample_stats in sample_stat_list_by_source.items()
584600
}
601+
for data_source, sample_stats in sample_stat_list_by_source.items():
602+
metric_dict.update(
603+
{
604+
f"val/{data_source}/n_rollouts": len(sample_stats),
605+
f"val/{data_source}/n_rollouts_w_trace": len(stats_w_trace_by_source[data_source]),
606+
f"val/{data_source}/reward": np.mean(
607+
[stat["reward"] for stat in sample_stats]
608+
), # each rollout must have a reward (fillna if missing)
609+
f"val/{data_source}/mean_response_length": np.mean(
610+
[stat["mean_response_length"] for stat in stats_w_trace_by_source[data_source]]
611+
),
612+
f"val/{data_source}/sum_response_length": np.mean(
613+
[stat["sum_response_length"] for stat in stats_w_trace_by_source[data_source]]
614+
),
615+
f"val/{data_source}/turn_count": np.mean(
616+
[stat["turn_count"] for stat in stats_w_trace_by_source[data_source]]
617+
),
618+
}
619+
)
620+
metric_dict.update(
621+
{
622+
"val/n_rollouts": len(sample_stat_list),
623+
"val/n_rollouts_w_trace": len(stats_w_trace),
624+
"val/reward": np.mean(
625+
[stat["reward"] for stat in sample_stat_list]
626+
), # each rollout must have a reward (fillna if missing)
627+
"val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
628+
"val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
629+
"val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
630+
}
631+
)
632+
return metric_dict
585633

586634
def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, device: torch.device):
587635
"""
@@ -684,19 +732,32 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
684732
for rollout_id, sample_info in finished_id_to_sample_info.items():
685733
merged_trace_idx: List[List[int]] = []
686734
current_merged_trace_idx: List[int] = []
687-
current_context: List[int] = []
735+
current_context: str = ""
688736
for turn_index, trace in enumerate(sample_info["trace_list"]):
689-
if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context:
690-
current_context = trace["prompt_ids"] + trace["response_ids"]
737+
# print('~' * 20)
738+
# print((trace["prompt_ids"] + trace["response_ids"]))
739+
# print(current_context)
740+
# print(f'|START|{_decode((trace["prompt_ids"] + trace["response_ids"]))}|END|')
741+
# print(f'|START|{_decode(current_context)}|END|')
742+
743+
temp_combined = _decode(trace["prompt_ids"] + trace["response_ids"])
744+
if temp_combined[: len(current_context)] == current_context:
745+
# if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context:
746+
current_context = temp_combined
691747
current_merged_trace_idx.append(turn_index)
692748
else:
693749
# assert len(current_merged_trace_idx) > 0
694750
merged_trace_idx.append(current_merged_trace_idx)
695751
current_merged_trace_idx = [turn_index]
696-
current_context = trace["prompt_ids"] + trace["response_ids"]
752+
current_context = temp_combined
697753
if current_merged_trace_idx not in merged_trace_idx:
698754
merged_trace_idx.append(current_merged_trace_idx)
699755

756+
print('-' * 20)
757+
print(merged_trace_idx)
758+
# assert len(merged_trace_idx) == 1
759+
# assert sum(len(x) for x in merged_trace_idx) == len(sample_info["trace_list"])
760+
700761
for current_merged_trace_idx in merged_trace_idx:
701762
prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"]
702763
response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"]

0 commit comments

Comments
 (0)