Skip to content

Commit 18f35c9

Browse files
committed
Updates:
- chains.py: Supporting "tool_choice". - Update cookbook examples. - Improve promopts, for "force" mode.
1 parent d828608 commit 18f35c9

5 files changed

+201
-11
lines changed

app/libs/chains.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from importlib import import_module
44
import json
55
import uuid
6+
import traceback
67
from fastapi import Request
78
from fastapi.responses import JSONResponse
89
from providers import BaseProvider
9-
from prompts import SYSTEM_MESSAGE, SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide
10+
from prompts import SYSTEM_MESSAGE, ENFORCED_SYSTAME_MESSAE, SUFFIX, FORCE_CALL_SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide, get_forced_tool_suffix
1011
from providers import GroqProvider
1112
import importlib
1213
from utils import get_tool_call_response, create_logger
@@ -19,8 +20,10 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]):
1920
self.provider = provider
2021
self.body = body
2122
self.response = None
23+
2224
# extract all keys from body except messages and tools and set in params
2325
self.params = {k: v for k, v in body.items() if k not in ["messages", "tools"]}
26+
2427
# self.no_tool_behaviour = self.params.get("no_tool_behaviour", "return")
2528
self.no_tool_behaviour = self.params.get("no_tool_behaviour", "forward")
2629
self.params.pop("no_tool_behaviour", None)
@@ -50,8 +53,6 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]):
5053
bt['extra'] = self.params.get("extra", {})
5154
self.params.pop("extra", None)
5255

53-
54-
5556
self.client : BaseProvider = None
5657

5758
@property
@@ -60,7 +61,7 @@ def last_message(self):
6061

6162
@property
6263
def is_tool_call(self):
63-
return bool(self.last_message["role"] == "user" and self.tools)
64+
return bool(self.last_message["role"] == "user" and self.tools and self.params.get("tool_choice", "none") != "none")
6465

6566
@property
6667
def is_tool_response(self):
@@ -88,6 +89,7 @@ async def handle(self, context: Context):
8889
return await self._next_handler.handle(context)
8990
except Exception as e:
9091
_exception_handler: "Handler" = ExceptionHandler()
92+
# Extract the stack trace and log the exception
9193
return await _exception_handler.handle(context, e)
9294

9395

@@ -130,19 +132,35 @@ class ToolExtractionHandler(Handler):
130132
async def handle(self, context: Context):
131133
body = context.body
132134
if context.is_tool_call:
135+
136+
# Prepare the messages and tools for the tool extraction
133137
messages = [
134138
f"{m['role'].title()}: {m['content']}"
135139
for m in context.messages
136140
if m["role"] != "system"
137141
]
138-
139142
tools_json = json.dumps([t["function"] for t in context.tools], indent=4)
140143

144+
# Process the tool_choice
145+
tool_choice = context.params.get("tool_choice", "auto")
146+
forced_mode = False
147+
if type(tool_choice) == dict and tool_choice.get("type", None) == "function":
148+
tool_choice = tool_choice["function"].get("name", None)
149+
if not tool_choice:
150+
raise ValueError("Invalid tool choice. 'tool_choice' is set to a dictionary with 'type' as 'function', but 'function' does not have a 'name' key.")
151+
forced_mode = True
152+
153+
# Regenerate the string tool_json and keep only the forced tool
154+
tools_json = json.dumps([t["function"] for t in context.tools if t["function"]["name"] == tool_choice], indent=4)
155+
156+
system_message = SYSTEM_MESSAGE if not forced_mode else ENFORCED_SYSTAME_MESSAE
157+
suffix = SUFFIX if not forced_mode else get_forced_tool_suffix(tool_choice)
158+
141159
new_messages = [
142-
{"role": "system", "content": SYSTEM_MESSAGE},
160+
{"role": "system", "content": system_message},
143161
{
144162
"role": "system",
145-
"content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{SUFFIX}",
163+
"content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{suffix}",
146164
},
147165
]
148166

@@ -309,4 +327,5 @@ async def handle(self, context: Context):
309327
class ExceptionHandler(Handler):
310328
async def handle(self, context: Context, exception: Exception):
311329
print(f"Error processing the request: {exception}")
312-
return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500)
330+
print(traceback.format_exc())
331+
return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500)

app/prompts.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,44 @@
3333
3434
** If no tools are required, then return an empty list for "tool_calls". **
3535
36-
**Wrap the JSON response between ```json and ```**.
36+
**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**.
3737
3838
**Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response, do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user.**"""
3939

40+
41+
ENFORCED_SYSTAME_MESSAE = """A history of conversations between an AI assistant and the user, plus the last user's message, is given to you.
42+
43+
You have access to a specific tool that the AI assistant must use to provide a proper answer. The tool is a function that requires a set of parameters, which are provided in a JSON schema to explain what parameters the tool needs. Your task is to extract the values for these parameters from the user's last message and the conversation history.
44+
45+
Your job is to closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response.
46+
47+
You should think step by step, provide your reasoning for your response, then add the JSON response at the end following the below schema:
48+
49+
50+
{
51+
"tool_calls": [{
52+
"name": "function_name",
53+
"arguments": {
54+
"arg1": "value1",
55+
"arg2": "value2",
56+
...
57+
}]
58+
}
59+
}
60+
61+
62+
**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**.
63+
64+
Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response. Do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user."""
65+
4066
CLEAN_UP_MESSAGE = "When I tried to extract the content between ```json and ``` and parse the content to valid JSON object, I faced with the abovr error. Remember, you are supposed to wrap the schema between ```json and ```, and do this only one time. First find out what went wrong, that I couldn't extract the JSON between ```json and ```, and also faced error when trying to parse it, then regenerate the your last message and fix the issue."
67+
4168
SUFFIX = """Think step by step and justify your response. Make sure to not miss in case to answer user query we need multiple tools, in that case detect all that we need, then generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME."""
4269

70+
FORCE_CALL_SUFFIX = """For this task, you HAVE to choose the tool (function) {tool_name}, and ignore other rools. Therefore think step by step and justify your response, then closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response. Finally generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME."""
71+
72+
def get_forced_tool_suffix(tool_name : str) -> str:
73+
return FORCE_CALL_SUFFIX.format(tool_name=tool_name)
4374

4475
def get_func_result_guide(function_call_result : str) -> str:
4576
return f"SYSTEM MESSAGE: \n```json\n{function_call_result}\n```\n\nThe above is the result after functions are called. Use the result to answer the user's last question.\n\n"
+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
2+
from duckduckgo_search import DDGS
3+
import requests, os
4+
import json
5+
6+
api_key=os.environ["GROQ_API_KEY"]
7+
header = {
8+
"Authorization": f"Bearer {api_key}",
9+
"Content-Type": "application/json"
10+
}
11+
proxy_url = "https://groqcall.ai/proxy/groq/v1/chat/completions"
12+
13+
# or "http://localhost:8000/proxy/groq/v1/chat/completions" if running locally
14+
# proxy_url = "http://localhost:8000/proxy/groq/v1/chat/completions"
15+
16+
17+
def duckduckgo_search(query, max_results=None):
18+
"""
19+
Use this function to search DuckDuckGo for a query.
20+
"""
21+
with DDGS() as ddgs:
22+
return [r for r in ddgs.text(query, safesearch='off', max_results=max_results)]
23+
24+
def duckduckgo_news(query, max_results=None):
25+
"""
26+
Use this function to get the latest news from DuckDuckGo.
27+
"""
28+
with DDGS() as ddgs:
29+
return [r for r in ddgs.news(query, safesearch='off', max_results=max_results)]
30+
31+
function_map = {
32+
"duckduckgo_search": duckduckgo_search,
33+
"duckduckgo_news": duckduckgo_news,
34+
}
35+
36+
request = {
37+
"messages": [
38+
{
39+
"role": "system",
40+
"content": "YOU MUST FOLLOW THESE INSTRUCTIONS CAREFULLY.\n<instructions>\n1. Use markdown to format your answers.\n</instructions>"
41+
},
42+
{
43+
"role": "user",
44+
"content": "Whats happening in France? Summarize top stories with sources, very short and concise."
45+
}
46+
],
47+
"model": "mixtral-8x7b-32768",
48+
# "tool_choice": "auto",
49+
# "tool_choice": "none",
50+
"tool_choice": {"type": "function", "function": {"name": "duckduckgo_search"}},
51+
"tools": [
52+
{
53+
"type": "function",
54+
"function": {
55+
"name": "duckduckgo_search",
56+
"description": "Use this function to search DuckDuckGo for a query.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The result from DuckDuckGo.",
57+
"parameters": {
58+
"type": "object",
59+
"properties": {
60+
"query": {
61+
"type": "string"
62+
},
63+
"max_results": {
64+
"type": [
65+
"number",
66+
"null"
67+
]
68+
}
69+
}
70+
}
71+
}
72+
},
73+
{
74+
"type": "function",
75+
"function": {
76+
"name": "duckduckgo_news",
77+
"description": "Use this function to get the latest news from DuckDuckGo.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The latest news from DuckDuckGo.",
78+
"parameters": {
79+
"type": "object",
80+
"properties": {
81+
"query": {
82+
"type": "string"
83+
},
84+
"max_results": {
85+
"type": [
86+
"number",
87+
"null"
88+
]
89+
}
90+
}
91+
}
92+
}
93+
}
94+
]
95+
}
96+
97+
response = requests.post(
98+
proxy_url,
99+
headers=header,
100+
json=request
101+
)
102+
# Check if the request was successful
103+
if response.status_code == 200:
104+
# Process the response data (if needed)
105+
res = response.json()
106+
message = res['choices'][0]['message']
107+
tools_response_messages = []
108+
if not message['content'] and 'tool_calls' in message:
109+
for tool_call in message['tool_calls']:
110+
tool_name = tool_call['function']['name']
111+
tool_args = tool_call['function']['arguments']
112+
tool_args = json.loads(tool_args)
113+
if tool_name not in function_map:
114+
print(f"Error: {tool_name} is not a valid function name.")
115+
continue
116+
tool_func = function_map[tool_name]
117+
tool_response = tool_func(**tool_args)
118+
tools_response_messages.append({
119+
"role": "tool", "content": json.dumps(tool_response)
120+
})
121+
122+
if tools_response_messages:
123+
request['messages'] += tools_response_messages
124+
response = requests.post(
125+
proxy_url,
126+
headers=header,
127+
json=request
128+
)
129+
if response.status_code == 200:
130+
res = response.json()
131+
print(res['choices'][0]['message']['content'])
132+
else:
133+
print("Error:", response.status_code, response.text)
134+
else:
135+
print(message['content'])
136+
else:
137+
print("Error:", response.status_code, response.text)

cookbook/function_call_with_schema.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11

22
from duckduckgo_search import DDGS
33
import requests, os
4-
api_key=os.environ["GROQ_API_KEY"]
54
import json
5+
6+
api_key=os.environ["GROQ_API_KEY"]
67
header = {
78
"Authorization": f"Bearer {api_key}",
89
"Content-Type": "application/json"

cookbook/function_call_without_schema.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import requests
2+
import json
3+
import os
24

3-
api_key = "YOUR_GROQ_API_KEY"
5+
api_key=os.environ["GROQ_API_KEY"],
46
header = {
57
"Authorization": f"Bearer {api_key}",
68
"Content-Type": "application/json"

0 commit comments

Comments
 (0)