Skip to content

Commit 649f45e

Browse files
committed
sample demonstrations randomly
1 parent 542d224 commit 649f45e

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

balrog/dataset.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import glob
2+
import logging
23
import os
34
import random
45
import re
@@ -63,9 +64,15 @@ def load_episode(self, filename):
6364
episode = {k: data[k] for k in data.files}
6465
return episode
6566

66-
def load_in_context_learning_episode(self, i, task, agent):
67+
def load_in_context_learning_episodes(self, num_episodes, task, agent):
6768
demo_task = self.demo_task(task)
68-
demo_path = self.demo_path(i, demo_task)
69+
demo_paths = [self.demo_path(i, demo_task) for i in range(num_episodes)]
70+
random.shuffle(demo_paths)
71+
72+
for demo_path in demo_paths:
73+
self.load_in_context_learning_episode(demo_path, agent)
74+
75+
def load_in_context_learning_episode(self, demo_path, agent):
6976
episode = self.load_episode(demo_path)
7077

7178
actions = episode.pop("action").tolist()
@@ -89,6 +96,6 @@ def load_in_context_learning_episode(self, i, task, agent):
8996
break
9097

9198
if not done:
92-
print("warning: icl trajectory ended without done")
99+
logging.info("icl trajectory ended without done")
93100

94101
agent.wrap_episode()

balrog/evaluator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,7 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
290290

291291
# If the agent is an FewShotAgent, load the in-context learning episode
292292
if isinstance(agent, FewShotAgent):
293-
for icl_episode in range(self.config.eval.icl_episodes):
294-
self.dataset.load_in_context_learning_episode(icl_episode, task, agent)
293+
self.dataset.load_in_context_learning_episodes(self.config.eval.icl_episodes, task, agent)
295294

296295
if self.config.agent.cache_icl and self.config.client.client_name == "gemini":
297296
agent.cache_icl()

0 commit comments

Comments
 (0)