11import json
2+ import math
23import uuid
34from fastapi .responses import JSONResponse
45from prompts import *
56from utils import get_tool_call_response , describe
67from .base_handler import Handler , Context
78from .context import Context
89from utils import get_tool_call_response , create_logger , describe
9-
10+ from config import PARSE_ERROR_TRIES , EVALUATION_CYCLES_COUNT
1011missed_tool_logger = create_logger (
1112 "chain.missed_tools" , ".logs/empty_tool_tool_response.log"
1213)
@@ -84,11 +85,12 @@ async def handle(self, context: Context):
8485 new_messages = [
8586 {"role" : "system" , "content" : system_message },
8687 {
87- "role" : "system " ,
88- "content" : f"# Conversation History:\n { messages_flatten } \n \n # Available Tools: \n { tools_json } \n \n { suffix } " ,
88+ "role" : "user " ,
89+ "content" : f"# Conversation History:\n { messages_flatten } \n \n # Available Tools: \n { tools_json } \n \n { suffix } \n { context . messages [ - 1 ][ 'content' ] } " ,
8990 },
9091 ]
9192
93+
9294 completion , tool_calls = await self .process_tool_calls (
9395 context , new_messages
9496 )
@@ -144,56 +146,87 @@ async def handle(self, context: Context):
144146
145147 async def process_tool_calls (self , context , new_messages ):
146148 try :
147- tries = 5
148- tool_calls = []
149- while tries > 0 :
150- try :
151- # Assuming the context has an instantiated client according to the selected provider
152- completion = context .client .route (
153- model = context .client .parser_model ,
154- messages = new_messages ,
155- temperature = 0 ,
156- max_tokens = 1024 ,
157- top_p = 1 ,
158- stream = False ,
159- )
160-
161- response = completion .choices [0 ].message .content
162- if "```json" in response :
163- response = response .split ("```json" )[1 ].split ("```" )[0 ]
164-
165- try :
166- tool_response = json .loads (response )
167- if isinstance (tool_response , list ):
168- tool_response = {"tool_calls" : tool_response }
169- except json .JSONDecodeError as e :
170- print (
171- f"Error parsing the tool response: { e } , tries left: { tries } "
172- )
173- new_messages .append (
174- {
175- "role" : "user" ,
176- "content" : f"Error: { e } .\n \n { CLEAN_UP_MESSAGE } " ,
177- }
149+
150+ evaluation_cycles_count = EVALUATION_CYCLES_COUNT
151+ try :
152+ # Assuming the context has an instantiated client according to the selected provider
153+ cnadidate_responses = []
154+ result_is_confirmed = False
155+ for generation in range (evaluation_cycles_count ):
156+ tries = PARSE_ERROR_TRIES
157+ tool_calls = []
158+ while tries > 0 :
159+ completion = context .client .route (
160+ model = context .client .parser_model ,
161+ messages = new_messages ,
162+ temperature = 0 ,
163+ max_tokens = 512 ,
164+ top_p = 1 ,
165+ stream = False ,
166+ # response_format = {"type": "json_object"}
178167 )
179- tries -= 1
180- continue
181168
169+ response = completion .choices [0 ].message .content
170+ response = response .replace ("\_" , "_" )
171+ if TOOLS_OPEN_TOKEN in response :
172+ response = response .split (TOOLS_OPEN_TOKEN )[1 ].split (TOOLS_CLOSE_TOKEN )[0 ]
173+ if "```json" in response :
174+ response = response .split ("```json" )[1 ].split ("```" )[0 ]
175+
176+ try :
177+ tool_response = json .loads (response )
178+ if isinstance (tool_response , list ):
179+ tool_response = {"tool_calls" : tool_response }
180+ break
181+ except json .JSONDecodeError as e :
182+ print (
183+ f"Error parsing the tool response: { e } , tries left: { tries } "
184+ )
185+ new_messages .append (
186+ {
187+ "role" : "user" ,
188+ "content" : f"Error: { e } .\n \n { CLEAN_UP_MESSAGE } " ,
189+ }
190+ )
191+ tries -= 1
192+ continue
193+ cnadidate_responses .append (tool_response )
194+
195+ # Go through candidate and see if all detected tools count is 2 then break
196+ tool_calls_count = {}
197+ for tool_response in cnadidate_responses :
198+ for func in tool_response .get ("tool_calls" , []):
199+ tool_calls_count [func ["name" ]] = tool_calls_count .get (func ["name" ], 0 ) + 1
200+
201+ if all ([v == 2 for v in tool_calls_count .values ()]):
202+ result_is_confirmed = True
203+ break
204+
205+ # Go through candiudtae and count the number of each tolls is called
206+ tool_calls_count = {}
207+ for tool_response in cnadidate_responses :
182208 for func in tool_response .get ("tool_calls" , []):
183- tool_calls .append (
184- {
185- "id" : f"call_{ func ['name' ]} _{ str (uuid .uuid4 ())} " ,
186- "type" : "function" ,
187- "function" : {
188- "name" : func ["name" ],
189- "arguments" : json .dumps (func ["arguments" ]),
190- },
191- }
192- )
209+ tool_calls_count [func ["name" ]] = tool_calls_count .get (func ["name" ], 0 ) + 1
193210
194- break
195- except Exception as e :
196- raise e
211+ pickup_threshold = math .floor (len (cnadidate_responses ) * 0.7 ) if not result_is_confirmed else 0
212+ # Select any tools with frq more than 2
213+ tool_calls = []
214+ for tool_response in cnadidate_responses :
215+ for func in tool_response .get ("tool_calls" , []):
216+ if tool_calls_count [func ["name" ]] > pickup_threshold :
217+ tool_calls .append (
218+ {
219+ "id" : f"call_{ func ['name' ]} _{ str (uuid .uuid4 ())} " ,
220+ "type" : "function" ,
221+ "function" : {
222+ "name" : func ["name" ],
223+ "arguments" : json .dumps (func ["arguments" ]),
224+ },
225+ }
226+ )
227+
228+ except Exception as e :
229+ raise e
197230
198231 if tries == 0 :
199232 tool_calls = []
@@ -215,23 +248,20 @@ async def handle(self, context: Context):
215248 message ["role" ] = "user"
216249 message ["content" ] = get_func_result_guide (message ["content" ])
217250
218- messages [- 1 ]["role" ] = "user"
219- # Assuming get_func_result_guide is a function that formats the tool response
220- messages [- 1 ]["content" ] = get_func_result_guide (messages [- 1 ]["content" ])
221-
222251 try :
252+ params = {
253+ 'temperature' : 0.5 ,
254+ 'max_tokens' : 1024 ,
255+ }
256+ params = {** params , ** context .params }
257+
223258 completion = context .client .route (
224259 messages = messages ,
225- ** context .client .clean_params (context . params ),
260+ ** context .client .clean_params (params ),
226261 )
227262 context .response = completion .model_dump ()
228263 return JSONResponse (content = context .response , status_code = 200 )
229264 except Exception as e :
230- # Log the exception or handle it as needed
231- # context.response = {
232- # "error": "An error occurred processing the tool response"
233- # }
234- # return JSONResponse(content=context.response, status_code=500)
235265 raise e
236266
237267 return await super ().handle (context )
0 commit comments