Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,5 @@ cython_debug/
/outputs
/tw_games
tw-games.zip
/demos
/demos.zip
/records
/records.zip
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ python eval.py \
## Documentation
- [Evaluation Guide](https://github.com/balrog-ai/BALROG/blob/main/docs/evaluation.md) - Detailed instructions for various evaluation scenarios
- [Agent Development](https://github.com/balrog-ai/BALROG/blob/main/docs/agents.md) - Tutorial on creating custom agents
- [Few Shot Learning](https://github.com/balrog-ai/BALROG/blob/main/docs/few_shot_learning.md) - Instructions on how to run Few Shot Learning

We welcome contributions! Please see our [Contributing Guidelines](https://github.com/balrog-ai/BALROG/blob/main/docs/contribution.md) for details.

Expand Down
3 changes: 3 additions & 0 deletions balrog/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .chain_of_thought import ChainOfThoughtAgent
from .custom import CustomAgent
from .dummy import DummyAgent
from .few_shot import FewShotAgent
from .naive import NaiveAgent


Expand Down Expand Up @@ -47,6 +48,8 @@ def create_agent(self):
return DummyAgent(client_factory, prompt_builder)
elif self.config.agent.type == "custom":
return CustomAgent(client_factory, prompt_builder)
elif self.config.agent.type == "few_shot":
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)

else:
raise ValueError(f"Unknown agent type: {self.config.agent}")
153 changes: 153 additions & 0 deletions balrog/agents/few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import copy
import re
from typing import List, Optional

from balrog.agents.base import BaseAgent


class Message:
def __init__(self, role: str, content: str, attachment: Optional[object] = None):
self.role = role # 'system', 'user', 'assistant'
self.content = content # String content of the message
self.attachment = attachment

def __repr__(self):
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})"


class FewShotAgent(BaseAgent):
def __init__(self, client_factory, prompt_builder, max_icl_history):
"""Initialize the FewShotAgent with a client and prompt builder."""
super().__init__(client_factory, prompt_builder)
self.client = client_factory()
self.icl_episodes = []
self.icl_events = []
self.max_icl_history = max_icl_history
self.cached_icl = False

def update_icl_observation(self, obs: dict):
long_term_context = obs["text"].get("long_term_context", "")
self.icl_events.append(
{
"type": "icl_observation",
"text": long_term_context,
}
)

def update_icl_action(self, action: str):
self.icl_events.append(
{
"type": "icl_action",
"action": action,
}
)

def cache_icl(self):
self.client.cache_icl_demo(self.get_icl_prompt())
self.cached_icl = True

def wrap_episode(self):
icl_episode = []
icl_episode.append(
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)
for event in self.icl_events:
if event["type"] == "icl_observation":
content = "Obesrvation:\n" + event["text"]
message = Message(role="user", content=content)
elif event["type"] == "icl_action":
content = event["action"]
message = Message(role="assistant", content=content)
icl_episode.append(message)
icl_episode.append(
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)

self.icl_episodes.append(icl_episode)
self.icl_events = []

def get_icl_prompt(self) -> List[Message]:
icl_instruction = Message(
role="user",
content=self.prompt_builder.system_prompt.replace(
"PLAY",
"First, observe the demonstrations provided and learn from them!",
),
)

# unroll the wrapped icl episodes messages
icl_messages = [icl_instruction]
i = 0
for icl_episode in self.icl_episodes:
episode_steps = len(icl_episode) - 2 # not count start and end messages
if i + episode_steps <= self.max_icl_history:
icl_messages.extend(icl_episode)
i += episode_steps
else:
icl_episode = icl_episode[: self.max_icl_history - i + 1] + [
icl_episode[-1]
] # +1 for start message -1 for end message
icl_messages.extend(icl_episode)
i += len(icl_episode) - 2 # not count start and end messages
break

end_demo_message = Message(
role="user",
content="****** Now it's your turn to play the game! ******",
)
icl_messages.append(end_demo_message)

return icl_messages

def act(self, obs, prev_action=None):
"""Generate the next action based on the observation and previous action.

Args:
obs (dict): The current observation in the environment.
prev_action (str, optional): The previous action taken.

Returns:
str: The selected action from the LLM response.
"""
if prev_action:
self.prompt_builder.update_action(prev_action)

self.prompt_builder.update_observation(obs)

if not self.cached_icl:
messages = self.get_icl_prompt()
else:
messages = []

messages.extend(self.prompt_builder.get_prompt(icl_episodes=True))

naive_instruction = """
You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates.
""".strip()

if messages and messages[-1].role == "user":
messages[-1].content += "\n\n" + naive_instruction

response = self.client.generate(messages)

final_answer = self._extract_final_answer(response)

return final_answer

def _extract_final_answer(self, answer):
"""Sanitize the final answer, keeping only alphabetic characters.

Args:
answer (LLMResponse): The response from the LLM.

Returns:
LLMResponse: The sanitized response.
"""

def filter_letters(input_string):
return re.sub(r"[^a-zA-Z\s:]", "", input_string)

final_answer = copy.deepcopy(answer)
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))

return final_answer
5 changes: 4 additions & 1 deletion balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ agent:
max_history: 16 # Maximum number of previous turns to keep in the dialogue history
max_image_history: 0 # Maximum number of images to keep in the history
max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent)
max_icl_history: 1000 # Maximum number of ICL steps to keep in history (if using 'few_shot' type of agent)
cache_icl: False

eval:
output_dir: "results" # Directory where evaluation results will be saved
Expand All @@ -19,7 +21,8 @@ eval:
max_steps_per_episode: null # Max steps per episode; null uses the environment default
save_trajectories: True # Whether to save agent trajectories (text only)
save_images: False # Whether to save images from the environment

icl_episodes: 1
icl_dataset: records

client:
client_name: openai # LLM client to use (e.g., 'openai', 'gemini', 'claude')
Expand Down
102 changes: 102 additions & 0 deletions balrog/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import glob
import logging
import os
import random
import re
from pathlib import Path

import numpy as np


def natural_sort_key(s):
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))]


def choice_excluding(lst, excluded_element):
possible_choices = [item for item in lst if item != excluded_element]
return random.choice(possible_choices)


class InContextDataset:
def __init__(self, config, env_name, original_cwd) -> None:
self.config = config
self.env_name = env_name
self.original_cwd = original_cwd

def icl_episodes(self, task):
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key))

def extract_seed(self, demo_path):
# extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz`
seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part]
return int(seed[0])

def demo_task(self, task):
# use different task - avoid the case where we put the solution into the context
if self.env_name == "babaisai":
task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task)

return task

def demo_path(self, i, task):
icl_episodes = self.icl_episodes(task)
demo_path = icl_episodes[i % len(icl_episodes)]

# use different seed - avoid the case where we put the solution into the context
if self.env_name == "textworld":
from balrog.environments.textworld import global_textworld_context

textworld_context = global_textworld_context(
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs
)
next_seed = textworld_context.count[task]
demo_seed = self.extract_seed(demo_path)
if next_seed == demo_seed:
demo_path = self.icl_episodes(task)[i + 1]

return demo_path

def load_episode(self, filename):
# Load the compressed NPZ file
with np.load(filename, allow_pickle=True) as data:
# Convert to dictionary if you want
episode = {k: data[k] for k in data.files}
return episode

def load_in_context_learning_episodes(self, num_episodes, task, agent):
demo_task = self.demo_task(task)
demo_paths = [self.demo_path(i, demo_task) for i in range(len(self.icl_episodes(task)))]
random.shuffle(demo_paths)
demo_paths = demo_paths[:num_episodes]

for demo_path in demo_paths:
self.load_in_context_learning_episode(demo_path, agent)

def load_in_context_learning_episode(self, demo_path, agent):
episode = self.load_episode(demo_path)

actions = episode.pop("action").tolist()
rewards = episode.pop("reward").tolist()
terminated = episode.pop("terminated")
truncated = episode.pop("truncated")
dones = np.any([terminated, truncated], axis=0).tolist()
observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())]

# first transition only contains observation (like env.reset())
observation, action, reward, done = observations.pop(0), actions.pop(0), rewards.pop(0), dones.pop(0)
agent.update_icl_observation(observation)

for observation, action, reward, done in zip(observations, actions, rewards, dones):
action = str(action)

agent.update_icl_action(action)
agent.update_icl_observation(observation)

if done:
break

if not done:
logging.info("icl trajectory ended without done")

agent.wrap_episode()
15 changes: 13 additions & 2 deletions balrog/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from omegaconf import OmegaConf
from tqdm import tqdm

from balrog.agents.few_shot import FewShotAgent
from balrog.dataset import InContextDataset
from balrog.environments import make_env
from balrog.utils import get_unique_seed

Expand Down Expand Up @@ -43,7 +45,7 @@ def __init__(self, config, original_cwd="", output_dir="."):
self.env_evaluators = {}
self.tasks = []
for env_name in self.env_names:
evaluator = Evaluator(env_name, config, output_dir=self.output_dir)
evaluator = Evaluator(env_name, config, original_cwd=original_cwd, output_dir=self.output_dir)
self.env_evaluators[env_name] = evaluator
for task in evaluator.tasks:
for episode_idx in range(evaluator.num_episodes):
Expand Down Expand Up @@ -219,7 +221,7 @@ class Evaluator:
including loading in-context learning episodes and running episodes with the agent.
"""

def __init__(self, env_name, config, output_dir="."):
def __init__(self, env_name, config, original_cwd="", output_dir="."):
"""Initialize the Evaluator.

Args:
Expand All @@ -237,6 +239,8 @@ def __init__(self, env_name, config, output_dir="."):
self.num_workers = config.eval.num_workers
self.max_steps_per_episode = config.eval.max_steps_per_episode

self.dataset = InContextDataset(self.config, self.env_name, original_cwd=original_cwd)

def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
"""Run a single evaluation episode.

Expand Down Expand Up @@ -284,6 +288,13 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
csv_writer = csv.writer(csv_file, escapechar="˘", quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(["Step", "Action", "Reasoning", "Observation", "Reward", "Done"])

# If the agent is an FewShotAgent, load the in-context learning episode
if isinstance(agent, FewShotAgent):
self.dataset.load_in_context_learning_episodes(self.config.eval.icl_episodes, task, agent)

if self.config.agent.cache_icl and self.config.client.client_name == "gemini":
agent.cache_icl()

pbar_desc = f"Task: {task}, Proc: {process_num}"
pbar = tqdm(
total=max_steps_per_episode,
Expand Down
5 changes: 4 additions & 1 deletion balrog/prompt_builder/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def reset(self):
"""Clear the event history."""
self._events.clear()

def get_prompt(self) -> List[Message]:
def get_prompt(self, icl_episodes=False) -> List[Message]:
"""Generate a list of Message objects representing the prompt.

Returns:
Expand All @@ -85,6 +85,9 @@ def get_prompt(self) -> List[Message]:
if self.system_prompt:
messages.append(Message(role="user", content=self.system_prompt))

if self.system_prompt and not icl_episodes:
messages.append(Message(role="user", content=self.system_prompt))

# Determine which images to include
images_needed = self.max_image_history
for event in reversed(self._events):
Expand Down
Loading
Loading