21
21
22
22
23
23
class ReplayToolType (Enum ):
24
- Analysis = ( 'analysis' ,)
24
+ Analysis = 'analysis'
25
25
PhaseTransition = 'phase_transition'
26
26
27
27
@@ -33,11 +33,36 @@ class ReplayAnalysisTool(ReplayTool):
33
33
replay_tool_type = ReplayToolType .Analysis
34
34
35
35
36
- def replay_tool (name : str , description : str , parameters : dict ) -> ReplayTool :
37
- f = ChatCompletionToolParamFunctionChunk (
38
- name = name , description = description , parameters = parameters
36
+ def replay_analysis_tool (name : str , description : str , parameters : dict ) -> ReplayTool :
37
+ tool = ReplayAnalysisTool (
38
+ replay_tool_type = ReplayToolType .Analysis ,
39
+ type = 'function' ,
40
+ function = ChatCompletionToolParamFunctionChunk (
41
+ name = name , description = description , parameters = parameters
42
+ ),
39
43
)
40
- return ReplayAnalysisTool (type = 'function' , function = f )
44
+ return tool
45
+
46
+
47
+ class ReplayPhaseTransitionTool (ReplayTool ):
48
+ replay_tool_type = ReplayToolType .PhaseTransition
49
+ new_phase : ReplayPhase
50
+
51
+
52
+ def replay_phase_tool (
53
+ new_phase : ReplayPhase , name : str , description : str , parameters : dict
54
+ ):
55
+ tool = ReplayPhaseTransitionTool (
56
+ replay_tool_type = ReplayToolType .PhaseTransition ,
57
+ new_phase = new_phase ,
58
+ type = 'function' ,
59
+ function = ChatCompletionToolParamFunctionChunk (
60
+ name = name ,
61
+ description = description ,
62
+ parameters = parameters ,
63
+ ),
64
+ )
65
+ return tool
41
66
42
67
43
68
# ###########################################################################
@@ -50,7 +75,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
50
75
IMPORTANT: Prefer using inspect-data over inspect-point.
51
76
"""
52
77
53
- ReplayInspectDataTool = replay_tool (
78
+ ReplayInspectDataTool = replay_analysis_tool (
54
79
name = 'inspect-data' ,
55
80
description = _REPLAY_INSPECT_DATA_DESCRIPTION .strip (),
56
81
parameters = {
@@ -82,7 +107,7 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
82
107
Use this tool instead of `inspect-data` only when you don't have a specific data point to investigate.
83
108
"""
84
109
85
- ReplayInspectPointTool = replay_tool (
110
+ ReplayInspectPointTool = replay_analysis_tool (
86
111
name = 'inspect-point' ,
87
112
description = _REPLAY_INSPECT_POINT_DESCRIPTION .strip (),
88
113
parameters = {
@@ -100,23 +125,6 @@ def replay_tool(name: str, description: str, parameters: dict) -> ReplayTool:
100
125
# ###########################################################################
101
126
102
127
103
- class ReplayPhaseTransitionTool (ReplayTool ):
104
- replay_tool_type = ReplayToolType .PhaseTransition
105
- new_phase : ReplayPhase
106
-
107
-
108
- def replay_phase_tool (
109
- new_phase : ReplayPhase , name : str , description : str , parameters : dict
110
- ):
111
- return ReplayPhaseTransitionTool (
112
- new_phase = new_phase ,
113
- type = 'function' ,
114
- function = ChatCompletionToolParamFunctionChunk (
115
- name = name , description = description , parameters = parameters
116
- ),
117
- )
118
-
119
-
120
128
replay_phase_transition_tools : list [ReplayPhaseTransitionTool ] = [
121
129
replay_phase_tool (
122
130
ReplayPhase .Edit ,
@@ -155,8 +163,7 @@ def replay_phase_tool(
155
163
]
156
164
replay_tool_names : set [str ] = set ([t ['function' ]['name' ] for t in replay_tools ])
157
165
replay_replay_tool_type_by_name = {
158
- t ['function' ]['name' ]: t ['function' ].get ('replay_tool_type' , None )
159
- for t in replay_tools
166
+ t ['function' ]['name' ]: t .get ('replay_tool_type' , None ) for t in replay_tools
160
167
}
161
168
162
169
@@ -174,6 +181,24 @@ def is_replay_tool(
174
181
# ###########################################################################
175
182
176
183
184
+ def get_replay_transition_tool_for_current_phase (
185
+ current_phase : ReplayPhase , name : str | None = None
186
+ ) -> ReplayTool | None :
187
+ next_phase = get_replay_child_phase (current_phase )
188
+ if next_phase :
189
+ transition_tools = [
190
+ t
191
+ for t in replay_phase_transition_tools
192
+ if t ['new_phase' ] == next_phase
193
+ and (not name or t ['function' ]['name' ] == name )
194
+ ]
195
+ assert len (
196
+ transition_tools
197
+ ), f'replay_phase_transition_tools is missing tools for new_phase: { next_phase } '
198
+ return transition_tools [0 ]
199
+ return None
200
+
201
+
177
202
def get_replay_tools (
178
203
replay_phase : ReplayPhase , default_tools : list [ChatCompletionToolParam ]
179
204
) -> list [ChatCompletionToolParam ]:
@@ -190,15 +215,9 @@ def get_replay_tools(
190
215
raise ValueError (f'Unhandled ReplayPhase in get_tools: { replay_phase } ' )
191
216
192
217
# Add tools to allow transitioning to next phase.
193
- next_phase = get_replay_child_phase (replay_phase )
194
- if next_phase :
195
- transition_tools = [
196
- t for t in replay_phase_transition_tools if t ['new_phase' ] == next_phase
197
- ]
198
- assert len (
199
- transition_tools
200
- ), f'replay_phase_transition_tools is missing tools for new_phase: { next_phase } '
201
- tools += transition_tools
218
+ next_phase_tool = get_replay_transition_tool_for_current_phase (replay_phase )
219
+ if next_phase_tool :
220
+ tools .append (next_phase_tool )
202
221
203
222
# Return all tools.
204
223
return tools
@@ -234,8 +253,18 @@ def handle_replay_tool_call(
234
253
)
235
254
elif is_replay_tool (name , ReplayToolType .PhaseTransition ):
236
255
# Request a phase change.
256
+ tool = get_replay_transition_tool_for_current_phase (state .replay_phase , name )
257
+ assert tool , f'Missing ReplayPhaseTransitionTool for { state .replay_phase } in Replay tool_call: { tool_call .function .name } '
258
+ new_phase = tool ['new_phase' ]
259
+ assert (
260
+ new_phase
261
+ ), f'Missing new_phase in Replay tool_call: { tool_call .function .name } '
262
+ assert (
263
+ new_phase
264
+ ), f'Missing new_phase in Replay tool_call: { tool_call .function .name } '
265
+ del arguments ['new_phase' ]
237
266
action = ReplayPhaseUpdateAction (
238
- new_phase = tool_call [ ' new_phase' ] , info = json .dumps (arguments )
267
+ new_phase = new_phase , info = json .dumps (arguments )
239
268
)
240
269
assert action , f'Unhandled Replay tool_call: { tool_call .function .name } '
241
270
return action
0 commit comments