Skip to content

Commit cf6beff

Browse files
committed
simplify dataset
1 parent ec54ce5 commit cf6beff

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

balrog/dataset.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import glob
32
import os
43
import random
@@ -27,8 +26,10 @@ def icl_episodes(self, task):
2726
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task
2827
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key))
2928

30-
def check_seed(self, demo_path):
31-
return int(demo_path.stem.split("seed_")[1])
29+
def extract_seed(self, demo_path):
30+
# extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz`
31+
seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part]
32+
return int(seed[0])
3233

3334
def demo_task(self, task):
3435
# use different task - avoid the case where we put the solution into the context
@@ -37,24 +38,10 @@ def demo_task(self, task):
3738

3839
return task
3940

40-
def demo_path(self, i, task, demo_config):
41+
def demo_path(self, i, task):
4142
icl_episodes = self.icl_episodes(task)
4243
demo_path = icl_episodes[i % len(icl_episodes)]
4344

44-
# use the same role
45-
if self.env_name == "nle":
46-
from balrog.environments.nle import Role
47-
48-
character = demo_config.envs.nle_kwargs.character
49-
if character != "@":
50-
for part in character.split("-"):
51-
# check if there is specified role
52-
if part.lower() in [e.value for e in Role]:
53-
# check if we have games played with this role
54-
new_demo_paths = [path for path in icl_episodes if part.lower() in path.stem.lower()]
55-
if new_demo_paths:
56-
demo_path = random.choice(new_demo_paths)
57-
5845
# use different seed - avoid the case where we put the solution into the context
5946
if self.env_name == "textworld":
6047
from balrog.environments.textworld import global_textworld_context
@@ -63,7 +50,7 @@ def demo_path(self, i, task, demo_config):
6350
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs
6451
)
6552
next_seed = textworld_context.count[task]
66-
demo_seed = self.check_seed(demo_path)
53+
demo_seed = self.extract_seed(demo_path)
6754
if next_seed == demo_seed:
6855
demo_path = self.icl_episodes(task)[i + 1]
6956

@@ -77,9 +64,8 @@ def load_episode(self, filename):
7764
return episode
7865

7966
def load_in_context_learning_episode(self, i, task, agent):
80-
demo_config = copy.deepcopy(self.config)
8167
demo_task = self.demo_task(task)
82-
demo_path = self.demo_path(i, demo_task, demo_config)
68+
demo_path = self.demo_path(i, demo_task)
8369
episode = self.load_episode(demo_path)
8470

8571
actions = episode.pop("action")

0 commit comments

Comments
 (0)