Skip to content

Commit 51eced0

Browse files
committed
fix: openai compatable with llama and gemini
1 parent 5d5ea3e commit 51eced0

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

application/agents/llm_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True
7272
while True:
7373
tool_calls = {}
7474
for chunk in resp:
75-
if isinstance(chunk, str):
75+
if isinstance(chunk, str) and len(chunk) > 0:
7676
return
77-
else:
77+
elif hasattr(chunk, "delta"):
7878
chunk_delta = chunk.delta
7979

8080
if (
@@ -113,6 +113,8 @@ def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True
113113
tool_response, call_id = agent._execute_tool_action(
114114
tools_dict, call
115115
)
116+
if isinstance(call["function"]["arguments"], str):
117+
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
116118

117119
function_call_dict = {
118120
"function_call": {
@@ -156,6 +158,8 @@ def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True
156158
and chunk.finish_reason == "stop"
157159
):
158160
return
161+
elif isinstance(chunk, str) and len(chunk) == 0:
162+
continue
159163

160164
resp = agent.llm.gen_stream(
161165
model=agent.gpt_model, messages=messages, tools=agent.tools
Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,132 @@
1-
from application.llm.base import BaseLLM
21
import json
3-
import requests
2+
import sys
3+
4+
from application.core.settings import settings
5+
from application.llm.base import BaseLLM
46

57

68
class DocsGPTAPILLM(BaseLLM):
79

810
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
11+
from openai import OpenAI
12+
913
super().__init__(*args, **kwargs)
10-
self.api_key = api_key
14+
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
1115
self.user_api_key = user_api_key
12-
self.endpoint = "https://llm.arc53.com"
13-
14-
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
15-
response = requests.post(
16-
f"{self.endpoint}/answer", json={"messages": messages, "max_new_tokens": 30}
17-
)
18-
response_clean = response.json()["a"].replace("###", "")
19-
20-
return response_clean
21-
22-
def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
23-
response = requests.post(
24-
f"{self.endpoint}/stream",
25-
json={"messages": messages, "max_new_tokens": 256},
26-
stream=True,
27-
)
28-
29-
for line in response.iter_lines():
30-
if line:
31-
data_str = line.decode("utf-8")
32-
if data_str.startswith("data: "):
33-
data = json.loads(data_str[6:])
34-
yield data["a"]
16+
self.api_key = api_key
17+
18+
def _clean_messages_openai(self, messages):
19+
cleaned_messages = []
20+
for message in messages:
21+
role = message.get("role")
22+
content = message.get("content")
23+
24+
if role == "model":
25+
role = "assistant"
26+
27+
if role and content is not None:
28+
if isinstance(content, str):
29+
cleaned_messages.append({"role": role, "content": content})
30+
elif isinstance(content, list):
31+
for item in content:
32+
if "text" in item:
33+
cleaned_messages.append(
34+
{"role": role, "content": item["text"]}
35+
)
36+
elif "function_call" in item:
37+
tool_call = {
38+
"id": item["function_call"]["call_id"],
39+
"type": "function",
40+
"function": {
41+
"name": item["function_call"]["name"],
42+
"arguments": json.dumps(
43+
item["function_call"]["args"]
44+
),
45+
},
46+
}
47+
cleaned_messages.append(
48+
{
49+
"role": "assistant",
50+
"content": None,
51+
"tool_calls": [tool_call],
52+
}
53+
)
54+
elif "function_response" in item:
55+
cleaned_messages.append(
56+
{
57+
"role": "tool",
58+
"tool_call_id": item["function_response"][
59+
"call_id"
60+
],
61+
"content": json.dumps(
62+
item["function_response"]["response"]["result"]
63+
),
64+
}
65+
)
66+
else:
67+
raise ValueError(
68+
f"Unexpected content dictionary format: {item}"
69+
)
70+
else:
71+
raise ValueError(f"Unexpected content type: {type(content)}")
72+
73+
return cleaned_messages
74+
75+
def _raw_gen(
76+
self,
77+
baseself,
78+
model,
79+
messages,
80+
stream=False,
81+
tools=None,
82+
engine=settings.AZURE_DEPLOYMENT_NAME,
83+
**kwargs,
84+
):
85+
messages = self._clean_messages_openai(messages)
86+
if tools:
87+
response = self.client.chat.completions.create(
88+
model="docsgpt",
89+
messages=messages,
90+
stream=stream,
91+
tools=tools,
92+
**kwargs,
93+
)
94+
return response.choices[0]
95+
else:
96+
response = self.client.chat.completions.create(
97+
model="docsgpt", messages=messages, stream=stream, **kwargs
98+
)
99+
return response.choices[0].message.content
100+
101+
def _raw_gen_stream(
102+
self,
103+
baseself,
104+
model,
105+
messages,
106+
stream=True,
107+
tools=None,
108+
engine=settings.AZURE_DEPLOYMENT_NAME,
109+
**kwargs,
110+
):
111+
messages = self._clean_messages_openai(messages)
112+
if tools:
113+
response = self.client.chat.completions.create(
114+
model="docsgpt",
115+
messages=messages,
116+
stream=stream,
117+
tools=tools,
118+
**kwargs,
119+
)
120+
else:
121+
response = self.client.chat.completions.create(
122+
model="docsgpt", messages=messages, stream=stream, **kwargs
123+
)
124+
125+
for line in response:
126+
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
127+
yield line.choices[0].delta.content
128+
elif len(line.choices) > 0:
129+
yield line.choices[0]
130+
131+
def _supports_tools(self):
132+
return True

application/llm/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _raw_gen_stream(
125125
)
126126

127127
for line in response:
128-
if line.choices[0].delta.content is not None:
128+
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
129129
yield line.choices[0].delta.content
130-
else:
130+
elif len(line.choices) > 0:
131131
yield line.choices[0]
132132

133133
def _supports_tools(self):

0 commit comments

Comments
 (0)