Skip to content

Commit

Permalink
isolate state transition logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Domiii committed Jan 21, 2025
1 parent ebc4d21 commit 128a6db
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 82 deletions.
21 changes: 8 additions & 13 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from collections import deque

Expand Down Expand Up @@ -41,8 +42,8 @@
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.replay.replay_initial_analysis import replay_enhance_action
from openhands.replay.replay_state_machine import (
get_replay_observation_message,
from openhands.replay.replay_phases import (
on_agent_replay_observation,
)
from openhands.runtime.plugins import (
AgentSkillsRequirement,
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(

# We're in normal mode by default (even if replay is not enabled).
# This will initialize the set of tools the agent has access to.
self.replay_phase_changed(ReplayPhase.Normal)
self.update_tools(ReplayPhase.Normal)

self.prompt_manager = PromptManager(
microagent_dir=os.path.join(os.path.dirname(__file__), 'micro')
Expand Down Expand Up @@ -256,7 +257,7 @@ def get_observation_message(
text += f'\n[Command finished with exit code {obs.exit_code}]'
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, ReplayObservation):
message = get_replay_observation_message(obs, max_message_chars)
message = on_agent_replay_observation(obs, max_message_chars)
elif isinstance(obs, IPythonRunCellObservation):
text = obs.content
# replace base64 images with a placeholder
Expand Down Expand Up @@ -316,22 +317,16 @@ def reset(self) -> None:
"""Resets the CodeAct Agent."""
super().reset()

def replay_phase_changed(self, phase: ReplayPhase) -> None:
"""Called whenenever the phase of the replay debugging process changes.
We currently use this to give the agent access to different tools for the
different phases.
"""
def update_tools(self, phase: ReplayPhase) -> None:
self.tools = codeact_function_calling.get_tools(
codeact_enable_browsing=self.config.codeact_enable_browsing,
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
codeact_enable_replay=self.config.codeact_enable_replay,
replay_phase=phase,
)
logger.debug(
f'[REPLAY] CodeActAgent.replay_phase_changed({phase}).'
# f'New tools: {json.dumps(self.tools, indent=2)}'
logger.info(
f'[REPLAY] update_tools: {json.dumps([t.function['name'] for t in self.tools], indent=2)}'
)

def step(self, state: State) -> Action:
Expand Down
10 changes: 4 additions & 6 deletions openhands/agenthub/codeact_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,9 @@ def get_tools(
codeact_enable_llm_editor,
codeact_enable_jupyter,
)
if not codeact_enable_replay or replay_phase == ReplayPhase.Normal:
# Use the default tools when not in a Replay-specific phase.
return default_tools

if codeact_enable_replay:
tools = get_replay_tools(replay_phase, default_tools)
# Handle Replay tool updates.
return get_replay_tools(replay_phase, default_tools)

return tools
# Just the default tools.
return default_tools
9 changes: 3 additions & 6 deletions openhands/controller/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,9 @@ def reset(self) -> None:
if self.llm:
self.llm.reset()

# the "noqa" below is so we can add an empty but not abstract method. otherwise we'd need an empty
# implmentation in every subclass other than CodeActAgent (which does override it.)
def replay_phase_changed(self, phase: ReplayPhase) -> None: # noqa: B027
"""Called when the phase of the replay debugging process changes. This method
can be used to update the agent's behavior based on the phase.
"""
# `noqa: B027` is necessary to have an empty method implementation.
def update_tools(self, phase: ReplayPhase) -> None: # noqa: B027
"""Agent tools might have changed due to some observed event."""
pass

@property
Expand Down
8 changes: 3 additions & 5 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from openhands.core.logger import LOG_ALL_EVENTS
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState, ReplayPhase
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
Expand Down Expand Up @@ -53,7 +53,7 @@
)
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.replay.replay_state_machine import on_replay_observation
from openhands.replay.replay_phases import on_controller_replay_observation
from openhands.utils.shutdown_listener import should_continue

# note: RESUME is only available on web GUI
Expand Down Expand Up @@ -145,8 +145,6 @@ def __init__(
self._stuck_detector = StuckDetector(self.state)
self.status_callback = status_callback

self.replay_phase = ReplayPhase.Normal

async def close(self) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
Expand Down Expand Up @@ -298,7 +296,7 @@ async def _handle_observation(self, observation: Observation) -> None:
if self._pending_action and self._pending_action.id == observation.cause:
self._pending_action = None
if isinstance(observation, ReplayObservation):
on_replay_observation(observation, self.state, self.agent)
on_controller_replay_observation(observation, self.state, self.agent)

if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
Expand Down
4 changes: 1 addition & 3 deletions openhands/replay/replay_initial_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def on_replay_internal_command_observation(
state: State, observation: ReplayInternalCmdOutputObservation
) -> AnalysisToolMetadata | None:
"""
Handle result for an internally sent command (not agent tool use or user action).
NOTE: Currently, the only internal command is the initial-analysis command.
Handle result for an internal (automatically or user-triggered sent) command.
Enhance the user prompt with the results of the initial analysis.
Returns the metadata needed for the agent to switch to analysis tools.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,83 @@
)
from openhands.replay.replay_prompts import replay_prompt_phase_edit

# ###########################################################################
# Phase events.
# ###########################################################################

def on_replay_observation(obs: ReplayObservation, state: State, agent: Agent) -> None:

def on_controller_replay_observation(
obs: ReplayObservation, state: State, agent: Agent
) -> None:
"""Handle the observation."""
new_phase: ReplayPhase | None = None
if isinstance(obs, ReplayInternalCmdOutputObservation):
# NOTE: Currently, the only internal command is the initial-analysis command.
analysis_tool_metadata = on_replay_internal_command_observation(state, obs)
if analysis_tool_metadata:
# Start analysis phase
state.replay_recording_id = analysis_tool_metadata['recordingId']
state.replay_phase = ReplayPhase.Analysis
agent.replay_phase_changed(ReplayPhase.Analysis)
new_phase = ReplayPhase.Analysis
elif isinstance(obs, ReplayPhaseUpdateObservation):
# Agent action triggered a phase change.
new_phase = obs.new_phase

if new_phase:
if state.replay_phase == new_phase:
logger.warning(
f'Unexpected ReplayPhaseUpdateAction. Already in phase. Observation:\n {repr(obs)}',
)
else:
state.replay_phase = new_phase
agent.replay_phase_changed(new_phase)
update_phase(new_phase, state, agent)


def get_replay_observation_message(
def on_agent_replay_observation(
obs: ReplayObservation, max_message_chars: int
) -> Message:
"""Create a message to explain the observation."""
text: str
if isinstance(obs, ReplayToolCmdOutputObservation):
# if it doesn't have tool call metadata, it was triggered by a user action
# Internal command result from an automatic or user-triggered replay command.
if obs.tool_call_metadata is None:
# If it doesn't have tool call metadata, it was triggered by a user action.
text = truncate_content(
f'\nObserved result of replay command executed by user:\n{obs.content}',
max_message_chars,
)
else:
text = obs.content
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, ReplayPhaseUpdateObservation):
new_phase = obs.new_phase
if new_phase == ReplayPhase.Edit:
text = replay_prompt_phase_edit(obs)
else:
raise NotImplementedError(f'Unhandled ReplayPhaseUpdateAction: {new_phase}')
message = Message(role='user', content=[TextContent(text=text)])
# Agent requested a phase update.
text = get_new_phase_prompt(obs)
else:
raise NotImplementedError(
f"Unhandled observation type: {obs.__class__.__name__} ({getattr(obs, 'observation', None)})"
)
return message
return Message(role='user', content=[TextContent(text=text)])


# ###########################################################################
# Prompts.
# ###########################################################################


def get_new_phase_prompt(obs: ReplayPhaseUpdateObservation) -> str:
"""Get the prompt for the new phase."""
new_phase = obs.new_phase
if new_phase == ReplayPhase.Edit:
new_phase_prompt = replay_prompt_phase_edit(obs)
else:
raise NotImplementedError(f'Unhandled ReplayPhaseUpdateAction: {new_phase}')
return new_phase_prompt


# ###########################################################################
# State machine transitions.
# ###########################################################################


def update_phase(new_phase: ReplayPhase, state: State, agent: Agent):
"""Apply phase update side effects."""
state.replay_phase = new_phase
agent.update_tools(new_phase)
logger.info(f'[REPLAY] update_phase (replay_phase): {new_phase}')
67 changes: 33 additions & 34 deletions openhands/replay/replay_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,11 @@ def replay_prompt_phase_analysis(command_result: dict, prompt: str) -> str:
3. Then use the `inspect-*` tools to investigate.
4. Once found, `submit-hypothesis`.
# Initial Analysis
""" + json.dumps(command_result, indent=2)
return enhance_prompt(prompt, prefix, suffix)


def replay_prompt_phase_analysis_legacy(command_result: dict, prompt: str) -> str:
# Old workflow: initial-analysis left hints in form of source code annotations.
annotated_repo_path = command_result.get('annotatedRepo', '')
comment_text = command_result.get('commentText', '')
react_component_name = command_result.get('reactComponentName', '')
console_error = command_result.get('consoleError', '')
# start_location = result.get('startLocation', '')
start_name = command_result.get('startName', '')

# TODO: Move this to a prompt template file.
if comment_text:
if react_component_name:
prefix = (
f'There is a change needed to the {react_component_name} component.\n'
)
else:
prefix = f'There is a change needed in {annotated_repo_path}:\n'
prefix += f'{comment_text}\n\n'
elif console_error:
prefix = f'There is a change needed in {annotated_repo_path} to fix a console error that has appeared unexpectedly:\n'
prefix += f'{console_error}\n\n'

prefix += '<IMPORTANT>\n'
prefix += 'Information about a reproduction of the problem is available in source comments.\n'
prefix += 'You must search for these comments and use them to get a better understanding of the problem.\n'
prefix += f'The first reproduction comment to search for is named {start_name}. Start your investigation there.\n'
prefix += '</IMPORTANT>\n'

suffix = ''

return enhance_prompt(prompt, prefix, suffix)


def replay_prompt_phase_edit(obs: ReplayPhaseUpdateObservation) -> str:
# Tell the agent to stop analyzing and start editing:
return """
Expand All @@ -73,3 +39,36 @@ def replay_prompt_phase_edit(obs: ReplayPhaseUpdateObservation) -> str:
3. Do the `editSuggestions` actually address the issue?
4. Rephrase the hypothesis so that it is consistent and correct.
"""


# def replay_prompt_phase_analysis_legacy(command_result: dict, prompt: str) -> str:
# # Old workflow: initial-analysis left hints in form of source code annotations.
# annotated_repo_path = command_result.get('annotatedRepo', '')
# comment_text = command_result.get('commentText', '')
# react_component_name = command_result.get('reactComponentName', '')
# console_error = command_result.get('consoleError', '')
# # start_location = result.get('startLocation', '')
# start_name = command_result.get('startName', '')

# # TODO: Move this to a prompt template file.
# if comment_text:
# if react_component_name:
# prefix = (
# f'There is a change needed to the {react_component_name} component.\n'
# )
# else:
# prefix = f'There is a change needed in {annotated_repo_path}:\n'
# prefix += f'{comment_text}\n\n'
# elif console_error:
# prefix = f'There is a change needed in {annotated_repo_path} to fix a console error that has appeared unexpectedly:\n'
# prefix += f'{console_error}\n\n'

# prefix += '<IMPORTANT>\n'
# prefix += 'Information about a reproduction of the problem is available in source comments.\n'
# prefix += 'You must search for these comments and use them to get a better understanding of the problem.\n'
# prefix += f'The first reproduction comment to search for is named {start_name}. Start your investigation there.\n'
# prefix += '</IMPORTANT>\n'

# suffix = ''

# return enhance_prompt(prompt, prefix, suffix)
4 changes: 4 additions & 0 deletions openhands/replay/replay_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def handle_replay_tool_call(
def get_replay_tools(
replay_phase: ReplayPhase, default_tools: list[ChatCompletionToolParam]
) -> list[ChatCompletionToolParam]:
if replay_phase == ReplayPhase.Normal:
# Use the default tools when not in a Replay-specific phase.
return default_tools

analysis_tools = [
ReplayInspectDataTool,
ReplayInspectPointTool,
Expand Down

0 comments on commit 128a6db

Please sign in to comment.