Skip to content

Commit a4b540c

Browse files
committed
WIP: tool call fixes
1 parent cd2a715 commit a4b540c

File tree

2 files changed

+73
-39
lines changed

2 files changed

+73
-39
lines changed

openhands/replay/replay_tools.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
class ReplayToolType(Enum):
24-
Analysis = ('analysis',)
24+
Analysis = 'analysis'
2525
PhaseTransition = 'phase_transition'
2626

2727

@@ -33,11 +33,36 @@ class ReplayAnalysisTool(ReplayTool):
3333
replay_tool_type = ReplayToolType.Analysis
3434

3535

36-
def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
37-
f = ChatCompletionToolParamFunctionChunk(
38-
name=name, description=description, parameters=parameters
36+
def replay_analysis_tool(name: str, description: str, parameters: dict) -> ReplayTool:
37+
tool = ReplayAnalysisTool(
38+
replay_tool_type=ReplayToolType.Analysis,
39+
type='function',
40+
function=ChatCompletionToolParamFunctionChunk(
41+
name=name, description=description, parameters=parameters
42+
),
3943
)
40-
return ReplayAnalysisTool(type='function', function=f)
44+
return tool
45+
46+
47+
class ReplayPhaseTransitionTool(ReplayTool):
48+
replay_tool_type = ReplayToolType.PhaseTransition
49+
new_phase: ReplayPhase
50+
51+
52+
def replay_phase_tool(
53+
new_phase: ReplayPhase, name: str, description: str, parameters: dict
54+
):
55+
tool = ReplayPhaseTransitionTool(
56+
replay_tool_type=ReplayToolType.PhaseTransition,
57+
new_phase=new_phase,
58+
type='function',
59+
function=ChatCompletionToolParamFunctionChunk(
60+
name=name,
61+
description=description,
62+
parameters=parameters,
63+
),
64+
)
65+
return tool
4166

4267

4368
# ###########################################################################
@@ -50,7 +75,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
5075
IMPORTANT: Prefer using inspect-data over inspect-point.
5176
"""
5277

53-
ReplayInspectDataTool = replay_tool(
78+
ReplayInspectDataTool = replay_analysis_tool(
5479
name='inspect-data',
5580
description=_REPLAY_INSPECT_DATA_DESCRIPTION.strip(),
5681
parameters={
@@ -82,7 +107,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
82107
Use this tool instead of `inspect-data` only when you don't have a specific data point to investigate.
83108
"""
84109

85-
ReplayInspectPointTool = replay_tool(
110+
ReplayInspectPointTool = replay_analysis_tool(
86111
name='inspect-point',
87112
description=_REPLAY_INSPECT_POINT_DESCRIPTION.strip(),
88113
parameters={
@@ -100,23 +125,6 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
100125
# ###########################################################################
101126

102127

103-
class ReplayPhaseTransitionTool(ReplayTool):
104-
replay_tool_type = ReplayToolType.PhaseTransition
105-
new_phase: ReplayPhase
106-
107-
108-
def replay_phase_tool(
109-
new_phase: ReplayPhase, name: str, description: str, parameters: dict
110-
):
111-
return ReplayPhaseTransitionTool(
112-
new_phase=new_phase,
113-
type='function',
114-
function=ChatCompletionToolParamFunctionChunk(
115-
name=name, description=description, parameters=parameters
116-
),
117-
)
118-
119-
120128
replay_phase_transition_tools: list[ReplayPhaseTransitionTool] = [
121129
replay_phase_tool(
122130
ReplayPhase.Edit,
@@ -155,8 +163,7 @@ def replay_phase_tool(
155163
]
156164
replay_tool_names: set[str] = set([t['function']['name'] for t in replay_tools])
157165
replay_replay_tool_type_by_name = {
158-
t['function']['name']: t['function'].get('replay_tool_type', None)
159-
for t in replay_tools
166+
t['function']['name']: t.get('replay_tool_type', None) for t in replay_tools
160167
}
161168

162169

@@ -174,6 +181,24 @@ def is_replay_tool(
174181
# ###########################################################################
175182

176183

184+
def get_replay_transition_tool_for_current_phase(
185+
current_phase: ReplayPhase, name: str | None = None
186+
) -> ReplayTool | None:
187+
next_phase = get_replay_child_phase(current_phase)
188+
if next_phase:
189+
transition_tools = [
190+
t
191+
for t in replay_phase_transition_tools
192+
if t['new_phase'] == next_phase
193+
and (not name or t['function']['name'] == name)
194+
]
195+
assert len(
196+
transition_tools
197+
), f'replay_phase_transition_tools is missing tools for new_phase: {next_phase}'
198+
return transition_tools[0]
199+
return None
200+
201+
177202
def get_replay_tools(
178203
replay_phase: ReplayPhase, default_tools: list[ChatCompletionToolParam]
179204
) -> list[ChatCompletionToolParam]:
@@ -190,15 +215,9 @@ def get_replay_tools(
190215
raise ValueError(f'Unhandled ReplayPhase in get_tools: {replay_phase}')
191216

192217
# Add tools to allow transitioning to next phase.
193-
next_phase = get_replay_child_phase(replay_phase)
194-
if next_phase:
195-
transition_tools = [
196-
t for t in replay_phase_transition_tools if t['new_phase'] == next_phase
197-
]
198-
assert len(
199-
transition_tools
200-
), f'replay_phase_transition_tools is missing tools for new_phase: {next_phase}'
201-
tools += transition_tools
218+
next_phase_tool = get_replay_transition_tool_for_current_phase(replay_phase)
219+
if next_phase_tool:
220+
tools.append(next_phase_tool)
202221

203222
# Return all tools.
204223
return tools
@@ -234,8 +253,18 @@ def handle_replay_tool_call(
234253
)
235254
elif is_replay_tool(name, ReplayToolType.PhaseTransition):
236255
# Request a phase change.
256+
tool = get_replay_transition_tool_for_current_phase(state.replay_phase, name)
257+
assert tool, f'Missing ReplayPhaseTransitionTool for {state.replay_phase} in Replay tool_call: {tool_call.function.name}'
258+
new_phase = tool['new_phase']
259+
assert (
260+
new_phase
261+
), f'Missing new_phase in Replay tool_call: {tool_call.function.name}'
262+
assert (
263+
new_phase
264+
), f'Missing new_phase in Replay tool_call: {tool_call.function.name}'
265+
del arguments['new_phase']
237266
action = ReplayPhaseUpdateAction(
238-
new_phase=tool_call['new_phase'], info=json.dumps(arguments)
267+
new_phase=new_phase, info=json.dumps(arguments)
239268
)
240269
assert action, f'Unhandled Replay tool_call: {tool_call.function.name}'
241270
return action

replay_benchmarks/bolt/run-bolt.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ INSTANCE_ID=$1
99
PROMPT_NAME="$2"
1010

1111
THIS_DIR="$(dirname "$0")"
12-
OH_ROOT="$THIS_DIR/.."
12+
OH_ROOT="$THIS_DIR/../.."
1313
OH_ROOT="$(node -e 'console.log(require("path").resolve(process.argv[1]))' $OH_ROOT)"
1414
if [[ -z "$TMP_DIR" ]]; then
1515
TMP_DIR="/tmp"
@@ -65,11 +65,14 @@ else
6565
fi
6666

6767
# Config overrides + sanity checks.
68+
set -x
6869
export DEBUG=1
6970
# export REPLAY_DEV_MODE=1
7071
export REPLAY_ENABLE_TOOL_CACHE=1
7172
export WORKSPACE_BASE="$WORKSPACE_ROOT"
7273
export LLM_MODEL="anthropic/claude-3-5-sonnet-20241022"
74+
set +x
75+
7376
if [[ -z "$LLM_API_KEY" ]]; then
7477
if [[ -z "$ANTHROPIC_API_KEY" ]]; then
7578
echo "LLM_API_KEY or ANTHROPIC_API_KEY environment variable must be set."
@@ -84,9 +87,11 @@ echo "WORKSPACE_ROOT: \"$WORKSPACE_ROOT\""
8487
echo "Logging to \"$LOG_FILE\"..."
8588

8689
# GO.
90+
PROMPT_ONELINE=$(echo "$PROMPT" | tr '\n' " ")
8791
cd $OH_ROOT
88-
poetry run python -m openhands.core.main -t "$PROMPT" \
89-
> "$LOG_FILE" 2>&1
92+
set -x
93+
poetry run python -m openhands.core.main -t "$PROMPT_ONELINE" >"${LOG_FILE}" 2>&1
94+
set +x
9095

9196

9297
# Log the relevant diff.

0 commit comments

Comments
 (0)