1
1
import json
2
2
from inspect import getfullargspec
3
- from typing import List , Generator , Tuple , Callable , Optional , Union , Dict , Any , Iterator , AsyncGenerator , Awaitable , \
3
+ from typing import List , Generator , Tuple , Callable , Optional , Union , Dict , Iterator , AsyncGenerator , Awaitable , \
4
4
Set , AsyncIterator
5
5
6
6
from openai import AsyncStream , Stream
7
- from openai .types .chat import ChatCompletion , ChatCompletionChunk
7
+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessage , ChatCompletionMessageToolCall
8
+ from openai .types .chat .chat_completion_message_tool_call import Function
8
9
9
10
from json_streamer import ParseState , loads
10
11
from .fn_dispatcher import dispatch_yielded_functions_with_args , o_func
@@ -46,7 +47,7 @@ def __init__(self, func: Callable):
46
47
def _simplified_generator (
47
48
response : OAIResponse ,
48
49
content_fn_def : Optional [ContentFuncDef ],
49
- result : Dict
50
+ result : ChatCompletionMessage
50
51
) -> Callable [[], AsyncGenerator [Tuple [str , Dict ], None ]]:
51
52
"""
52
53
Return an async generator that converts an OpenAI response stream to a simple generator that yields function names
@@ -57,20 +58,25 @@ def _simplified_generator(
57
58
:return: A function that returns a generator
58
59
"""
59
60
60
- result ["role" ] = "assistant"
61
-
62
61
async def generator () -> AsyncGenerator [Tuple [str , Dict ], None ]:
62
+
63
63
async for r in _process_stream (response , content_fn_def ):
64
64
if content_fn_def is not None and r [0 ] == content_fn_def .name :
65
65
yield content_fn_def .name , {content_fn_def .arg : r [2 ]}
66
66
67
- if " content" not in result :
68
- result [ " content" ] = ""
69
- result [ " content" ] += r [2 ]
67
+ if result . content is None :
68
+ result . content = ""
69
+ result . content += r [2 ]
70
70
else :
71
71
yield r [0 ], r [2 ]
72
72
if r [1 ] == ParseState .COMPLETE :
73
- result ["function_call" ] = {"name" : r [0 ], "arguments" : json .dumps (r [2 ])}
73
+ if result .tool_calls is None :
74
+ result .tool_calls = []
75
+ result .tool_calls .append (ChatCompletionMessageToolCall (
76
+ id = r [3 ] or "" ,
77
+ type = "function" ,
78
+ function = Function (name = r [0 ], arguments = json .dumps (r [2 ]))
79
+ ))
74
80
75
81
return generator
76
82
@@ -113,7 +119,7 @@ async def process_response(
113
119
content_func : Optional [Callable [[AsyncGenerator [str , None ]], Awaitable [None ]]] = None ,
114
120
funcs : Optional [List [Callable [[], Awaitable [None ]]]] = None ,
115
121
self : Optional = None
116
- ) -> Tuple [Set [str ], Dict [ str , Any ] ]:
122
+ ) -> Tuple [Set [str ], ChatCompletionMessage ]:
117
123
"""
118
124
Processes an OpenAI response stream and returns a set of function names that were invoked, and a dictionary contains
119
125
the results of the functions (to be used as part of the message history for the next api request).
@@ -144,7 +150,7 @@ async def process_response(
144
150
if content_fn_def is not None :
145
151
func_map [content_fn_def .name ] = content_func
146
152
147
- result = {}
153
+ result = ChatCompletionMessage ( role = "assistant" )
148
154
gen = _simplified_generator (response , content_fn_def , result )
149
155
preprocess = DiffPreprocessor (content_fn_def )
150
156
return await dispatch_yielded_functions_with_args (gen , func_map , preprocess .preprocess , self ), result
@@ -183,6 +189,7 @@ class StreamProcessorState:
183
189
content_fn_def : Optional [ContentFuncDef ] = None
184
190
current_processor : Optional [Generator [Tuple [ParseState , dict ], str , None ]] = None
185
191
current_fn : Optional [str ] = None
192
+ call_id : Optional [str ] = None
186
193
187
194
def __init__ (self , content_fn_def : Optional [ContentFuncDef ]):
188
195
self .content_fn_def = content_fn_def
@@ -191,7 +198,7 @@ def __init__(self, content_fn_def: Optional[ContentFuncDef]):
191
198
async def _process_stream (
192
199
response : OAIResponse ,
193
200
content_fn_def : Optional [ContentFuncDef ]
194
- ) -> AsyncGenerator [Tuple [str , ParseState , Union [dict , str ]], None ]:
201
+ ) -> AsyncGenerator [Tuple [str , ParseState , Union [dict , str ], Optional [ str ] ], None ]:
195
202
"""
196
203
Processes an OpenAI response stream and yields the function name, the parse state and the parsed arguments.
197
204
:param response: The response stream from OpenAI
@@ -213,7 +220,7 @@ async def _process_stream(
213
220
def _process_message (
214
221
message : ChatCompletionChunk ,
215
222
state : StreamProcessorState
216
- ) -> Generator [Tuple [str , ParseState , Union [dict , str ]], None , None ]:
223
+ ) -> Generator [Tuple [str , ParseState , Union [dict , str ], Optional [ str ] ], None , None ]:
217
224
"""
218
225
This function processes the responses as they arrive from OpenAI, and transforms them as a generator of
219
226
partial objects
@@ -231,25 +238,28 @@ def _process_message(
231
238
if func .name :
232
239
if state .current_processor is not None :
233
240
state .current_processor .close ()
241
+
242
+ state .call_id = delta .tool_calls and delta .tool_calls [0 ].id or None
234
243
state .current_fn = func .name
235
244
state .current_processor = _arguments_processor ()
236
245
next (state .current_processor )
237
246
if func .arguments :
238
247
arg = func .arguments
239
248
ret = state .current_processor .send (arg )
240
249
if ret is not None :
241
- yield state .current_fn , ret [0 ], ret [1 ]
250
+ yield state .current_fn , ret [0 ], ret [1 ], state . call_id
242
251
if delta .content :
243
252
if delta .content is None or delta .content == "" :
244
253
return
245
254
if state .content_fn_def is not None :
246
- yield state .content_fn_def .name , ParseState .PARTIAL , delta .content
255
+ yield state .content_fn_def .name , ParseState .PARTIAL , delta .content , state . call_id
247
256
else :
248
- yield None , ParseState .PARTIAL , delta .content
257
+ yield None , ParseState .PARTIAL , delta .content , None
249
258
if message .choices [0 ].finish_reason and (
250
259
message .choices [0 ].finish_reason == "function_call" or message .choices [0 ].finish_reason == "tool_calls"
251
260
):
252
261
if state .current_processor is not None :
253
262
state .current_processor .close ()
254
263
state .current_processor = None
255
264
state .current_fn = None
265
+ state .call_id = None
0 commit comments