1
1
import json
2
+ import math
2
3
import uuid
3
4
from fastapi .responses import JSONResponse
4
5
from prompts import *
5
6
from utils import get_tool_call_response , describe
6
7
from .base_handler import Handler , Context
7
8
from .context import Context
8
9
from utils import get_tool_call_response , create_logger , describe
9
-
10
+ from config import PARSE_ERROR_TRIES , EVALUATION_CYCLES_COUNT
10
11
missed_tool_logger = create_logger (
11
12
"chain.missed_tools" , ".logs/empty_tool_tool_response.log"
12
13
)
@@ -84,11 +85,12 @@ async def handle(self, context: Context):
84
85
new_messages = [
85
86
{"role" : "system" , "content" : system_message },
86
87
{
87
- "role" : "system " ,
88
- "content" : f"# Conversation History:\n { messages_flatten } \n \n # Available Tools: \n { tools_json } \n \n { suffix } " ,
88
+ "role" : "user " ,
89
+ "content" : f"# Conversation History:\n { messages_flatten } \n \n # Available Tools: \n { tools_json } \n \n { suffix } \n { context . messages [ - 1 ][ 'content' ] } " ,
89
90
},
90
91
]
91
92
93
+
92
94
completion , tool_calls = await self .process_tool_calls (
93
95
context , new_messages
94
96
)
@@ -144,56 +146,87 @@ async def handle(self, context: Context):
144
146
145
147
async def process_tool_calls (self , context , new_messages ):
146
148
try :
147
- tries = 5
148
- tool_calls = []
149
- while tries > 0 :
150
- try :
151
- # Assuming the context has an instantiated client according to the selected provider
152
- completion = context .client .route (
153
- model = context .client .parser_model ,
154
- messages = new_messages ,
155
- temperature = 0 ,
156
- max_tokens = 1024 ,
157
- top_p = 1 ,
158
- stream = False ,
159
- )
160
-
161
- response = completion .choices [0 ].message .content
162
- if "```json" in response :
163
- response = response .split ("```json" )[1 ].split ("```" )[0 ]
164
-
165
- try :
166
- tool_response = json .loads (response )
167
- if isinstance (tool_response , list ):
168
- tool_response = {"tool_calls" : tool_response }
169
- except json .JSONDecodeError as e :
170
- print (
171
- f"Error parsing the tool response: { e } , tries left: { tries } "
172
- )
173
- new_messages .append (
174
- {
175
- "role" : "user" ,
176
- "content" : f"Error: { e } .\n \n { CLEAN_UP_MESSAGE } " ,
177
- }
149
+
150
+ evaluation_cycles_count = EVALUATION_CYCLES_COUNT
151
+ try :
152
+ # Assuming the context has an instantiated client according to the selected provider
153
+ cnadidate_responses = []
154
+ result_is_confirmed = False
155
+ for generation in range (evaluation_cycles_count ):
156
+ tries = PARSE_ERROR_TRIES
157
+ tool_calls = []
158
+ while tries > 0 :
159
+ completion = context .client .route (
160
+ model = context .client .parser_model ,
161
+ messages = new_messages ,
162
+ temperature = 0 ,
163
+ max_tokens = 512 ,
164
+ top_p = 1 ,
165
+ stream = False ,
166
+ # response_format = {"type": "json_object"}
178
167
)
179
- tries -= 1
180
- continue
181
168
169
+ response = completion .choices [0 ].message .content
170
+ response = response .replace ("\_" , "_" )
171
+ if TOOLS_OPEN_TOKEN in response :
172
+ response = response .split (TOOLS_OPEN_TOKEN )[1 ].split (TOOLS_CLOSE_TOKEN )[0 ]
173
+ if "```json" in response :
174
+ response = response .split ("```json" )[1 ].split ("```" )[0 ]
175
+
176
+ try :
177
+ tool_response = json .loads (response )
178
+ if isinstance (tool_response , list ):
179
+ tool_response = {"tool_calls" : tool_response }
180
+ break
181
+ except json .JSONDecodeError as e :
182
+ print (
183
+ f"Error parsing the tool response: { e } , tries left: { tries } "
184
+ )
185
+ new_messages .append (
186
+ {
187
+ "role" : "user" ,
188
+ "content" : f"Error: { e } .\n \n { CLEAN_UP_MESSAGE } " ,
189
+ }
190
+ )
191
+ tries -= 1
192
+ continue
193
+ cnadidate_responses .append (tool_response )
194
+
195
+ # Go through candidate and see if all detected tools count is 2 then break
196
+ tool_calls_count = {}
197
+ for tool_response in cnadidate_responses :
198
+ for func in tool_response .get ("tool_calls" , []):
199
+ tool_calls_count [func ["name" ]] = tool_calls_count .get (func ["name" ], 0 ) + 1
200
+
201
+ if all ([v == 2 for v in tool_calls_count .values ()]):
202
+ result_is_confirmed = True
203
+ break
204
+
205
+ # Go through candiudtae and count the number of each tolls is called
206
+ tool_calls_count = {}
207
+ for tool_response in cnadidate_responses :
182
208
for func in tool_response .get ("tool_calls" , []):
183
- tool_calls .append (
184
- {
185
- "id" : f"call_{ func ['name' ]} _{ str (uuid .uuid4 ())} " ,
186
- "type" : "function" ,
187
- "function" : {
188
- "name" : func ["name" ],
189
- "arguments" : json .dumps (func ["arguments" ]),
190
- },
191
- }
192
- )
209
+ tool_calls_count [func ["name" ]] = tool_calls_count .get (func ["name" ], 0 ) + 1
193
210
194
- break
195
- except Exception as e :
196
- raise e
211
+ pickup_threshold = math .floor (len (cnadidate_responses ) * 0.7 ) if not result_is_confirmed else 0
212
+ # Select any tools with frq more than 2
213
+ tool_calls = []
214
+ for tool_response in cnadidate_responses :
215
+ for func in tool_response .get ("tool_calls" , []):
216
+ if tool_calls_count [func ["name" ]] > pickup_threshold :
217
+ tool_calls .append (
218
+ {
219
+ "id" : f"call_{ func ['name' ]} _{ str (uuid .uuid4 ())} " ,
220
+ "type" : "function" ,
221
+ "function" : {
222
+ "name" : func ["name" ],
223
+ "arguments" : json .dumps (func ["arguments" ]),
224
+ },
225
+ }
226
+ )
227
+
228
+ except Exception as e :
229
+ raise e
197
230
198
231
if tries == 0 :
199
232
tool_calls = []
@@ -215,23 +248,20 @@ async def handle(self, context: Context):
215
248
message ["role" ] = "user"
216
249
message ["content" ] = get_func_result_guide (message ["content" ])
217
250
218
- messages [- 1 ]["role" ] = "user"
219
- # Assuming get_func_result_guide is a function that formats the tool response
220
- messages [- 1 ]["content" ] = get_func_result_guide (messages [- 1 ]["content" ])
221
-
222
251
try :
252
+ params = {
253
+ 'temperature' : 0.5 ,
254
+ 'max_tokens' : 1024 ,
255
+ }
256
+ params = {** params , ** context .params }
257
+
223
258
completion = context .client .route (
224
259
messages = messages ,
225
- ** context .client .clean_params (context . params ),
260
+ ** context .client .clean_params (params ),
226
261
)
227
262
context .response = completion .model_dump ()
228
263
return JSONResponse (content = context .response , status_code = 200 )
229
264
except Exception as e :
230
- # Log the exception or handle it as needed
231
- # context.response = {
232
- # "error": "An error occurred processing the tool response"
233
- # }
234
- # return JSONResponse(content=context.response, status_code=500)
235
265
raise e
236
266
237
267
return await super ().handle (context )
0 commit comments