-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Expand file tree
/
Copy pathopenai_api.py
More file actions
331 lines (291 loc) · 13.8 KB
/
openai_api.py
File metadata and controls
331 lines (291 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from __future__ import annotations
import json
import re
from typing import Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
count_output_tokens,
get_max_completion_tokens,
)
@register_provider(
[
LLMType.OPENAI,
LLMType.FIREWORKS,
LLMType.OPEN_LLM,
LLMType.MOONSHOT,
LLMType.MISTRAL,
LLMType.YI,
LLMType.OPEN_ROUTER,
LLMType.DEEPSEEK,
LLMType.SILICONFLOW,
LLMType.OPENROUTER,
LLMType.LLAMA_API,
LLMType.MINIMAX,
]
)
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self, config: LLMConfig):
self.config = config
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _get_proxy_params(self) -> dict:
params = {}
if self.config.proxy:
params = {"proxy": self.config.proxy}
if self.config.base_url:
params["base_url"] = self.config.base_url
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True
)
usage = None
collected_messages = []
collected_reasoning_messages = []
has_finished = False
async for chunk in response:
if not chunk.choices:
continue
choice0 = chunk.choices[0]
choice_delta = choice0.delta
if hasattr(choice_delta, "reasoning_content") and choice_delta.reasoning_content:
collected_reasoning_messages.append(choice_delta.reasoning_content) # for deepseek
continue
chunk_message = choice_delta.content or "" # extract the message
finish_reason = choice0.finish_reason if hasattr(choice0, "finish_reason") else None
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
if has_finished:
# for oneapi, there has a usage chunk after finish_reason not none chunk
if chunk_has_usage:
usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
if finish_reason:
if chunk_has_usage:
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
elif hasattr(choice0, "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**choice0.usage)
has_finished = True
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
if collected_reasoning_messages:
self.reasoning_content = "".join(collected_reasoning_messages)
if not usage:
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
# "n": 1, # Some services do not provide this parameter, such as mistral
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"model": self.model,
"timeout": self.get_timeout(timeout),
}
if "o1-" in self.model:
# compatible to openai o1-series
kwargs["temperature"] = 1
kwargs.pop("max_tokens")
if self.config.api_type == LLMType.MINIMAX and kwargs.get("temperature", 0) <= 0:
# MiniMax requires temperature in (0.0, 1.0]
kwargs["temperature"] = 0.01
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(rsp)
async def _achat_completion_function(
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
) -> ChatCompletion:
messages = self.format_msg(messages)
kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
async def aask_code(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> llm = OpenAILLM()
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> rsp = await llm.aask_code(msg)
# -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
if "tools" not in kwargs:
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
kwargs.update(configs)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
def _parse_arguments(self, arguments: str) -> dict:
"""parse arguments in openai function call"""
if "language" not in arguments and "code" not in arguments:
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
return {"language": "python", "code": arguments}
# 匹配language
language_pattern = re.compile(r'[\"\']?language[\"\']?\s*:\s*["\']([^"\']+?)["\']', re.DOTALL)
language_match = language_pattern.search(arguments)
language_value = language_match.group(1) if language_match else "python"
# 匹配code
code_pattern = r'(["\'`]{3}|["\'`])([\s\S]*?)\1'
try:
code_value = re.findall(code_pattern, arguments)[-1][-1]
except Exception as e:
logger.error(f"{e}, when re.findall({code_pattern}, {arguments})")
code_value = None
if code_value is None:
raise ValueError(f"Parse code error for {arguments}")
# arguments只有code的情况
return {"language": language_value, "code": code_value}
# @handle_exception
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
:param dict rsp: same as in self.get_choice_function(rsp)
:return dict: return the first function arguments of choice, for example,
{'language': 'python', 'code': "print('Hello, World!')"}
"""
message = rsp.choices[0].message
if (
message.tool_calls is not None
and message.tool_calls[0].function is not None
and message.tool_calls[0].function.arguments is not None
):
# reponse is code
try:
return json.loads(message.tool_calls[0].function.arguments, strict=False)
except json.decoder.JSONDecodeError as e:
error_msg = (
f"Got JSONDecodeError for \n{'--'*40} \n{message.tool_calls[0].function.arguments}, {str(e)}"
)
logger.error(error_msg)
return self._parse_arguments(message.tool_calls[0].function.arguments)
elif message.tool_calls is None and message.content is not None:
# reponse is code, fix openai tools_call respond bug,
# The response content is `code``, but it appears in the content instead of the arguments.
code_formats = "```"
if message.content.startswith(code_formats) and message.content.endswith(code_formats):
code = CodeParser.parse_code(text=message.content)
return {"language": "python", "code": code}
# reponse is message
return {"language": "markdown", "code": self.get_choice_text(rsp)}
else:
logger.error(f"Failed to parse \n {rsp}\n")
raise Exception(f"Failed to parse \n {rsp}\n")
def get_choice_text(self, rsp: ChatCompletion) -> str:
"""Required to provide the first text of choice"""
return rsp.choices[0].message.content if rsp.choices else ""
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not self.config.calc_usage:
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
usage.completion_tokens = count_output_tokens(rsp, self.pricing_plan)
except Exception as e:
logger.warning(f"usage calculation failed: {e}")
return usage
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token
# FIXME
# https://community.openai.com/t/why-is-gpt-3-5-turbo-1106-max-tokens-limited-to-4096/494973/3
return min(get_max_completion_tokens(messages, self.model, self.config.max_token), 4096)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
"""Moderate content."""
return await self.aclient.moderations.create(input=content)
async def atext_to_speech(self, **kwargs):
"""text to speech"""
return await self.aclient.audio.speech.create(**kwargs)
async def aspeech_to_text(self, **kwargs):
"""speech to text"""
return await self.aclient.audio.transcriptions.create(**kwargs)
async def gen_image(
self,
prompt: str,
size: str = "1024x1024",
quality: str = "standard",
model: str = None,
resp_format: str = "url",
) -> list["Image"]:
"""image generate"""
assert resp_format in ["url", "b64_json"]
if not model:
model = self.model
res = await self.aclient.images.generate(
model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format
)
imgs = []
for item in res.data:
img_url_or_b64 = item.url if resp_format == "url" else item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs
def count_tokens(self, messages: list[dict]) -> int:
try:
return count_message_tokens(messages, self.model)
except:
return super().count_tokens(messages)