Skip to content

Commit

Permalink
WIP: tool call fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Domiii committed Jan 21, 2025
1 parent cd2a715 commit a4b540c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 39 deletions.
101 changes: 65 additions & 36 deletions openhands/replay/replay_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class ReplayToolType(Enum):
Analysis = ('analysis',)
Analysis = 'analysis'
PhaseTransition = 'phase_transition'


Expand All @@ -33,11 +33,36 @@ class ReplayAnalysisTool(ReplayTool):
replay_tool_type = ReplayToolType.Analysis


def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
f = ChatCompletionToolParamFunctionChunk(
name=name, description=description, parameters=parameters
def replay_analysis_tool(name: str, description: str, parameters: dict) -> ReplayTool:
tool = ReplayAnalysisTool(
replay_tool_type=ReplayToolType.Analysis,
type='function',
function=ChatCompletionToolParamFunctionChunk(
name=name, description=description, parameters=parameters
),
)
return ReplayAnalysisTool(type='function', function=f)
return tool


class ReplayPhaseTransitionTool(ReplayTool):
replay_tool_type = ReplayToolType.PhaseTransition
new_phase: ReplayPhase


def replay_phase_tool(
new_phase: ReplayPhase, name: str, description: str, parameters: dict
):
tool = ReplayPhaseTransitionTool(
replay_tool_type=ReplayToolType.PhaseTransition,
new_phase=new_phase,
type='function',
function=ChatCompletionToolParamFunctionChunk(
name=name,
description=description,
parameters=parameters,
),
)
return tool


# ###########################################################################
Expand All @@ -50,7 +75,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
IMPORTANT: Prefer using inspect-data over inspect-point.
"""

ReplayInspectDataTool = replay_tool(
ReplayInspectDataTool = replay_analysis_tool(
name='inspect-data',
description=_REPLAY_INSPECT_DATA_DESCRIPTION.strip(),
parameters={
Expand Down Expand Up @@ -82,7 +107,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
Use this tool instead of `inspect-data` only when you don't have a specific data point to investigate.
"""

ReplayInspectPointTool = replay_tool(
ReplayInspectPointTool = replay_analysis_tool(
name='inspect-point',
description=_REPLAY_INSPECT_POINT_DESCRIPTION.strip(),
parameters={
Expand All @@ -100,23 +125,6 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
# ###########################################################################


class ReplayPhaseTransitionTool(ReplayTool):
replay_tool_type = ReplayToolType.PhaseTransition
new_phase: ReplayPhase


def replay_phase_tool(
new_phase: ReplayPhase, name: str, description: str, parameters: dict
):
return ReplayPhaseTransitionTool(
new_phase=new_phase,
type='function',
function=ChatCompletionToolParamFunctionChunk(
name=name, description=description, parameters=parameters
),
)


replay_phase_transition_tools: list[ReplayPhaseTransitionTool] = [
replay_phase_tool(
ReplayPhase.Edit,
Expand Down Expand Up @@ -155,8 +163,7 @@ def replay_phase_tool(
]
replay_tool_names: set[str] = set([t['function']['name'] for t in replay_tools])
replay_replay_tool_type_by_name = {
t['function']['name']: t['function'].get('replay_tool_type', None)
for t in replay_tools
t['function']['name']: t.get('replay_tool_type', None) for t in replay_tools
}


Expand All @@ -174,6 +181,24 @@ def is_replay_tool(
# ###########################################################################


def get_replay_transition_tool_for_current_phase(
current_phase: ReplayPhase, name: str | None = None
) -> ReplayTool | None:
next_phase = get_replay_child_phase(current_phase)
if next_phase:
transition_tools = [
t
for t in replay_phase_transition_tools
if t['new_phase'] == next_phase
and (not name or t['function']['name'] == name)
]
assert len(
transition_tools
), f'replay_phase_transition_tools is missing tools for new_phase: {next_phase}'
return transition_tools[0]
return None


def get_replay_tools(
replay_phase: ReplayPhase, default_tools: list[ChatCompletionToolParam]
) -> list[ChatCompletionToolParam]:
Expand All @@ -190,15 +215,9 @@ def get_replay_tools(
raise ValueError(f'Unhandled ReplayPhase in get_tools: {replay_phase}')

# Add tools to allow transitioning to next phase.
next_phase = get_replay_child_phase(replay_phase)
if next_phase:
transition_tools = [
t for t in replay_phase_transition_tools if t['new_phase'] == next_phase
]
assert len(
transition_tools
), f'replay_phase_transition_tools is missing tools for new_phase: {next_phase}'
tools += transition_tools
next_phase_tool = get_replay_transition_tool_for_current_phase(replay_phase)
if next_phase_tool:
tools.append(next_phase_tool)

# Return all tools.
return tools
Expand Down Expand Up @@ -234,8 +253,18 @@ def handle_replay_tool_call(
)
elif is_replay_tool(name, ReplayToolType.PhaseTransition):
# Request a phase change.
tool = get_replay_transition_tool_for_current_phase(state.replay_phase, name)
assert tool, f'Missing ReplayPhaseTransitionTool for {state.replay_phase} in Replay tool_call: {tool_call.function.name}'
new_phase = tool['new_phase']
assert (
new_phase
), f'Missing new_phase in Replay tool_call: {tool_call.function.name}'
assert (
new_phase
), f'Missing new_phase in Replay tool_call: {tool_call.function.name}'
del arguments['new_phase']
action = ReplayPhaseUpdateAction(
new_phase=tool_call['new_phase'], info=json.dumps(arguments)
new_phase=new_phase, info=json.dumps(arguments)
)
assert action, f'Unhandled Replay tool_call: {tool_call.function.name}'
return action
11 changes: 8 additions & 3 deletions replay_benchmarks/bolt/run-bolt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ INSTANCE_ID=$1
PROMPT_NAME="$2"

THIS_DIR="$(dirname "$0")"
OH_ROOT="$THIS_DIR/.."
OH_ROOT="$THIS_DIR/../.."
OH_ROOT="$(node -e 'console.log(require("path").resolve(process.argv[1]))' $OH_ROOT)"
if [[ -z "$TMP_DIR" ]]; then
TMP_DIR="/tmp"
Expand Down Expand Up @@ -65,11 +65,14 @@ else
fi

# Config overrides + sanity checks.
set -x
export DEBUG=1
# export REPLAY_DEV_MODE=1
export REPLAY_ENABLE_TOOL_CACHE=1
export WORKSPACE_BASE="$WORKSPACE_ROOT"
export LLM_MODEL="anthropic/claude-3-5-sonnet-20241022"
set +x

if [[ -z "$LLM_API_KEY" ]]; then
if [[ -z "$ANTHROPIC_API_KEY" ]]; then
echo "LLM_API_KEY or ANTHROPIC_API_KEY environment variable must be set."
Expand All @@ -84,9 +87,11 @@ echo "WORKSPACE_ROOT: \"$WORKSPACE_ROOT\""
echo "Logging to \"$LOG_FILE\"..."

# GO.
PROMPT_ONELINE=$(echo "$PROMPT" | tr '\n' " ")
cd $OH_ROOT
poetry run python -m openhands.core.main -t "$PROMPT" \
> "$LOG_FILE" 2>&1
set -x
poetry run python -m openhands.core.main -t "$PROMPT_ONELINE" >"${LOG_FILE}" 2>&1
set +x


# Log the relevant diff.
Expand Down

0 comments on commit a4b540c

Please sign in to comment.