Skip to content

Commit

Permalink
more tool fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Domiii committed Jan 24, 2025
1 parent 5e320cb commit 954e435
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 16 deletions.
2 changes: 1 addition & 1 deletion openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def update_tools(self, phase: ReplayPhase) -> None:
replay_phase=phase,
)
logger.info(
f'[REPLAY] update_tools: {json.dumps([t["function"]['name'] for t in self.tools], indent=2)}'
f'[REPLAY] update_tools for phase {phase}: {json.dumps([t["function"]['name'] for t in self.tools], indent=2)}'
)

def step(self, state: State) -> Action:
Expand Down
5 changes: 5 additions & 0 deletions openhands/replay/replay_initial_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.replay import ReplayPhase
from openhands.events.action.action import Action
from openhands.events.action.message import MessageAction
from openhands.events.action.replay import ReplayInternalCmdRunAction
Expand Down Expand Up @@ -44,6 +45,10 @@ def start_initial_analysis(


def replay_enhance_action(state: State, is_workspace_repo: bool) -> Action | None:
if state.replay_phase != ReplayPhase.Normal:
# We currently only enhance prompts in the Normal phase.
return None

if state.replay_enhance_prompt_id == -1:
# 1. Get current user prompt.
latest_user_message = state.get_last_user_message()
Expand Down
2 changes: 1 addition & 1 deletion openhands/replay/replay_phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def get_phase_prompt(obs) -> str:

def update_phase(new_phase: ReplayPhase, state: State, agent: Agent):
"""Apply phase update side effects."""
logger.info(f'[REPLAY] update_phase (replay_phase): {new_phase}')
state.replay_phase = new_phase
agent.update_tools(new_phase)
logger.info(f'[REPLAY] update_phase (replay_phase): {new_phase}')


# ###########################################################################
Expand Down
17 changes: 6 additions & 11 deletions openhands/replay/replay_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def replay_phase_tool(
# Bookkeeping + utilities.
# ###########################################################################

replay_analysis_tools: list[ReplayTool] = [
replay_analysis_tools: tuple[ReplayTool, ...] = (
ReplayInspectDataTool,
ReplayInspectPointTool,
]
)

replay_tools: list[ReplayTool] = [
replay_tools: tuple[ReplayTool, ...] = (
*replay_analysis_tools,
*replay_phase_transition_tools,
]
)
replay_tool_names: set[str] = set([t['function']['name'] for t in replay_tools])
replay_replay_tool_type_by_name = {
t['function']['name']: t.get('replay_tool_type', None) for t in replay_tools
Expand Down Expand Up @@ -203,23 +203,18 @@ def get_replay_tools(
replay_phase: ReplayPhase, default_tools: list[ChatCompletionToolParam]
) -> list[ChatCompletionToolParam]:
if replay_phase == ReplayPhase.Normal:
# Use the default tools when not in a Replay-specific phase.
tools = default_tools
elif replay_phase == ReplayPhase.Analysis:
# Only allow analysis in this phase.
tools = replay_analysis_tools
tools = list(replay_analysis_tools)
elif replay_phase == ReplayPhase.Edit:
# Combine default and analysis tools.
tools = default_tools + replay_analysis_tools
tools = default_tools + list(replay_analysis_tools)
else:
raise ValueError(f'Unhandled ReplayPhase in get_tools: {replay_phase}')

# Add tools to allow transitioning to next phase.
next_phase_tool = get_replay_transition_tool_for_current_phase(replay_phase)
if next_phase_tool:
tools.append(next_phase_tool)

# Return all tools.
return tools


Expand Down
41 changes: 38 additions & 3 deletions tests/unit/replay/test_replay_tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,49 @@
import pytest

from openhands.core.schema.replay import ReplayPhase
from openhands.replay.replay_phases import get_next_agent_replay_phase
from openhands.replay.replay_tools import (
get_replay_tools,
get_replay_transition_tool_for_current_phase,
replay_analysis_tools,
)


def test_get_replay_transition_tool_for_analysis_phase():
def test_get_replay_transition_tools_analysis():
assert get_next_agent_replay_phase(ReplayPhase.Analysis) is not None
tool = get_replay_transition_tool_for_current_phase(ReplayPhase.Analysis, 'submit')
assert tool is not None
assert tool
assert tool['function']['name'] == 'submit'
assert tool['new_phase'] is not None
assert get_next_agent_replay_phase(ReplayPhase.Analysis) is not None
assert tool['new_phase'] == get_next_agent_replay_phase(ReplayPhase.Analysis)


def test_get_replay_transition_tools_edit():
assert get_next_agent_replay_phase(ReplayPhase.Edit) is None
tool = get_replay_transition_tool_for_current_phase(ReplayPhase.Edit, 'submit')
assert not tool


def test_get_tools():
default_tools = []

# Test Normal phase
tools = get_replay_tools(ReplayPhase.Normal, default_tools)
assert len(tools) == len(default_tools)
assert all(t in tools for t in default_tools)

# Test Analysis phase
tools = get_replay_tools(ReplayPhase.Analysis, default_tools)
assert len(tools) == len(replay_analysis_tools) + 1 # +1 for transition tool
assert all(t in tools for t in replay_analysis_tools)
assert tools[-1]['function']['name'] == 'submit'

# Test Edit phase
tools = get_replay_tools(ReplayPhase.Edit, default_tools)
assert len(tools) == len(default_tools) + len(replay_analysis_tools)
assert all(t in tools for t in default_tools)
assert all(t in tools for t in replay_analysis_tools)

# Test invalid phase
with pytest.raises(ValueError):
get_replay_tools('invalid_phase', default_tools)

0 comments on commit 954e435

Please sign in to comment.