Skip to content

Commit 954e435

Browse files
committed
more tool fixes
1 parent 5e320cb commit 954e435

File tree

5 files changed

+51
-16
lines changed

5 files changed

+51
-16
lines changed

openhands/agenthub/codeact_agent/codeact_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def update_tools(self, phase: ReplayPhase) -> None:
326326
replay_phase=phase,
327327
)
328328
logger.info(
329-
f'[REPLAY] update_tools: {json.dumps([t["function"]['name'] for t in self.tools], indent=2)}'
329+
f'[REPLAY] update_tools for phase {phase}: {json.dumps([t["function"]['name'] for t in self.tools], indent=2)}'
330330
)
331331

332332
def step(self, state: State) -> Action:

openhands/replay/replay_initial_analysis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from openhands.controller.state.state import State
66
from openhands.core.logger import openhands_logger as logger
7+
from openhands.core.schema.replay import ReplayPhase
78
from openhands.events.action.action import Action
89
from openhands.events.action.message import MessageAction
910
from openhands.events.action.replay import ReplayInternalCmdRunAction
@@ -44,6 +45,10 @@ def start_initial_analysis(
4445

4546

4647
def replay_enhance_action(state: State, is_workspace_repo: bool) -> Action | None:
48+
if state.replay_phase != ReplayPhase.Normal:
49+
# We currently only enhance prompts in the Normal phase.
50+
return None
51+
4752
if state.replay_enhance_prompt_id == -1:
4853
# 1. Get current user prompt.
4954
latest_user_message = state.get_last_user_message()

openhands/replay/replay_phases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def get_phase_prompt(obs) -> str:
8888

8989
def update_phase(new_phase: ReplayPhase, state: State, agent: Agent):
9090
"""Apply phase update side effects."""
91+
logger.info(f'[REPLAY] update_phase (replay_phase): {new_phase}')
9192
state.replay_phase = new_phase
9293
agent.update_tools(new_phase)
93-
logger.info(f'[REPLAY] update_phase (replay_phase): {new_phase}')
9494

9595

9696
# ###########################################################################

openhands/replay/replay_tools.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ def replay_phase_tool(
152152
# Bookkeeping + utilities.
153153
# ###########################################################################
154154

155-
replay_analysis_tools: list[ReplayTool] = [
155+
replay_analysis_tools: tuple[ReplayTool, ...] = (
156156
ReplayInspectDataTool,
157157
ReplayInspectPointTool,
158-
]
158+
)
159159

160-
replay_tools: list[ReplayTool] = [
160+
replay_tools: tuple[ReplayTool, ...] = (
161161
*replay_analysis_tools,
162162
*replay_phase_transition_tools,
163-
]
163+
)
164164
replay_tool_names: set[str] = set([t['function']['name'] for t in replay_tools])
165165
replay_replay_tool_type_by_name = {
166166
t['function']['name']: t.get('replay_tool_type', None) for t in replay_tools
@@ -203,23 +203,18 @@ def get_replay_tools(
203203
replay_phase: ReplayPhase, default_tools: list[ChatCompletionToolParam]
204204
) -> list[ChatCompletionToolParam]:
205205
if replay_phase == ReplayPhase.Normal:
206-
# Use the default tools when not in a Replay-specific phase.
207206
tools = default_tools
208207
elif replay_phase == ReplayPhase.Analysis:
209-
# Only allow analysis in this phase.
210-
tools = replay_analysis_tools
208+
tools = list(replay_analysis_tools)
211209
elif replay_phase == ReplayPhase.Edit:
212-
# Combine default and analysis tools.
213-
tools = default_tools + replay_analysis_tools
210+
tools = default_tools + list(replay_analysis_tools)
214211
else:
215212
raise ValueError(f'Unhandled ReplayPhase in get_tools: {replay_phase}')
216213

217-
# Add tools to allow transitioning to next phase.
218214
next_phase_tool = get_replay_transition_tool_for_current_phase(replay_phase)
219215
if next_phase_tool:
220216
tools.append(next_phase_tool)
221217

222-
# Return all tools.
223218
return tools
224219

225220

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,49 @@
1+
import pytest
2+
13
from openhands.core.schema.replay import ReplayPhase
24
from openhands.replay.replay_phases import get_next_agent_replay_phase
35
from openhands.replay.replay_tools import (
6+
get_replay_tools,
47
get_replay_transition_tool_for_current_phase,
8+
replay_analysis_tools,
59
)
610

711

8-
def test_get_replay_transition_tool_for_analysis_phase():
12+
def test_get_replay_transition_tools_analysis():
13+
assert get_next_agent_replay_phase(ReplayPhase.Analysis) is not None
914
tool = get_replay_transition_tool_for_current_phase(ReplayPhase.Analysis, 'submit')
10-
assert tool is not None
15+
assert tool
1116
assert tool['function']['name'] == 'submit'
1217
assert tool['new_phase'] is not None
13-
assert get_next_agent_replay_phase(ReplayPhase.Analysis) is not None
1418
assert tool['new_phase'] == get_next_agent_replay_phase(ReplayPhase.Analysis)
19+
20+
21+
def test_get_replay_transition_tools_edit():
22+
assert get_next_agent_replay_phase(ReplayPhase.Edit) is None
23+
tool = get_replay_transition_tool_for_current_phase(ReplayPhase.Edit, 'submit')
24+
assert not tool
25+
26+
27+
def test_get_tools():
28+
default_tools = []
29+
30+
# Test Normal phase
31+
tools = get_replay_tools(ReplayPhase.Normal, default_tools)
32+
assert len(tools) == len(default_tools)
33+
assert all(t in tools for t in default_tools)
34+
35+
# Test Analysis phase
36+
tools = get_replay_tools(ReplayPhase.Analysis, default_tools)
37+
assert len(tools) == len(replay_analysis_tools) + 1 # +1 for transition tool
38+
assert all(t in tools for t in replay_analysis_tools)
39+
assert tools[-1]['function']['name'] == 'submit'
40+
41+
# Test Edit phase
42+
tools = get_replay_tools(ReplayPhase.Edit, default_tools)
43+
assert len(tools) == len(default_tools) + len(replay_analysis_tools)
44+
assert all(t in tools for t in default_tools)
45+
assert all(t in tools for t in replay_analysis_tools)
46+
47+
# Test invalid phase
48+
with pytest.raises(ValueError):
49+
get_replay_tools('invalid_phase', default_tools)

0 commit comments

Comments
 (0)