Skip to content

Commit 57aaf67

Browse files
committed
Few Shot Learning
1 parent 31df8ef commit 57aaf67

File tree

6 files changed

+274
-4
lines changed

6 files changed

+274
-4
lines changed

balrog/agents/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .chain_of_thought import ChainOfThoughtAgent
55
from .custom import CustomAgent
66
from .dummy import DummyAgent
7+
from .few_shot import FewShotAgent
78
from .naive import NaiveAgent
89

910

@@ -47,6 +48,8 @@ def create_agent(self):
4748
return DummyAgent(client_factory, prompt_builder)
4849
elif self.config.agent.type == "custom":
4950
return CustomAgent(client_factory, prompt_builder)
51+
elif self.config.agent.type == "few_shot":
52+
return FewShotAgent(client_factory, prompt_builder)
5053

5154
else:
5255
raise ValueError(f"Unknown agent type: {self.config.agent}")

balrog/agents/few_shot.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import copy
2+
import re
3+
from typing import List, Optional
4+
5+
from balrog.agents.base import BaseAgent
6+
7+
from .dummy import make_dummy_action
8+
9+
10+
class Message:
11+
def __init__(self, role: str, content: str, attachment: Optional[object] = None):
12+
self.role = role # 'system', 'user', 'assistant'
13+
self.content = content # String content of the message
14+
self.attachment = attachment
15+
16+
def __repr__(self):
17+
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})"
18+
19+
20+
class FewShotAgent(BaseAgent):
21+
def __init__(self, client_factory, prompt_builder):
22+
"""Initialize the FewShotAgent with a client and prompt builder."""
23+
super().__init__(client_factory, prompt_builder)
24+
# self.client = client_factory()
25+
self.icl_episodes = []
26+
self.icl_events = []
27+
self.cached_icl = False
28+
29+
def update_icl_observation(self, obs: dict):
30+
long_term_context = obs["text"].get("long_term_context", "")
31+
self.icl_events.append(
32+
{
33+
"type": "icl_observation",
34+
"text": long_term_context,
35+
}
36+
)
37+
38+
def update_icl_action(self, action: str):
39+
self.icl_events.append(
40+
{
41+
"type": "icl_action",
42+
"action": action,
43+
}
44+
)
45+
46+
def cache_icl(self):
47+
self.client.cache_icl_demo(self.get_icl_prompt())
48+
self.cached_icl = True
49+
50+
def wrap_episode(self):
51+
icl_episode = []
52+
icl_episode.append(
53+
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
54+
)
55+
for event in self.icl_events:
56+
if event["type"] == "icl_observation":
57+
content = "Obesrvation:\n" + event["text"]
58+
message = Message(role="user", content=content)
59+
elif event["type"] == "icl_action":
60+
content = event["action"]
61+
message = Message(role="assistant", content=content)
62+
icl_episode.append(message)
63+
icl_episode.append(
64+
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
65+
)
66+
67+
self.icl_episodes.append(icl_episode)
68+
self.icl_events = []
69+
70+
def get_icl_prompt(self) -> List[Message]:
71+
icl_instruction = Message(
72+
role="user",
73+
content=self.prompt_builder.system_prompt.replace(
74+
"PLAY",
75+
"First, observe the demonstrations provided and learn from them!",
76+
),
77+
)
78+
79+
# unroll the wrapped icl episodes messages
80+
icl_messages = [icl_instruction]
81+
for icl_episode in self.icl_episodes:
82+
icl_messages.extend(icl_episode)
83+
84+
end_demo_message = Message(
85+
role="user",
86+
content="****** Now it's your turn to play the game! ******",
87+
)
88+
icl_messages.append(end_demo_message)
89+
90+
return icl_messages
91+
92+
def act(self, obs, prev_action=None):
93+
"""Generate the next action based on the observation and previous action.
94+
95+
Args:
96+
obs (dict): The current observation in the environment.
97+
prev_action (str, optional): The previous action taken.
98+
99+
Returns:
100+
str: The selected action from the LLM response.
101+
"""
102+
if prev_action:
103+
self.prompt_builder.update_action(prev_action)
104+
105+
self.prompt_builder.update_observation(obs)
106+
107+
if not self.cached_icl:
108+
messages = self.get_icl_prompt()
109+
else:
110+
messages = []
111+
112+
messages.extend(self.prompt_builder.get_prompt(icl_episodes=True))
113+
114+
naive_instruction = """
115+
You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates.
116+
""".strip()
117+
118+
if messages and messages[-1].role == "user":
119+
messages[-1].content += "\n\n" + naive_instruction
120+
121+
response = make_dummy_action(messages)
122+
# response = self.client.generate(messages)
123+
124+
final_answer = self._extract_final_answer(response)
125+
126+
return final_answer
127+
128+
def _extract_final_answer(self, answer):
129+
"""Sanitize the final answer, keeping only alphabetic characters.
130+
131+
Args:
132+
answer (LLMResponse): The response from the LLM.
133+
134+
Returns:
135+
LLMResponse: The sanitized response.
136+
"""
137+
138+
def filter_letters(input_string):
139+
return re.sub(r"[^a-zA-Z\s:]", "", input_string)
140+
141+
final_answer = copy.deepcopy(answer)
142+
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))
143+
144+
return final_answer

balrog/config/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ agent:
44
max_history: 16 # Maximum number of previous turns to keep in the dialogue history
55
max_image_history: 0 # Maximum number of images to keep in the history
66
max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent)
7+
cache_icl: False
78

89
eval:
910
output_dir: "results" # Directory where evaluation results will be saved
@@ -19,7 +20,8 @@ eval:
1920
max_steps_per_episode: null # Max steps per episode; null uses the environment default
2021
save_trajectories: True # Whether to save agent trajectories (text only)
2122
save_images: False # Whether to save images from the environment
22-
23+
icl_episodes: 1
24+
icl_dataset: records
2325

2426
client:
2527
client_name: openai # LLM client to use (e.g., 'openai', 'gemini', 'claude')

balrog/dataset.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import copy
2+
import glob
3+
import os
4+
import random
5+
import re
6+
from pathlib import Path
7+
8+
import numpy as np
9+
10+
11+
def natural_sort_key(s):
12+
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))]
13+
14+
15+
def choice_excluding(lst, excluded_element):
16+
possible_choices = [item for item in lst if item != excluded_element]
17+
return random.choice(possible_choices)
18+
19+
20+
class InContextDataset:
21+
def __init__(self, config, env_name, original_cwd) -> None:
22+
self.config = config
23+
self.env_name = env_name
24+
self.original_cwd = original_cwd
25+
26+
def icl_episodes(self, task):
27+
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task
28+
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key))
29+
30+
def check_seed(self, demo_path):
31+
return int(demo_path.stem.split("seed_")[1])
32+
33+
def demo_task(self, task):
34+
# use different task - avoid the case where we put the solution into the context
35+
if self.env_name == "babaisai":
36+
task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task)
37+
38+
return task
39+
40+
def demo_path(self, i, task, demo_config):
41+
icl_episodes = self.icl_episodes(task)
42+
demo_path = icl_episodes[i % len(icl_episodes)]
43+
44+
# use the same role
45+
if self.env_name == "nle":
46+
from balrog.environments.nle import Role
47+
48+
character = demo_config.envs.nle_kwargs.character
49+
if character != "@":
50+
for part in character.split("-"):
51+
# check if there is specified role
52+
if part.lower() in [e.value for e in Role]:
53+
# check if we have games played with this role
54+
new_demo_paths = [path for path in icl_episodes if part.lower() in path.stem.lower()]
55+
if new_demo_paths:
56+
demo_path = random.choice(new_demo_paths)
57+
58+
# use different seed - avoid the case where we put the solution into the context
59+
if self.env_name == "textworld":
60+
from balrog.environments.textworld import global_textworld_context
61+
62+
textworld_context = global_textworld_context(
63+
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs
64+
)
65+
next_seed = textworld_context.count[task]
66+
demo_seed = self.check_seed(demo_path)
67+
if next_seed == demo_seed:
68+
demo_path = self.icl_episodes(task)[i + 1]
69+
70+
return demo_path
71+
72+
def load_episode(self, filename):
73+
# Load the compressed NPZ file
74+
with np.load(filename, allow_pickle=True) as data:
75+
# Convert to dictionary if you want
76+
episode = {k: data[k] for k in data.files}
77+
return episode
78+
79+
def load_in_context_learning_episode(self, i, task, agent):
80+
demo_config = copy.deepcopy(self.config)
81+
demo_task = self.demo_task(task)
82+
demo_path = self.demo_path(i, demo_task, demo_config)
83+
episode = self.load_episode(demo_path)
84+
85+
actions = episode.pop("action")
86+
rewards = episode.pop("reward")
87+
terminated = episode.pop("terminated")
88+
truncated = episode.pop("truncated")
89+
dones = np.any([terminated, truncated], axis=0)
90+
observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())]
91+
92+
for observation, action, reward, done in zip(observations, actions, rewards, dones):
93+
action = str(action)
94+
if action == "":
95+
action = None
96+
97+
agent.update_icl_observation(observation)
98+
agent.update_icl_action(action)
99+
100+
if done:
101+
break
102+
103+
if not done:
104+
print("warning: icl trajectory ended without done")
105+
106+
agent.wrap_episode()

balrog/evaluator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from omegaconf import OmegaConf
1515
from tqdm import tqdm
1616

17+
from balrog.agents.few_shot import FewShotAgent
18+
from balrog.dataset import InContextDataset
1719
from balrog.environments import make_env
1820
from balrog.utils import get_unique_seed
1921

@@ -43,7 +45,7 @@ def __init__(self, config, original_cwd="", output_dir="."):
4345
self.env_evaluators = {}
4446
self.tasks = []
4547
for env_name in self.env_names:
46-
evaluator = Evaluator(env_name, config, output_dir=self.output_dir)
48+
evaluator = Evaluator(env_name, config, original_cwd=original_cwd, output_dir=self.output_dir)
4749
self.env_evaluators[env_name] = evaluator
4850
for task in evaluator.tasks:
4951
for episode_idx in range(evaluator.num_episodes):
@@ -219,7 +221,7 @@ class Evaluator:
219221
including loading in-context learning episodes and running episodes with the agent.
220222
"""
221223

222-
def __init__(self, env_name, config, output_dir="."):
224+
def __init__(self, env_name, config, original_cwd="", output_dir="."):
223225
"""Initialize the Evaluator.
224226
225227
Args:
@@ -237,6 +239,8 @@ def __init__(self, env_name, config, output_dir="."):
237239
self.num_workers = config.eval.num_workers
238240
self.max_steps_per_episode = config.eval.max_steps_per_episode
239241

242+
self.dataset = InContextDataset(self.config, self.env_name, original_cwd=original_cwd)
243+
240244
def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
241245
"""Run a single evaluation episode.
242246
@@ -284,6 +288,14 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
284288
csv_writer = csv.writer(csv_file, escapechar="˘", quoting=csv.QUOTE_MINIMAL)
285289
csv_writer.writerow(["Step", "Action", "Reasoning", "Observation", "Reward", "Done"])
286290

291+
# If the agent is an FewShotAgent, load the in-context learning episode
292+
if isinstance(agent, FewShotAgent):
293+
for icl_episode in range(self.config.eval.icl_episodes):
294+
self.dataset.load_in_context_learning_episode(icl_episode, task, agent)
295+
296+
if self.config.agent.cache_icl and self.config.client.client_name == "gemini":
297+
agent.cache_icl()
298+
287299
pbar_desc = f"Task: {task}, Proc: {process_num}"
288300
pbar = tqdm(
289301
total=max_steps_per_episode,

balrog/prompt_builder/history.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ def reset(self):
7575
"""Clear the event history."""
7676
self._events.clear()
7777

78-
def get_prompt(self) -> List[Message]:
78+
def get_prompt(self, icl_episodes=False) -> List[Message]:
7979
"""Generate a list of Message objects representing the prompt.
8080
8181
Returns:
8282
List[Message]: Messages constructed from the event history.
8383
"""
8484
messages = []
8585

86+
if self.system_prompt and not icl_episodes:
87+
messages.append(Message(role="user", content=self.system_prompt))
88+
8689
# Determine which images to include
8790
images_needed = self.max_image_history
8891
for event in reversed(self._events):

0 commit comments

Comments
 (0)