Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion balrog/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from balrog.client import create_llm_client

from ..prompt_builder import create_prompt_builder
from .battleships_naive import NaiveAgent as BattleshipsNaive
from .chain_of_thought import ChainOfThoughtAgent
from .custom import CustomAgent
from .dummy import DummyAgent
from .few_shot import FewShotAgent
from .naive import NaiveAgent
from .robust_naive import RobustNaiveAgent
from .robust_cot import RobustCoTAgent
from .robust_naive import RobustNaiveAgent


class AgentFactory:
Expand Down Expand Up @@ -44,6 +45,8 @@ def create_agent(self):

if self.config.agent.type == "naive":
return NaiveAgent(client_factory, prompt_builder)
if self.config.agent.type == "battleships_naive":
return BattleshipsNaive(client_factory, prompt_builder)
elif self.config.agent.type == "cot":
return ChainOfThoughtAgent(client_factory, prompt_builder, config=self.config)
elif self.config.agent.type == "dummy":
Expand Down
61 changes: 61 additions & 0 deletions balrog/agents/battleships_naive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import copy
import re

from balrog.agents.base import BaseAgent


class NaiveAgent(BaseAgent):
"""An agent that generates actions based on observations without complex reasoning."""

def __init__(self, client_factory, prompt_builder):
"""Initialize the NaiveAgent with a client and prompt builder."""
super().__init__(client_factory, prompt_builder)
self.client = client_factory()

def act(self, obs, prev_action=None):
"""Generate the next action based on the observation and previous action.

Args:
obs (dict): The current observation in the environment.
prev_action (str, optional): The previous action taken.

Returns:
str: The selected action from the LLM response.
"""
if prev_action:
self.prompt_builder.update_action(prev_action)

self.prompt_builder.update_observation(obs)

messages = self.prompt_builder.get_prompt()

naive_instruction = """
It's your turn. What coordinate would you like to output?
""".strip()

if messages and messages[-1].role == "user":
messages[-1].content += "\n\n" + naive_instruction

response = self.client.generate(messages)

final_answer = self._extract_final_answer(response)

return final_answer

def _extract_final_answer(self, answer):
"""Sanitize the final answer, keeping only alphabetic characters.

Args:
answer (LLMResponse): The response from the LLM.

Returns:
LLMResponse: The sanitized response.
"""

def filter_letters(input_string):
return re.sub(r"[^a-zA-Z0-9\s:]", "", input_string)

final_answer = copy.deepcopy(answer)
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))

return final_answer
12 changes: 12 additions & 0 deletions balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ eval:
nle: 5 # Number of episodes for the 'nle' environment
minihack: 5 # Number of episodes for each 'minihack' task
babyai: 10 # Number of episodes for each 'babyai' task
battleships: 10
crafter: 10 # Number of episodes for the 'crafter' environment
babaisai: 3 # Number of episodes for each 'babaisai' task
textworld: 10 # Number of episodes for each 'textworld' task
Expand Down Expand Up @@ -58,6 +59,14 @@ envs:
save_ttyrec_every: 0
autopickup: False
skip_more: True
battleships_kwargs:
episode_steps: 50
board_size: [10, 10]
ship_sizes:
5: 1
4: 1
3: 2
2: 1
babyai_kwargs:
num_dists: 0
crafter_kwargs:
Expand Down Expand Up @@ -102,6 +111,9 @@ tasks:
- "BabyAI-MixedTrainLocal-v0/putnext"
- "BabyAI-MixedTrainLocal-v0/pick_up_seq_go_to"

battleships_tasks:
- "Battleship-v0"

textworld_tasks:
- "treasure_hunter"
- "the_cooking_game"
Expand Down
4 changes: 4 additions & 0 deletions balrog/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def make_env(env_name, task, config, render_mode=None):
from balrog.environments.babaisai.babaisai_env import make_babaisai_env

base_env = make_babaisai_env(env_name, task, config, render_mode=render_mode)
elif env_name == "battleships":
from balrog.environments.battleships.battleships_env import make_battleships_env

base_env = make_battleships_env(env_name, task, config, render_mode=render_mode)
else:
raise ValueError(f"Unknown environment: {env_name}")
return EnvWrapper(base_env, env_name, task)
Expand Down
8 changes: 8 additions & 0 deletions balrog/environments/battleships/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Battleships

### Installation

```
pip install jupyter scipy
pip install git+https://github.com/thomashirtz/gym-battleship#egg=gym-battleship
```
50 changes: 50 additions & 0 deletions balrog/environments/battleships/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np

from balrog.environments.battleships.base import BattleshipsWrapper


def get_instruction_prompt(env, instruction):
ship_names = {
5: "Carrier",
4: "Battleship",
3: "Cruiser",
2: "Destroyer",
}

ships_strings = "\n".join(
[
f"{number} {ship_names[ship_size]} {ship_size} cells {'each' if number > 1 else ''}"
for ship_size, number in env.ship_sizes.items()
]
)

num_rows, num_columns = env.board.shape

instruction_prompt = f"""
You are an AI agent playing a Battleships game on a {num_rows}x{num_columns} grid. Your mission is to strategically locate and sink all enemy ships hidden on the board.

Game Rules:
- The board is a {num_rows}x{num_columns} grid with coordinates from {env.language_action_space[0]} to {env.language_action_space[-1]}
- Ships are placed horizontally or vertically, never diagonally
- Ships cannot be adjacent to each other (not even diagonally)
- A hit will be reported when you successfully strike a ship
- A miss will be reported when you strike empty water

The enemy has the following ships:
{ships_strings}

In a moment I will present you an observation grid. This grid represents the current state of a Battleship game. The format uses the following notation:
- O: Water (missed shot)
- X: Hit (part of a ship that has been hit)
- Z: Sunk (indicates that the entire ship has been sunk)

Tips:
- When you get a hit, try to sunk the ship as you get more reward for that.
- Avoid targeting cells adjacent to sunken ships.

IMPORTANT: Your response must be EXACTLY one coordinate in the format of a letter followed by a number (e.g., "E5", "A1", "J10"). Do not provide any explanation or reasoning in your response.

PLAY
""".strip()

return instruction_prompt
Loading