55from openai .types .chat .chat_completion_message_tool_call import ChatCompletionMessageToolCall , Function
66
77from modelscope_agent .llm .llm import LLM
8+ from modelscope_agent .utils .llm_utils import retry
89from modelscope_agent .llm .utils import Message , Tool , ToolCall
910from modelscope_agent .utils .utils import assert_package_exist
1011
1112
12-
1313class OpenAI (LLM ):
14+ input_msg = {'role' , 'content' , 'tool_calls' , 'partial' , 'prefix' }
1415
1516 def __init__ (self , config : DictConfig , base_url : Optional [str ] = None , api_key : Optional [str ] = None ):
1617 super ().__init__ (config )
@@ -23,10 +24,10 @@ def __init__(self, config: DictConfig, base_url: Optional[str] = None, api_key:
2324 api_key = api_key ,
2425 base_url = base_url ,
2526 )
26- exclude_fields = {"model" , "base_url" , "api_key" }
27- self .args : Dict = {k : v for k , v in OmegaConf .to_container (getattr (config , 'generation_config' , {}), resolve = True ).items () if k not in exclude_fields }
27+ self .args : Dict = {k : v for k , v in getattr (config .llm , 'generation_config' , {}).items ()}
2828
29- def generate (self , messages : List [Message ], model : Optional [str ] = None , tools : List [Tool ] = None , ** kwargs ) -> Message | Generator [Message , None , None ]:
29+ @retry (max_attempts = 3 )
30+ def generate (self , messages : List [Message ], tools : List [Tool ] = None , ** kwargs ) -> Message | Generator [Message , None , None ]:
3031 parameters = inspect .signature (self .client .chat .completions .create ).parameters
3132 args = self .args .copy ()
3233 args .update (kwargs )
@@ -45,44 +46,23 @@ def generate(self, messages: List[Message], model: Optional[str] = None, tools:
4546 }
4647 } for tool in tools
4748 ]
48- completion = self ._call_llm (model or self . model , messages , tools , ** args )
49+ completion = self ._call_llm (messages , tools , ** args )
4950
5051 # 考虑到复杂任务可能存在 单次调用llm生成不完整的情况。需要调用continue_gen判断是否应多次调用以获得完整输出
5152 if stream :
5253 return self .stream_continue_generate (messages , completion , tools , ** args )
5354 else :
5455 return self .continue_generate (messages , completion , tools , ** args )
5556
56- def _call_llm (self , model , messages , tools , ** kwargs ):
57+ def _call_llm (self , messages , tools , ** kwargs ):
5758 messages = self .format_input_message (messages )
5859 return self .client .chat .completions .create (
59- model = model ,
60+ model = self . model ,
6061 messages = messages ,
6162 tools = tools ,
6263 ** kwargs
6364 )
6465
65- def _stream_continue_generate (self , messages : List [Message ], new_message , tools : List [Tool ] = None , ** kwargs ):
66- # 如果上一条消息也和new_message一样不完整,则进行拼接
67- if messages and messages [- 1 ].to_dict ().get ('partial' , False ):
68- # 更新最后一条消息的内容
69- messages [- 1 ].reasoning_content += new_message .reasoning_content
70- messages [- 1 ].content += new_message .content
71- if new_message .tool_calls :
72- if messages [- 1 ].tool_calls :
73- messages [- 1 ].tool_calls += new_message .tool_calls
74- else :
75- messages [- 1 ].tool_calls = new_message .tool_calls
76- else :
77- # 否则添加为新的 partial 消息
78- new_message .partial = True
79- messages .append (new_message )
80-
81- messages = self .format_input_message (messages )
82-
83- # 继续调用 LLM 并流式返回后续结果
84- return self ._call_llm (messages , tools , ** kwargs )
85-
8666 def stream_continue_generate (self , messages : List [Message ], completion , tools : List [Tool ] = None , ** kwargs ) -> Generator [Message , None , None ]:
8767 message = None
8868 for chunk in completion :
@@ -114,13 +94,13 @@ def stream_continue_generate(self, messages: List[Message], completion, tools: L
11494 yield message_chunk
11595 if chunk .choices [0 ].finish_reason in ['length' , 'null' ]:
11696 print (f'finish_reason: { chunk .choices [0 ].finish_reason } , continue generate.' )
117- completion = self ._stream_continue_generate (messages , message , tools , ** kwargs )
97+ completion = self ._continue_generate (messages , message , tools , ** kwargs )
11898 for chunk in self .stream_continue_generate (messages , completion , tools , ** kwargs ):
11999 yield chunk
120100
121101 def stream_format_output_message (self , completion_chunk ) -> Message :
122- content = completion_chunk .choices [0 ].delta .content
123- reasoning_content = completion_chunk .choices [0 ].delta .reasoning_content
102+ content = completion_chunk .choices [0 ].delta .content or ''
103+ reasoning_content = completion_chunk .choices [0 ].delta .reasoning_content or ''
124104 tool_calls = None
125105 if completion_chunk .choices [0 ].delta .tool_calls :
126106 func = completion_chunk .choices [0 ].delta .tool_calls
@@ -135,8 +115,8 @@ def stream_format_output_message(self, completion_chunk) -> Message:
135115 return Message (role = 'assistant' , content = content , reasoning_content = reasoning_content , tool_calls = tool_calls , id = completion_chunk .id )
136116
137117 def format_output_message (self , completion ) -> Message :
138- content = completion .choices [0 ].message .content
139- reasoning_content = completion .choices [0 ].message .reasoning_content
118+ content = completion .choices [0 ].message .content or ''
119+ reasoning_content = completion .choices [0 ].message .reasoning_content or ''
140120 tool_calls = None
141121 if completion .choices [0 ].message .tool_calls :
142122 tool_calls = [ToolCall (
@@ -149,11 +129,10 @@ def format_output_message(self, completion) -> Message:
149129 ]
150130 return Message (role = 'assistant' , content = content , reasoning_content = reasoning_content , tool_calls = tool_calls , id = completion .id )
151131
152- def _continue_generate (self , messages : List [Message ], completion , tools : List [Tool ] = None , ** kwargs ):
132+ def _continue_generate (self , messages : List [Message ], new_message , tools : List [Tool ] = None , ** kwargs ):
153133 # ref: https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=https%3A%2F%2Fhelp.aliyun.com%2Fdocument_detail%2F2862210.html&renderType=iframe
154134 # TODO: 移到dashscope_llm并找到真正openai的续写方式
155135 if messages [- 1 ].to_dict ().get ('partial' , False ):
156- new_message = self .format_output_message (completion )
157136 messages [- 1 ].reasoning_content += new_message .reasoning_content
158137 messages [- 1 ].content += new_message .content
159138 if new_message .tool_calls :
@@ -162,22 +141,21 @@ def _continue_generate(self, messages: List[Message], completion, tools: List[To
162141 else :
163142 messages [- 1 ].tool_calls = new_message .tool_calls
164143 else :
165- messages .append (self . format_output_message ( completion ) )
144+ messages .append (new_message )
166145 messages [- 1 ].partial = True
167146
168147 messages = self .format_input_message (messages )
169148 return self ._call_llm (messages , tools , ** kwargs )
170149
171150 def continue_generate (self , messages : List [Message ], completion , tools : List [Tool ] = None , ** kwargs ) -> Message :
172- # finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
173-
174-
151+ new_message = self .format_output_message (completion )
175152 if completion .choices [0 ].finish_reason in ['length' , 'null' ]:
153+ print (f'new_message: { new_message } ' )
176154 print (f'finish_reason: { completion .choices [0 ].finish_reason } , continue generate.' )
177- completion = self ._continue_generate (messages , completion , tools , ** kwargs )
155+ completion = self ._continue_generate (messages , new_message , tools , ** kwargs )
178156 return self .continue_generate (messages , completion , tools , ** kwargs )
179157 else :
180- return self . format_output_message ( completion )
158+ return new_message
181159
182160 def format_input_message (self , messages : List [Message ]) -> List [Dict [str , Any ]]:
183161 openai_messages = []
@@ -200,8 +178,7 @@ def format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]:
200178 tool_calls .append (tool_call )
201179 message ['tool_calls' ] = tool_calls
202180
203- input_msg = {'role' , 'content' , 'tool_calls' , 'partial' }
204- message = {key : value for key , value in message .items () if key in input_msg and value }
181+ message = {key : value for key , value in message .items () if key in self .input_msg and value }
205182
206183 openai_messages .append (message )
207184
@@ -233,7 +210,7 @@ def format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]:
233210
234211 # tools = [
235212 # Tool(server_name='amap-maps', tool_name='maps_regeocode', description='将一个高德经纬度坐标转换为行政区划地址信息', parameters={'type': 'object', 'properties': {'location': {'type': 'string', 'description': '经纬度'}}, 'required': ['location']}),
236- # Tool(tool_name='mkdir', description='在文件系统创建目录', parameters={'type': 'object', 'properties': {'dir_name': {'type': 'string', 'description': '目录名'}}, 'required': ['location ']})
213+ # Tool(tool_name='mkdir', description='在文件系统创建目录', parameters={'type': 'object', 'properties': {'dir_name': {'type': 'string', 'description': '目录名'}}, 'required': ['dir_name ']})
237214 # ]
238215 tools = None
239216
0 commit comments