Skip to content

Commit a4e6627

Browse files
Robust CoT agent (#27)
* feat: robust cot agent * fix: default message when action not parsed * feat: improve cot message
1 parent 1f31a28 commit a4e6627

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

balrog/agents/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .few_shot import FewShotAgent
88
from .naive import NaiveAgent
99
from .robust_naive import RobustNaiveAgent
10+
from .robust_cot import RobustCoTAgent
1011

1112

1213
class AgentFactory:
@@ -53,6 +54,8 @@ def create_agent(self):
5354
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)
5455
elif self.config.agent.type == "robust_naive":
5556
return RobustNaiveAgent(client_factory, prompt_builder)
57+
elif self.config.agent.type == "robust_cot":
58+
return RobustCoTAgent(client_factory, prompt_builder, config=self.config)
5659

5760
else:
5861
raise ValueError(f"Unknown agent type: {self.config.agent}")

balrog/agents/robust_cot.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import copy
2+
import re
3+
4+
from balrog.agents.base import BaseAgent
5+
from balrog.client import LLMClientWrapper
6+
7+
8+
class RobustCoTAgent(BaseAgent):
9+
"""An agent that performs actions using a chain-of-thought reasoning process."""
10+
11+
def __init__(self, client_factory: LLMClientWrapper, prompt_builder, config):
12+
"""Initialize the ChainOfThoughtAgent with a client, prompt builder, and configuration.
13+
14+
Args:
15+
client_factory (LLMClientWrapper): A factory for creating the LLM client instance.
16+
prompt_builder (PromptBuilder): Object to build prompts for the agent.
17+
config: Configuration object containing settings for the agent.
18+
"""
19+
super().__init__(client_factory, prompt_builder)
20+
self.remember_cot = config.agent.remember_cot
21+
22+
def act(self, obs, prev_action=None):
23+
"""Generate the next action using chain-of-thought reasoning based on the current observation.
24+
25+
Args:
26+
obs (dict): The current observation in the environment.
27+
prev_action (str, optional): The previous action taken.
28+
29+
Returns:
30+
LLMResponse: The response containing the final selected action.
31+
"""
32+
if prev_action:
33+
self.prompt_builder.update_action(prev_action)
34+
35+
self.prompt_builder.update_observation(obs)
36+
37+
messages = self.prompt_builder.get_prompt()
38+
39+
# Updated instructions: chain of thought + strict output format
40+
cot_instructions = """
41+
First, think about the best course of action.
42+
Then, you must choose exactly one of the listed actions and output it strictly in the following format:
43+
44+
<|ACTION|>YOUR_CHOSEN_ACTION<|END|>
45+
46+
Replace YOUR_CHOSEN_ACTION with the chosen action.
47+
""".strip()
48+
49+
# Add the updated instructions to the last message
50+
messages[-1].content += "\n\n" + cot_instructions
51+
52+
# Generate the CoT reasoning
53+
cot_reasoning = self.client.generate(messages)
54+
55+
# Extract the final answer from the CoT reasoning
56+
final_answer = self._extract_final_answer(cot_reasoning)
57+
58+
return final_answer
59+
60+
def _extract_final_answer(self, reasoning):
61+
"""Extract the final action from the chain-of-thought reasoning response.
62+
63+
Args:
64+
reasoning (LLMResponse): The response containing CoT reasoning and action.
65+
66+
Returns:
67+
LLMResponse: The response with the extracted final action in `completion`
68+
and the entire chain-of-thought in `reasoning`.
69+
"""
70+
# Make a copy so we don't mutate the original
71+
final_answer = copy.deepcopy(reasoning)
72+
73+
# Store the entire chain-of-thought (raw completion) in `reasoning`
74+
final_answer = final_answer._replace(reasoning=reasoning.completion)
75+
76+
# Now parse the strict action format: <|ACTION|> ... <|END|>
77+
completion_text = reasoning.completion
78+
match = re.search(r"<\|ACTION\|>(.*?)<\|END\|>", completion_text, re.DOTALL)
79+
if match:
80+
extracted_action = match.group(1).strip()
81+
else:
82+
# Fallback to the entire completion if not matched
83+
extracted_action = "Failed to obtain a valid action from the reasoning."
84+
85+
# Replace the final `completion` with only the extracted action
86+
final_answer = final_answer._replace(completion=extracted_action)
87+
88+
return final_answer

0 commit comments

Comments
 (0)