Skip to content

Commit 67a8d26

Browse files
authored
Add Few-Shot Learning Support to Balrog (#4)
* Few Shot Learning * turn off dummy actions * simplify dataset * add docs for few shot learning * update docs * fix download link * fix loading dataset * add parameter to limit the size of icl context * sample demonstrations randomly * quick fix * set default max_icl_history to 1000
1 parent 395ca9b commit 67a8d26

File tree

9 files changed

+326
-6
lines changed

9 files changed

+326
-6
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,5 +165,5 @@ cython_debug/
165165
/outputs
166166
/tw_games
167167
tw-games.zip
168-
/demos
169-
/demos.zip
168+
/records
169+
/records.zip

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ python eval.py \
7272
## Documentation
7373
- [Evaluation Guide](https://github.com/balrog-ai/BALROG/blob/main/docs/evaluation.md) - Detailed instructions for various evaluation scenarios
7474
- [Agent Development](https://github.com/balrog-ai/BALROG/blob/main/docs/agents.md) - Tutorial on creating custom agents
75+
- [Few Shot Learning](https://github.com/balrog-ai/BALROG/blob/main/docs/few_shot_learning.md) - Instructions on how to run Few Shot Learning
7576

7677
We welcome contributions! Please see our [Contributing Guidelines](https://github.com/balrog-ai/BALROG/blob/main/docs/contribution.md) for details.
7778

balrog/agents/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .chain_of_thought import ChainOfThoughtAgent
55
from .custom import CustomAgent
66
from .dummy import DummyAgent
7+
from .few_shot import FewShotAgent
78
from .naive import NaiveAgent
89

910

@@ -47,6 +48,8 @@ def create_agent(self):
4748
return DummyAgent(client_factory, prompt_builder)
4849
elif self.config.agent.type == "custom":
4950
return CustomAgent(client_factory, prompt_builder)
51+
elif self.config.agent.type == "few_shot":
52+
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)
5053

5154
else:
5255
raise ValueError(f"Unknown agent type: {self.config.agent}")

balrog/agents/few_shot.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import copy
2+
import re
3+
from typing import List, Optional
4+
5+
from balrog.agents.base import BaseAgent
6+
7+
8+
class Message:
9+
def __init__(self, role: str, content: str, attachment: Optional[object] = None):
10+
self.role = role # 'system', 'user', 'assistant'
11+
self.content = content # String content of the message
12+
self.attachment = attachment
13+
14+
def __repr__(self):
15+
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})"
16+
17+
18+
class FewShotAgent(BaseAgent):
19+
def __init__(self, client_factory, prompt_builder, max_icl_history):
20+
"""Initialize the FewShotAgent with a client and prompt builder."""
21+
super().__init__(client_factory, prompt_builder)
22+
self.client = client_factory()
23+
self.icl_episodes = []
24+
self.icl_events = []
25+
self.max_icl_history = max_icl_history
26+
self.cached_icl = False
27+
28+
def update_icl_observation(self, obs: dict):
29+
long_term_context = obs["text"].get("long_term_context", "")
30+
self.icl_events.append(
31+
{
32+
"type": "icl_observation",
33+
"text": long_term_context,
34+
}
35+
)
36+
37+
def update_icl_action(self, action: str):
38+
self.icl_events.append(
39+
{
40+
"type": "icl_action",
41+
"action": action,
42+
}
43+
)
44+
45+
def cache_icl(self):
46+
self.client.cache_icl_demo(self.get_icl_prompt())
47+
self.cached_icl = True
48+
49+
def wrap_episode(self):
50+
icl_episode = []
51+
icl_episode.append(
52+
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
53+
)
54+
for event in self.icl_events:
55+
if event["type"] == "icl_observation":
56+
content = "Obesrvation:\n" + event["text"]
57+
message = Message(role="user", content=content)
58+
elif event["type"] == "icl_action":
59+
content = event["action"]
60+
message = Message(role="assistant", content=content)
61+
icl_episode.append(message)
62+
icl_episode.append(
63+
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
64+
)
65+
66+
self.icl_episodes.append(icl_episode)
67+
self.icl_events = []
68+
69+
def get_icl_prompt(self) -> List[Message]:
70+
icl_instruction = Message(
71+
role="user",
72+
content=self.prompt_builder.system_prompt.replace(
73+
"PLAY",
74+
"First, observe the demonstrations provided and learn from them!",
75+
),
76+
)
77+
78+
# unroll the wrapped icl episodes messages
79+
icl_messages = [icl_instruction]
80+
i = 0
81+
for icl_episode in self.icl_episodes:
82+
episode_steps = len(icl_episode) - 2 # not count start and end messages
83+
if i + episode_steps <= self.max_icl_history:
84+
icl_messages.extend(icl_episode)
85+
i += episode_steps
86+
else:
87+
icl_episode = icl_episode[: self.max_icl_history - i + 1] + [
88+
icl_episode[-1]
89+
] # +1 for start message -1 for end message
90+
icl_messages.extend(icl_episode)
91+
i += len(icl_episode) - 2 # not count start and end messages
92+
break
93+
94+
end_demo_message = Message(
95+
role="user",
96+
content="****** Now it's your turn to play the game! ******",
97+
)
98+
icl_messages.append(end_demo_message)
99+
100+
return icl_messages
101+
102+
def act(self, obs, prev_action=None):
103+
"""Generate the next action based on the observation and previous action.
104+
105+
Args:
106+
obs (dict): The current observation in the environment.
107+
prev_action (str, optional): The previous action taken.
108+
109+
Returns:
110+
str: The selected action from the LLM response.
111+
"""
112+
if prev_action:
113+
self.prompt_builder.update_action(prev_action)
114+
115+
self.prompt_builder.update_observation(obs)
116+
117+
if not self.cached_icl:
118+
messages = self.get_icl_prompt()
119+
else:
120+
messages = []
121+
122+
messages.extend(self.prompt_builder.get_prompt(icl_episodes=True))
123+
124+
naive_instruction = """
125+
You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates.
126+
""".strip()
127+
128+
if messages and messages[-1].role == "user":
129+
messages[-1].content += "\n\n" + naive_instruction
130+
131+
response = self.client.generate(messages)
132+
133+
final_answer = self._extract_final_answer(response)
134+
135+
return final_answer
136+
137+
def _extract_final_answer(self, answer):
138+
"""Sanitize the final answer, keeping only alphabetic characters.
139+
140+
Args:
141+
answer (LLMResponse): The response from the LLM.
142+
143+
Returns:
144+
LLMResponse: The sanitized response.
145+
"""
146+
147+
def filter_letters(input_string):
148+
return re.sub(r"[^a-zA-Z\s:]", "", input_string)
149+
150+
final_answer = copy.deepcopy(answer)
151+
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))
152+
153+
return final_answer

balrog/config/config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ agent:
44
max_history: 16 # Maximum number of previous turns to keep in the dialogue history
55
max_image_history: 0 # Maximum number of images to keep in the history
66
max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent)
7+
max_icl_history: 1000 # Maximum number of ICL steps to keep in history (if using 'few_shot' type of agent)
8+
cache_icl: False
79

810
eval:
911
output_dir: "results" # Directory where evaluation results will be saved
@@ -19,7 +21,8 @@ eval:
1921
max_steps_per_episode: null # Max steps per episode; null uses the environment default
2022
save_trajectories: True # Whether to save agent trajectories (text only)
2123
save_images: False # Whether to save images from the environment
22-
24+
icl_episodes: 1
25+
icl_dataset: records
2326

2427
client:
2528
client_name: openai # LLM client to use (e.g., 'openai', 'gemini', 'claude')

balrog/dataset.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import glob
2+
import logging
3+
import os
4+
import random
5+
import re
6+
from pathlib import Path
7+
8+
import numpy as np
9+
10+
11+
def natural_sort_key(s):
12+
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))]
13+
14+
15+
def choice_excluding(lst, excluded_element):
16+
possible_choices = [item for item in lst if item != excluded_element]
17+
return random.choice(possible_choices)
18+
19+
20+
class InContextDataset:
21+
def __init__(self, config, env_name, original_cwd) -> None:
22+
self.config = config
23+
self.env_name = env_name
24+
self.original_cwd = original_cwd
25+
26+
def icl_episodes(self, task):
27+
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task
28+
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key))
29+
30+
def extract_seed(self, demo_path):
31+
# extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz`
32+
seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part]
33+
return int(seed[0])
34+
35+
def demo_task(self, task):
36+
# use different task - avoid the case where we put the solution into the context
37+
if self.env_name == "babaisai":
38+
task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task)
39+
40+
return task
41+
42+
def demo_path(self, i, task):
43+
icl_episodes = self.icl_episodes(task)
44+
demo_path = icl_episodes[i % len(icl_episodes)]
45+
46+
# use different seed - avoid the case where we put the solution into the context
47+
if self.env_name == "textworld":
48+
from balrog.environments.textworld import global_textworld_context
49+
50+
textworld_context = global_textworld_context(
51+
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs
52+
)
53+
next_seed = textworld_context.count[task]
54+
demo_seed = self.extract_seed(demo_path)
55+
if next_seed == demo_seed:
56+
demo_path = self.icl_episodes(task)[i + 1]
57+
58+
return demo_path
59+
60+
def load_episode(self, filename):
61+
# Load the compressed NPZ file
62+
with np.load(filename, allow_pickle=True) as data:
63+
# Convert to dictionary if you want
64+
episode = {k: data[k] for k in data.files}
65+
return episode
66+
67+
def load_in_context_learning_episodes(self, num_episodes, task, agent):
68+
demo_task = self.demo_task(task)
69+
demo_paths = [self.demo_path(i, demo_task) for i in range(len(self.icl_episodes(task)))]
70+
random.shuffle(demo_paths)
71+
demo_paths = demo_paths[:num_episodes]
72+
73+
for demo_path in demo_paths:
74+
self.load_in_context_learning_episode(demo_path, agent)
75+
76+
def load_in_context_learning_episode(self, demo_path, agent):
77+
episode = self.load_episode(demo_path)
78+
79+
actions = episode.pop("action").tolist()
80+
rewards = episode.pop("reward").tolist()
81+
terminated = episode.pop("terminated")
82+
truncated = episode.pop("truncated")
83+
dones = np.any([terminated, truncated], axis=0).tolist()
84+
observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())]
85+
86+
# first transition only contains observation (like env.reset())
87+
observation, action, reward, done = observations.pop(0), actions.pop(0), rewards.pop(0), dones.pop(0)
88+
agent.update_icl_observation(observation)
89+
90+
for observation, action, reward, done in zip(observations, actions, rewards, dones):
91+
action = str(action)
92+
93+
agent.update_icl_action(action)
94+
agent.update_icl_observation(observation)
95+
96+
if done:
97+
break
98+
99+
if not done:
100+
logging.info("icl trajectory ended without done")
101+
102+
agent.wrap_episode()

balrog/evaluator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from omegaconf import OmegaConf
1515
from tqdm import tqdm
1616

17+
from balrog.agents.few_shot import FewShotAgent
18+
from balrog.dataset import InContextDataset
1719
from balrog.environments import make_env
1820
from balrog.utils import get_unique_seed
1921

@@ -43,7 +45,7 @@ def __init__(self, config, original_cwd="", output_dir="."):
4345
self.env_evaluators = {}
4446
self.tasks = []
4547
for env_name in self.env_names:
46-
evaluator = Evaluator(env_name, config, output_dir=self.output_dir)
48+
evaluator = Evaluator(env_name, config, original_cwd=original_cwd, output_dir=self.output_dir)
4749
self.env_evaluators[env_name] = evaluator
4850
for task in evaluator.tasks:
4951
for episode_idx in range(evaluator.num_episodes):
@@ -219,7 +221,7 @@ class Evaluator:
219221
including loading in-context learning episodes and running episodes with the agent.
220222
"""
221223

222-
def __init__(self, env_name, config, output_dir="."):
224+
def __init__(self, env_name, config, original_cwd="", output_dir="."):
223225
"""Initialize the Evaluator.
224226
225227
Args:
@@ -237,6 +239,8 @@ def __init__(self, env_name, config, output_dir="."):
237239
self.num_workers = config.eval.num_workers
238240
self.max_steps_per_episode = config.eval.max_steps_per_episode
239241

242+
self.dataset = InContextDataset(self.config, self.env_name, original_cwd=original_cwd)
243+
240244
def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
241245
"""Run a single evaluation episode.
242246
@@ -284,6 +288,13 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
284288
csv_writer = csv.writer(csv_file, escapechar="˘", quoting=csv.QUOTE_MINIMAL)
285289
csv_writer.writerow(["Step", "Action", "Reasoning", "Observation", "Reward", "Done"])
286290

291+
# If the agent is an FewShotAgent, load the in-context learning episode
292+
if isinstance(agent, FewShotAgent):
293+
self.dataset.load_in_context_learning_episodes(self.config.eval.icl_episodes, task, agent)
294+
295+
if self.config.agent.cache_icl and self.config.client.client_name == "gemini":
296+
agent.cache_icl()
297+
287298
pbar_desc = f"Task: {task}, Proc: {process_num}"
288299
pbar = tqdm(
289300
total=max_steps_per_episode,

balrog/prompt_builder/history.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def reset(self):
7575
"""Clear the event history."""
7676
self._events.clear()
7777

78-
def get_prompt(self) -> List[Message]:
78+
def get_prompt(self, icl_episodes=False) -> List[Message]:
7979
"""Generate a list of Message objects representing the prompt.
8080
8181
Returns:
@@ -85,6 +85,9 @@ def get_prompt(self) -> List[Message]:
8585
if self.system_prompt:
8686
messages.append(Message(role="user", content=self.system_prompt))
8787

88+
if self.system_prompt and not icl_episodes:
89+
messages.append(Message(role="user", content=self.system_prompt))
90+
8891
# Determine which images to include
8992
images_needed = self.max_image_history
9093
for event in reversed(self._events):

0 commit comments

Comments
 (0)