Skip to content

Commit 3702c0c

Browse files
committed
update dependencies to battleships
1 parent 87a738e commit 3702c0c

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

balrog/agents/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from balrog.client import create_llm_client
22

33
from ..prompt_builder import create_prompt_builder
4+
from .battleships_naive import NaiveAgent as BattleshipsNaive
45
from .chain_of_thought import ChainOfThoughtAgent
56
from .custom import CustomAgent
67
from .dummy import DummyAgent
78
from .few_shot import FewShotAgent
89
from .naive import NaiveAgent
9-
from .robust_naive import RobustNaiveAgent
1010
from .robust_cot import RobustCoTAgent
11+
from .robust_naive import RobustNaiveAgent
1112

1213

1314
class AgentFactory:
@@ -44,6 +45,8 @@ def create_agent(self):
4445

4546
if self.config.agent.type == "naive":
4647
return NaiveAgent(client_factory, prompt_builder)
48+
if self.config.agent.type == "battleships_naive":
49+
return BattleshipsNaive(client_factory, prompt_builder)
4750
elif self.config.agent.type == "cot":
4851
return ChainOfThoughtAgent(client_factory, prompt_builder, config=self.config)
4952
elif self.config.agent.type == "dummy":

balrog/agents/battleships_naive.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import copy
2+
import re
3+
4+
from balrog.agents.base import BaseAgent
5+
6+
7+
class NaiveAgent(BaseAgent):
8+
"""An agent that generates actions based on observations without complex reasoning."""
9+
10+
def __init__(self, client_factory, prompt_builder):
11+
"""Initialize the NaiveAgent with a client and prompt builder."""
12+
super().__init__(client_factory, prompt_builder)
13+
self.client = client_factory()
14+
15+
def act(self, obs, prev_action=None):
16+
"""Generate the next action based on the observation and previous action.
17+
18+
Args:
19+
obs (dict): The current observation in the environment.
20+
prev_action (str, optional): The previous action taken.
21+
22+
Returns:
23+
str: The selected action from the LLM response.
24+
"""
25+
if prev_action:
26+
self.prompt_builder.update_action(prev_action)
27+
28+
self.prompt_builder.update_observation(obs)
29+
30+
messages = self.prompt_builder.get_prompt()
31+
32+
naive_instruction = """
33+
It's your turn. What coordinate would you like to output?
34+
""".strip()
35+
36+
if messages and messages[-1].role == "user":
37+
messages[-1].content += "\n\n" + naive_instruction
38+
39+
response = self.client.generate(messages)
40+
41+
final_answer = self._extract_final_answer(response)
42+
43+
return final_answer
44+
45+
def _extract_final_answer(self, answer):
46+
"""Sanitize the final answer, keeping only alphabetic characters.
47+
48+
Args:
49+
answer (LLMResponse): The response from the LLM.
50+
51+
Returns:
52+
LLMResponse: The sanitized response.
53+
"""
54+
55+
def filter_letters(input_string):
56+
return re.sub(r"[^a-zA-Z0-9\s:]", "", input_string)
57+
58+
final_answer = copy.deepcopy(answer)
59+
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))
60+
61+
return final_answer

balrog/environments/battleships/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
### Installation
44

55
```
6+
pip install jupyter scipy
67
pip install git+https://github.com/thomashirtz/gym-battleship#egg=gym-battleship
78
```

0 commit comments

Comments
 (0)