Skip to content

Commit 1d20a56

Browse files
feat: robust naive agent
1 parent 7fb3dbd commit 1d20a56

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

balrog/agents/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .dummy import DummyAgent
77
from .few_shot import FewShotAgent
88
from .naive import NaiveAgent
9+
from .robust_naive import RobustNaiveAgent
910

1011

1112
class AgentFactory:
@@ -50,6 +51,8 @@ def create_agent(self):
5051
return CustomAgent(client_factory, prompt_builder)
5152
elif self.config.agent.type == "few_shot":
5253
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)
54+
elif self.config.agent.type == "robust_naive":
55+
return RobustNaiveAgent(client_factory, prompt_builder)
5356

5457
else:
5558
raise ValueError(f"Unknown agent type: {self.config.agent}")

balrog/agents/robust_naive.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import copy
2+
import re
3+
4+
from balrog.agents.base import BaseAgent
5+
6+
7+
class RobustNaiveAgent(BaseAgent):
8+
"""An agent that generates actions based on observations without complex reasoning."""
9+
10+
def __init__(self, client_factory, prompt_builder):
11+
"""Initialize the NaiveAgent with a client and prompt builder."""
12+
super().__init__(client_factory, prompt_builder)
13+
self.client = client_factory()
14+
15+
def act(self, obs, prev_action=None):
16+
"""Generate the next action based on the observation and previous action.
17+
18+
Args:
19+
obs (dict): The current observation in the environment.
20+
prev_action (str, optional): The previous action taken.
21+
22+
Returns:
23+
str: The selected action from the LLM response.
24+
"""
25+
if prev_action:
26+
self.prompt_builder.update_action(prev_action)
27+
28+
self.prompt_builder.update_observation(obs)
29+
30+
messages = self.prompt_builder.get_prompt()
31+
32+
# Updated instructions to require a very strict output format
33+
naive_instruction = """
34+
You must choose exactly one of the listed actions and output it strictly in the following format:
35+
36+
<|ACTION|>YOUR_CHOSEN_ACTION</|ACTION|>
37+
38+
You must not output any other text before or after these tags. No explanation, no reasoning, just the action within these tags.
39+
""".strip()
40+
41+
if messages and messages[-1].role == "user":
42+
messages[-1].content += "\n\n" + naive_instruction
43+
44+
response = self.client.generate(messages)
45+
final_answer = self._extract_final_answer(response)
46+
return final_answer
47+
48+
def _extract_final_answer(self, answer):
49+
"""Extract the action from the completion by looking for <|ACTION|> ... </|ACTION|> tags.
50+
51+
Args:
52+
answer (LLMResponse): The response from the LLM.
53+
54+
Returns:
55+
LLMResponse: The sanitized response containing just the extracted action.
56+
"""
57+
completion_text = answer.completion
58+
# Use a regex to find the text inside <|ACTION|> and </|ACTION|>
59+
match = re.search(r"<\|ACTION\|>(.*?)</\|ACTION\|>", completion_text, re.DOTALL)
60+
if match:
61+
extracted_action = match.group(1).strip()
62+
else:
63+
# If no match is found, fallback to the original completion (or handle error)
64+
extracted_action = completion_text.strip()
65+
66+
final_answer = copy.deepcopy(answer)
67+
final_answer = final_answer._replace(completion=extracted_action)
68+
69+
return final_answer

0 commit comments

Comments
 (0)