1- import copy
21import glob
32import os
43import random
@@ -27,8 +26,10 @@ def icl_episodes(self, task):
2726 demos_dir = Path (self .original_cwd ) / self .config .eval .icl_dataset / self .env_name / task
2827 return list (sorted (glob .glob (os .path .join (demos_dir , "**/*.npz" ), recursive = True ), key = natural_sort_key ))
2928
30- def check_seed (self , demo_path ):
31- return int (demo_path .stem .split ("seed_" )[1 ])
29+ def extract_seed (self , demo_path ):
30+ # extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz`
31+ seed = [part .removeprefix ("seed" ) for part in Path (demo_path ).stem .split ("-" ) if "seed" in part ]
32+ return int (seed [0 ])
3233
3334 def demo_task (self , task ):
3435 # use different task - avoid the case where we put the solution into the context
@@ -37,24 +38,10 @@ def demo_task(self, task):
3738
3839 return task
3940
40- def demo_path (self , i , task , demo_config ):
41+ def demo_path (self , i , task ):
4142 icl_episodes = self .icl_episodes (task )
4243 demo_path = icl_episodes [i % len (icl_episodes )]
4344
44- # use the same role
45- if self .env_name == "nle" :
46- from balrog .environments .nle import Role
47-
48- character = demo_config .envs .nle_kwargs .character
49- if character != "@" :
50- for part in character .split ("-" ):
51- # check if there is specified role
52- if part .lower () in [e .value for e in Role ]:
53- # check if we have games played with this role
54- new_demo_paths = [path for path in icl_episodes if part .lower () in path .stem .lower ()]
55- if new_demo_paths :
56- demo_path = random .choice (new_demo_paths )
57-
5845 # use different seed - avoid the case where we put the solution into the context
5946 if self .env_name == "textworld" :
6047 from balrog .environments .textworld import global_textworld_context
@@ -63,7 +50,7 @@ def demo_path(self, i, task, demo_config):
6350 tasks = self .config .tasks .textworld_tasks , ** self .config .envs .textworld_kwargs
6451 )
6552 next_seed = textworld_context .count [task ]
66- demo_seed = self .check_seed (demo_path )
53+ demo_seed = self .extract_seed (demo_path )
6754 if next_seed == demo_seed :
6855 demo_path = self .icl_episodes (task )[i + 1 ]
6956
@@ -77,9 +64,8 @@ def load_episode(self, filename):
7764 return episode
7865
7966 def load_in_context_learning_episode (self , i , task , agent ):
80- demo_config = copy .deepcopy (self .config )
8167 demo_task = self .demo_task (task )
82- demo_path = self .demo_path (i , demo_task , demo_config )
68+ demo_path = self .demo_path (i , demo_task )
8369 episode = self .load_episode (demo_path )
8470
8571 actions = episode .pop ("action" )
0 commit comments