Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto PR from main to live #23

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# To be developed
# To be developed
EVALUATION_CYCLES_COUNT=3
PARSE_ERROR_TRIES = 5
148 changes: 89 additions & 59 deletions app/libs/tools_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import math
import uuid
from fastapi.responses import JSONResponse
from prompts import *
from utils import get_tool_call_response, describe
from .base_handler import Handler, Context
from .context import Context
from utils import get_tool_call_response, create_logger, describe

from config import PARSE_ERROR_TRIES, EVALUATION_CYCLES_COUNT
missed_tool_logger = create_logger(
"chain.missed_tools", ".logs/empty_tool_tool_response.log"
)
Expand Down Expand Up @@ -84,11 +85,12 @@ async def handle(self, context: Context):
new_messages = [
{"role": "system", "content": system_message},
{
"role": "system",
"content": f"# Conversation History:\n{messages_flatten}\n\n# Available Tools: \n{tools_json}\n\n{suffix}",
"role": "user",
"content": f"# Conversation History:\n{messages_flatten}\n\n# Available Tools: \n{tools_json}\n\n{suffix}\n{context.messages[-1]['content']}",
},
]


completion, tool_calls = await self.process_tool_calls(
context, new_messages
)
Expand Down Expand Up @@ -144,56 +146,87 @@ async def handle(self, context: Context):

async def process_tool_calls(self, context, new_messages):
try:
tries = 5
tool_calls = []
while tries > 0:
try:
# Assuming the context has an instantiated client according to the selected provider
completion = context.client.route(
model=context.client.parser_model,
messages=new_messages,
temperature=0,
max_tokens=1024,
top_p=1,
stream=False,
)

response = completion.choices[0].message.content
if "```json" in response:
response = response.split("```json")[1].split("```")[0]

try:
tool_response = json.loads(response)
if isinstance(tool_response, list):
tool_response = {"tool_calls": tool_response}
except json.JSONDecodeError as e:
print(
f"Error parsing the tool response: {e}, tries left: {tries}"
)
new_messages.append(
{
"role": "user",
"content": f"Error: {e}.\n\n{CLEAN_UP_MESSAGE}",
}

evaluation_cycles_count = EVALUATION_CYCLES_COUNT
try:
# Assuming the context has an instantiated client according to the selected provider
cnadidate_responses =[]
result_is_confirmed = False
for generation in range(evaluation_cycles_count):
tries = PARSE_ERROR_TRIES
tool_calls = []
while tries > 0:
completion = context.client.route(
model=context.client.parser_model,
messages=new_messages,
temperature=0,
max_tokens=512,
top_p=1,
stream=False,
# response_format = {"type": "json_object"}
)
tries -= 1
continue

response = completion.choices[0].message.content
response = response.replace("\_", "_")
if TOOLS_OPEN_TOKEN in response:
response = response.split(TOOLS_OPEN_TOKEN)[1].split(TOOLS_CLOSE_TOKEN)[0]
if "```json" in response:
response = response.split("```json")[1].split("```")[0]

try:
tool_response = json.loads(response)
if isinstance(tool_response, list):
tool_response = {"tool_calls": tool_response}
break
except json.JSONDecodeError as e:
print(
f"Error parsing the tool response: {e}, tries left: {tries}"
)
new_messages.append(
{
"role": "user",
"content": f"Error: {e}.\n\n{CLEAN_UP_MESSAGE}",
}
)
tries -= 1
continue
cnadidate_responses.append(tool_response)

# Go through candidate and see if all detected tools count is 2 then break
tool_calls_count = {}
for tool_response in cnadidate_responses:
for func in tool_response.get("tool_calls", []):
tool_calls_count[func["name"]] = tool_calls_count.get(func["name"], 0) + 1

if all([v == 2 for v in tool_calls_count.values()]):
result_is_confirmed = True
break

# Go through candiudtae and count the number of each tolls is called
tool_calls_count = {}
for tool_response in cnadidate_responses:
for func in tool_response.get("tool_calls", []):
tool_calls.append(
{
"id": f"call_{func['name']}_{str(uuid.uuid4())}",
"type": "function",
"function": {
"name": func["name"],
"arguments": json.dumps(func["arguments"]),
},
}
)
tool_calls_count[func["name"]] = tool_calls_count.get(func["name"], 0) + 1

break
except Exception as e:
raise e
pickup_threshold = math.floor(len(cnadidate_responses) * 0.7) if not result_is_confirmed else 0
# Select any tools with frq more than 2
tool_calls = []
for tool_response in cnadidate_responses:
for func in tool_response.get("tool_calls", []):
if tool_calls_count[func["name"]] > pickup_threshold:
tool_calls.append(
{
"id": f"call_{func['name']}_{str(uuid.uuid4())}",
"type": "function",
"function": {
"name": func["name"],
"arguments": json.dumps(func["arguments"]),
},
}
)

except Exception as e:
raise e

if tries == 0:
tool_calls = []
Expand All @@ -215,23 +248,20 @@ async def handle(self, context: Context):
message["role"] = "user"
message["content"] = get_func_result_guide(message["content"])

messages[-1]["role"] = "user"
# Assuming get_func_result_guide is a function that formats the tool response
messages[-1]["content"] = get_func_result_guide(messages[-1]["content"])

try:
params = {
'temperature' : 0.5,
'max_tokens' : 1024,
}
params = {**params, **context.params}

completion = context.client.route(
messages=messages,
**context.client.clean_params(context.params),
**context.client.clean_params(params),
)
context.response = completion.model_dump()
return JSONResponse(content=context.response, status_code=200)
except Exception as e:
# Log the exception or handle it as needed
# context.response = {
# "error": "An error occurred processing the tool response"
# }
# return JSONResponse(content=context.response, status_code=500)
raise e

return await super().handle(context)
Loading
Loading