diff --git a/.gitignore b/.gitignore index bb1da426..540d76e5 100644 --- a/.gitignore +++ b/.gitignore @@ -165,5 +165,5 @@ cython_debug/ /outputs /tw_games tw-games.zip -/demos -/demos.zip +/records +/records.zip diff --git a/README.md b/README.md index 9c28ecfc..d5d8bd2f 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ python eval.py \ ## Documentation - [Evaluation Guide](https://github.com/balrog-ai/BALROG/blob/main/docs/evaluation.md) - Detailed instructions for various evaluation scenarios - [Agent Development](https://github.com/balrog-ai/BALROG/blob/main/docs/agents.md) - Tutorial on creating custom agents +- [Few Shot Learning](https://github.com/balrog-ai/BALROG/blob/main/docs/few_shot_learning.md) - Instructions on how to run Few Shot Learning We welcome contributions! Please see our [Contributing Guidelines](https://github.com/balrog-ai/BALROG/blob/main/docs/contribution.md) for details. diff --git a/balrog/agents/__init__.py b/balrog/agents/__init__.py index 4ef4ed76..b4c00332 100644 --- a/balrog/agents/__init__.py +++ b/balrog/agents/__init__.py @@ -4,6 +4,7 @@ from .chain_of_thought import ChainOfThoughtAgent from .custom import CustomAgent from .dummy import DummyAgent +from .few_shot import FewShotAgent from .naive import NaiveAgent @@ -47,6 +48,8 @@ def create_agent(self): return DummyAgent(client_factory, prompt_builder) elif self.config.agent.type == "custom": return CustomAgent(client_factory, prompt_builder) + elif self.config.agent.type == "few_shot": + return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history) else: raise ValueError(f"Unknown agent type: {self.config.agent}") diff --git a/balrog/agents/few_shot.py b/balrog/agents/few_shot.py new file mode 100644 index 00000000..6db9540b --- /dev/null +++ b/balrog/agents/few_shot.py @@ -0,0 +1,153 @@ +import copy +import re +from typing import List, Optional + +from balrog.agents.base import BaseAgent + + +class Message: + def __init__(self, role: str, content: str, attachment: Optional[object] = None): + self.role = role # 'system', 'user', 'assistant' + self.content = content # String content of the message + self.attachment = attachment + + def __repr__(self): + return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})" + + +class FewShotAgent(BaseAgent): + def __init__(self, client_factory, prompt_builder, max_icl_history): + """Initialize the FewShotAgent with a client and prompt builder.""" + super().__init__(client_factory, prompt_builder) + self.client = client_factory() + self.icl_episodes = [] + self.icl_events = [] + self.max_icl_history = max_icl_history + self.cached_icl = False + + def update_icl_observation(self, obs: dict): + long_term_context = obs["text"].get("long_term_context", "") + self.icl_events.append( + { + "type": "icl_observation", + "text": long_term_context, + } + ) + + def update_icl_action(self, action: str): + self.icl_events.append( + { + "type": "icl_action", + "action": action, + } + ) + + def cache_icl(self): + self.client.cache_icl_demo(self.get_icl_prompt()) + self.cached_icl = True + + def wrap_episode(self): + icl_episode = [] + icl_episode.append( + Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******") + ) + for event in self.icl_events: + if event["type"] == "icl_observation": + content = "Obesrvation:\n" + event["text"] + message = Message(role="user", content=content) + elif event["type"] == "icl_action": + content = event["action"] + message = Message(role="assistant", content=content) + icl_episode.append(message) + icl_episode.append( + Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******") + ) + + self.icl_episodes.append(icl_episode) + self.icl_events = [] + + def get_icl_prompt(self) -> List[Message]: + icl_instruction = Message( + role="user", + content=self.prompt_builder.system_prompt.replace( + "PLAY", + "First, observe the demonstrations provided and learn from them!", + ), + ) + + # unroll the wrapped icl episodes messages + icl_messages = [icl_instruction] + i = 0 + for icl_episode in self.icl_episodes: + episode_steps = len(icl_episode) - 2 # not count start and end messages + if i + episode_steps <= self.max_icl_history: + icl_messages.extend(icl_episode) + i += episode_steps + else: + icl_episode = icl_episode[: self.max_icl_history - i + 1] + [ + icl_episode[-1] + ] # +1 for start message -1 for end message + icl_messages.extend(icl_episode) + i += len(icl_episode) - 2 # not count start and end messages + break + + end_demo_message = Message( + role="user", + content="****** Now it's your turn to play the game! ******", + ) + icl_messages.append(end_demo_message) + + return icl_messages + + def act(self, obs, prev_action=None): + """Generate the next action based on the observation and previous action. + + Args: + obs (dict): The current observation in the environment. + prev_action (str, optional): The previous action taken. + + Returns: + str: The selected action from the LLM response. + """ + if prev_action: + self.prompt_builder.update_action(prev_action) + + self.prompt_builder.update_observation(obs) + + if not self.cached_icl: + messages = self.get_icl_prompt() + else: + messages = [] + + messages.extend(self.prompt_builder.get_prompt(icl_episodes=True)) + + naive_instruction = """ +You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates. + """.strip() + + if messages and messages[-1].role == "user": + messages[-1].content += "\n\n" + naive_instruction + + response = self.client.generate(messages) + + final_answer = self._extract_final_answer(response) + + return final_answer + + def _extract_final_answer(self, answer): + """Sanitize the final answer, keeping only alphabetic characters. + + Args: + answer (LLMResponse): The response from the LLM. + + Returns: + LLMResponse: The sanitized response. + """ + + def filter_letters(input_string): + return re.sub(r"[^a-zA-Z\s:]", "", input_string) + + final_answer = copy.deepcopy(answer) + final_answer = final_answer._replace(completion=filter_letters(final_answer.completion)) + + return final_answer diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index a3ea60e0..2b53786f 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -4,6 +4,8 @@ agent: max_history: 16 # Maximum number of previous turns to keep in the dialogue history max_image_history: 0 # Maximum number of images to keep in the history max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent) + max_icl_history: 1000 # Maximum number of ICL steps to keep in history (if using 'few_shot' type of agent) + cache_icl: False eval: output_dir: "results" # Directory where evaluation results will be saved @@ -19,7 +21,8 @@ eval: max_steps_per_episode: null # Max steps per episode; null uses the environment default save_trajectories: True # Whether to save agent trajectories (text only) save_images: False # Whether to save images from the environment - + icl_episodes: 1 + icl_dataset: records client: client_name: openai # LLM client to use (e.g., 'openai', 'gemini', 'claude') diff --git a/balrog/dataset.py b/balrog/dataset.py new file mode 100644 index 00000000..dafb6f2e --- /dev/null +++ b/balrog/dataset.py @@ -0,0 +1,102 @@ +import glob +import logging +import os +import random +import re +from pathlib import Path + +import numpy as np + + +def natural_sort_key(s): + return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))] + + +def choice_excluding(lst, excluded_element): + possible_choices = [item for item in lst if item != excluded_element] + return random.choice(possible_choices) + + +class InContextDataset: + def __init__(self, config, env_name, original_cwd) -> None: + self.config = config + self.env_name = env_name + self.original_cwd = original_cwd + + def icl_episodes(self, task): + demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task + return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key)) + + def extract_seed(self, demo_path): + # extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz` + seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part] + return int(seed[0]) + + def demo_task(self, task): + # use different task - avoid the case where we put the solution into the context + if self.env_name == "babaisai": + task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task) + + return task + + def demo_path(self, i, task): + icl_episodes = self.icl_episodes(task) + demo_path = icl_episodes[i % len(icl_episodes)] + + # use different seed - avoid the case where we put the solution into the context + if self.env_name == "textworld": + from balrog.environments.textworld import global_textworld_context + + textworld_context = global_textworld_context( + tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs + ) + next_seed = textworld_context.count[task] + demo_seed = self.extract_seed(demo_path) + if next_seed == demo_seed: + demo_path = self.icl_episodes(task)[i + 1] + + return demo_path + + def load_episode(self, filename): + # Load the compressed NPZ file + with np.load(filename, allow_pickle=True) as data: + # Convert to dictionary if you want + episode = {k: data[k] for k in data.files} + return episode + + def load_in_context_learning_episodes(self, num_episodes, task, agent): + demo_task = self.demo_task(task) + demo_paths = [self.demo_path(i, demo_task) for i in range(len(self.icl_episodes(task)))] + random.shuffle(demo_paths) + demo_paths = demo_paths[:num_episodes] + + for demo_path in demo_paths: + self.load_in_context_learning_episode(demo_path, agent) + + def load_in_context_learning_episode(self, demo_path, agent): + episode = self.load_episode(demo_path) + + actions = episode.pop("action").tolist() + rewards = episode.pop("reward").tolist() + terminated = episode.pop("terminated") + truncated = episode.pop("truncated") + dones = np.any([terminated, truncated], axis=0).tolist() + observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())] + + # first transition only contains observation (like env.reset()) + observation, action, reward, done = observations.pop(0), actions.pop(0), rewards.pop(0), dones.pop(0) + agent.update_icl_observation(observation) + + for observation, action, reward, done in zip(observations, actions, rewards, dones): + action = str(action) + + agent.update_icl_action(action) + agent.update_icl_observation(observation) + + if done: + break + + if not done: + logging.info("icl trajectory ended without done") + + agent.wrap_episode() diff --git a/balrog/evaluator.py b/balrog/evaluator.py index 4c392eca..a5687821 100644 --- a/balrog/evaluator.py +++ b/balrog/evaluator.py @@ -14,6 +14,8 @@ from omegaconf import OmegaConf from tqdm import tqdm +from balrog.agents.few_shot import FewShotAgent +from balrog.dataset import InContextDataset from balrog.environments import make_env from balrog.utils import get_unique_seed @@ -43,7 +45,7 @@ def __init__(self, config, original_cwd="", output_dir="."): self.env_evaluators = {} self.tasks = [] for env_name in self.env_names: - evaluator = Evaluator(env_name, config, output_dir=self.output_dir) + evaluator = Evaluator(env_name, config, original_cwd=original_cwd, output_dir=self.output_dir) self.env_evaluators[env_name] = evaluator for task in evaluator.tasks: for episode_idx in range(evaluator.num_episodes): @@ -219,7 +221,7 @@ class Evaluator: including loading in-context learning episodes and running episodes with the agent. """ - def __init__(self, env_name, config, output_dir="."): + def __init__(self, env_name, config, original_cwd="", output_dir="."): """Initialize the Evaluator. Args: @@ -237,6 +239,8 @@ def __init__(self, env_name, config, output_dir="."): self.num_workers = config.eval.num_workers self.max_steps_per_episode = config.eval.max_steps_per_episode + self.dataset = InContextDataset(self.config, self.env_name, original_cwd=original_cwd) + def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0): """Run a single evaluation episode. @@ -284,6 +288,13 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0): csv_writer = csv.writer(csv_file, escapechar="˘", quoting=csv.QUOTE_MINIMAL) csv_writer.writerow(["Step", "Action", "Reasoning", "Observation", "Reward", "Done"]) + # If the agent is an FewShotAgent, load the in-context learning episode + if isinstance(agent, FewShotAgent): + self.dataset.load_in_context_learning_episodes(self.config.eval.icl_episodes, task, agent) + + if self.config.agent.cache_icl and self.config.client.client_name == "gemini": + agent.cache_icl() + pbar_desc = f"Task: {task}, Proc: {process_num}" pbar = tqdm( total=max_steps_per_episode, diff --git a/balrog/prompt_builder/history.py b/balrog/prompt_builder/history.py index 3076741e..a847c1d1 100644 --- a/balrog/prompt_builder/history.py +++ b/balrog/prompt_builder/history.py @@ -75,7 +75,7 @@ def reset(self): """Clear the event history.""" self._events.clear() - def get_prompt(self) -> List[Message]: + def get_prompt(self, icl_episodes=False) -> List[Message]: """Generate a list of Message objects representing the prompt. Returns: @@ -85,6 +85,9 @@ def get_prompt(self) -> List[Message]: if self.system_prompt: messages.append(Message(role="user", content=self.system_prompt)) + if self.system_prompt and not icl_episodes: + messages.append(Message(role="user", content=self.system_prompt)) + # Determine which images to include images_needed = self.max_image_history for event in reversed(self._events): diff --git a/docs/few_shot_learning.md b/docs/few_shot_learning.md new file mode 100644 index 00000000..5e2b1f9d --- /dev/null +++ b/docs/few_shot_learning.md @@ -0,0 +1,44 @@ +# Few Shot Learning +The approach enhances model performance by providing examples of expert gameplay in the agent's context. + +### Installation +Download and unzip our expert demonstrations +```bash +pip install gdown +gdown 1TQbrqMSC5K_SNx9tta1Tlhtg8flSIGaJ +unzip records.zip +``` + +### Usage +To run Few-Shot Learning +```bash +python -m eval agent.type=few_shot eval.icl_episodes=5 +``` + +### Features +- each demonstration have corresponding mp4 file, which allows for quick inspection +- `FewShotAgent` allows for context caching, can be enabled with `agent.cache_icl=True` + +### Additional Notes: +- Expert demonstrations are formatted as conversation sequences +- all trajectories are loaded in context, this can increase the cost of evaluation, especially for environments like nethack +- for textworld environments we avoid the case where we put the solution into the context +- in principle we also could incorporate similar strategy for other environments, for example in nle we could load trajectories corresponding to the same character + +### Prompt formatting +Example prompt for the agent starting playing the game with `eval.icl_episodes=1` +``` +00 = Message(role=user, content=System Prompt: [], attachment=None) +01 = Message(role=user, content=****** START OF DEMONSTRATION EPISODE 1 ******, attachment=None) +02 = Message(role=user, content=Obesrvation: [], attachment=None) +03 = Message(role=assistant, content=None, attachment=None) +04 = Message(role=user, content=Obesrvation: [], attachment=None) +05 = Message(role=assistant, content=go forward, attachment=None) +06 = Message(role=user, content=Obesrvation: [], attachment=None) +07 = Message(role=assistant, content=go forward, attachment=None) +08 = Message(role=user, content=Obesrvation: [], attachment=None) +09 = Message(role=assistant, content=turn left, attachment=None) +10 = Message(role=user, content=****** END OF DEMONSTRATION EPISODE 1 ******, attachment=None) +11 = Message(role=user, content=****** Now it's your turn to play the game! ******, attachment=None) +12 = Message(role=user, content=Current Observation: [], attachment=None) +```