Skip to content

Commit 9b13523

Browse files
committed
Updates:
- Improve parser - Vote to majority for tools selection - Working on reasoning engine. - Update Config file to ser the avaulation cycle
1 parent bb91a4b commit 9b13523

File tree

7 files changed

+325
-96
lines changed

7 files changed

+325
-96
lines changed

app/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
# To be developed
1+
# To be developed
2+
EVALUATION_CYCLES_COUNT=3
3+
PARSE_ERROR_TRIES = 5

app/libs/tools_handler.py

+89-59
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
2+
import math
23
import uuid
34
from fastapi.responses import JSONResponse
45
from prompts import *
56
from utils import get_tool_call_response, describe
67
from .base_handler import Handler, Context
78
from .context import Context
89
from utils import get_tool_call_response, create_logger, describe
9-
10+
from config import PARSE_ERROR_TRIES, EVALUATION_CYCLES_COUNT
1011
missed_tool_logger = create_logger(
1112
"chain.missed_tools", ".logs/empty_tool_tool_response.log"
1213
)
@@ -84,11 +85,12 @@ async def handle(self, context: Context):
8485
new_messages = [
8586
{"role": "system", "content": system_message},
8687
{
87-
"role": "system",
88-
"content": f"# Conversation History:\n{messages_flatten}\n\n# Available Tools: \n{tools_json}\n\n{suffix}",
88+
"role": "user",
89+
"content": f"# Conversation History:\n{messages_flatten}\n\n# Available Tools: \n{tools_json}\n\n{suffix}\n{context.messages[-1]['content']}",
8990
},
9091
]
9192

93+
9294
completion, tool_calls = await self.process_tool_calls(
9395
context, new_messages
9496
)
@@ -144,56 +146,87 @@ async def handle(self, context: Context):
144146

145147
async def process_tool_calls(self, context, new_messages):
146148
try:
147-
tries = 5
148-
tool_calls = []
149-
while tries > 0:
150-
try:
151-
# Assuming the context has an instantiated client according to the selected provider
152-
completion = context.client.route(
153-
model=context.client.parser_model,
154-
messages=new_messages,
155-
temperature=0,
156-
max_tokens=1024,
157-
top_p=1,
158-
stream=False,
159-
)
160-
161-
response = completion.choices[0].message.content
162-
if "```json" in response:
163-
response = response.split("```json")[1].split("```")[0]
164-
165-
try:
166-
tool_response = json.loads(response)
167-
if isinstance(tool_response, list):
168-
tool_response = {"tool_calls": tool_response}
169-
except json.JSONDecodeError as e:
170-
print(
171-
f"Error parsing the tool response: {e}, tries left: {tries}"
172-
)
173-
new_messages.append(
174-
{
175-
"role": "user",
176-
"content": f"Error: {e}.\n\n{CLEAN_UP_MESSAGE}",
177-
}
149+
150+
evaluation_cycles_count = EVALUATION_CYCLES_COUNT
151+
try:
152+
# Assuming the context has an instantiated client according to the selected provider
153+
cnadidate_responses =[]
154+
result_is_confirmed = False
155+
for generation in range(evaluation_cycles_count):
156+
tries = PARSE_ERROR_TRIES
157+
tool_calls = []
158+
while tries > 0:
159+
completion = context.client.route(
160+
model=context.client.parser_model,
161+
messages=new_messages,
162+
temperature=0,
163+
max_tokens=512,
164+
top_p=1,
165+
stream=False,
166+
# response_format = {"type": "json_object"}
178167
)
179-
tries -= 1
180-
continue
181168

169+
response = completion.choices[0].message.content
170+
response = response.replace("\_", "_")
171+
if TOOLS_OPEN_TOKEN in response:
172+
response = response.split(TOOLS_OPEN_TOKEN)[1].split(TOOLS_CLOSE_TOKEN)[0]
173+
if "```json" in response:
174+
response = response.split("```json")[1].split("```")[0]
175+
176+
try:
177+
tool_response = json.loads(response)
178+
if isinstance(tool_response, list):
179+
tool_response = {"tool_calls": tool_response}
180+
break
181+
except json.JSONDecodeError as e:
182+
print(
183+
f"Error parsing the tool response: {e}, tries left: {tries}"
184+
)
185+
new_messages.append(
186+
{
187+
"role": "user",
188+
"content": f"Error: {e}.\n\n{CLEAN_UP_MESSAGE}",
189+
}
190+
)
191+
tries -= 1
192+
continue
193+
cnadidate_responses.append(tool_response)
194+
195+
# Go through candidate and see if all detected tools count is 2 then break
196+
tool_calls_count = {}
197+
for tool_response in cnadidate_responses:
198+
for func in tool_response.get("tool_calls", []):
199+
tool_calls_count[func["name"]] = tool_calls_count.get(func["name"], 0) + 1
200+
201+
if all([v == 2 for v in tool_calls_count.values()]):
202+
result_is_confirmed = True
203+
break
204+
205+
# Go through candiudtae and count the number of each tolls is called
206+
tool_calls_count = {}
207+
for tool_response in cnadidate_responses:
182208
for func in tool_response.get("tool_calls", []):
183-
tool_calls.append(
184-
{
185-
"id": f"call_{func['name']}_{str(uuid.uuid4())}",
186-
"type": "function",
187-
"function": {
188-
"name": func["name"],
189-
"arguments": json.dumps(func["arguments"]),
190-
},
191-
}
192-
)
209+
tool_calls_count[func["name"]] = tool_calls_count.get(func["name"], 0) + 1
193210

194-
break
195-
except Exception as e:
196-
raise e
211+
pickup_threshold = math.floor(len(cnadidate_responses) * 0.7) if not result_is_confirmed else 0
212+
# Select any tools with frq more than 2
213+
tool_calls = []
214+
for tool_response in cnadidate_responses:
215+
for func in tool_response.get("tool_calls", []):
216+
if tool_calls_count[func["name"]] > pickup_threshold:
217+
tool_calls.append(
218+
{
219+
"id": f"call_{func['name']}_{str(uuid.uuid4())}",
220+
"type": "function",
221+
"function": {
222+
"name": func["name"],
223+
"arguments": json.dumps(func["arguments"]),
224+
},
225+
}
226+
)
227+
228+
except Exception as e:
229+
raise e
197230

198231
if tries == 0:
199232
tool_calls = []
@@ -215,23 +248,20 @@ async def handle(self, context: Context):
215248
message["role"] = "user"
216249
message["content"] = get_func_result_guide(message["content"])
217250

218-
messages[-1]["role"] = "user"
219-
# Assuming get_func_result_guide is a function that formats the tool response
220-
messages[-1]["content"] = get_func_result_guide(messages[-1]["content"])
221-
222251
try:
252+
params = {
253+
'temperature' : 0.5,
254+
'max_tokens' : 1024,
255+
}
256+
params = {**params, **context.params}
257+
223258
completion = context.client.route(
224259
messages=messages,
225-
**context.client.clean_params(context.params),
260+
**context.client.clean_params(params),
226261
)
227262
context.response = completion.model_dump()
228263
return JSONResponse(content=context.response, status_code=200)
229264
except Exception as e:
230-
# Log the exception or handle it as needed
231-
# context.response = {
232-
# "error": "An error occurred processing the tool response"
233-
# }
234-
# return JSONResponse(content=context.response, status_code=500)
235265
raise e
236266

237267
return await super().handle(context)

0 commit comments

Comments
 (0)