1
1
import json
2
2
from inspect import getfullargspec
3
- from typing import List , Generator , Tuple , Callable , Optional , Union , Dict , Set , Any , Iterator , AsyncGenerator , \
4
- Awaitable
3
+ from typing import List , Generator , Tuple , Callable , Optional , Union , Dict , Any , Iterator , AsyncGenerator , Awaitable , \
4
+ Set , Coroutine , AsyncIterator
5
5
6
6
from openai .openai_object import OpenAIObject
7
7
8
8
from json_streamer import ParseState , loads
9
9
from .fn_dispatcher import dispatch_yielded_functions_with_args , o_func
10
10
11
+ OAIObject = Union [OpenAIObject , Dict ]
12
+ OAIGenerator = Union [Generator [OAIObject , Any , None ], List [OAIObject ], Iterator [OAIObject ]]
13
+ OAIAsyncGenerator = Union [AsyncGenerator [OAIObject , None ], AsyncIterator [OAIObject ]]
14
+ OAIStream = Union [OAIGenerator , OAIAsyncGenerator ]
15
+
11
16
12
17
class ContentFuncDef :
13
18
"""
@@ -34,7 +39,7 @@ def __init__(self, func: Callable):
34
39
35
40
36
41
def _simplified_generator (
37
- response : Union [ Iterator [ OpenAIObject ], List [ OpenAIObject ]] ,
42
+ response : OAIStream ,
38
43
content_fn_def : Optional [ContentFuncDef ],
39
44
result : Dict
40
45
) -> Callable [[], AsyncGenerator [Tuple [str , Dict ], None ]]:
@@ -50,7 +55,7 @@ def _simplified_generator(
50
55
result ["role" ] = "assistant"
51
56
52
57
async def generator () -> AsyncGenerator [Tuple [str , Dict ], None ]:
53
- for r in _process_stream (response , content_fn_def ):
58
+ async for r in _process_stream (response , content_fn_def ):
54
59
if content_fn_def is not None and r [0 ] == content_fn_def .name :
55
60
yield content_fn_def .name , {content_fn_def .arg : r [2 ]}
56
61
@@ -99,7 +104,7 @@ def preprocess(self, key, current_dict):
99
104
100
105
101
106
async def process_response (
102
- response : Union [ Iterator [ OpenAIObject ], List [ OpenAIObject ]] ,
107
+ response : OAIStream ,
103
108
content_func : Optional [Callable [[AsyncGenerator [str , None ]], Awaitable [None ]]] = None ,
104
109
funcs : Optional [List [Callable [[], Awaitable [None ]]]] = None ,
105
110
self : Optional = None
@@ -123,7 +128,8 @@ async def process_response(
123
128
# assert content_func signature is Generator[str, None, None]
124
129
content_fn_def = ContentFuncDef (content_func ) if content_func is not None else None
125
130
126
- if not isinstance (response , Iterator ) and not isinstance (response , list ):
131
+ if (not isinstance (response , Iterator ) and not isinstance (response , List )
132
+ and not isinstance (response , AsyncIterator ) and not isinstance (response , AsyncGenerator )):
127
133
raise ValueError ("response must be an iterator (generator's stream from OpenAI or a log as a list)" )
128
134
129
135
func_map : Dict [str , Callable ] = {}
@@ -168,43 +174,66 @@ def _arguments_processor(json_loader=loads) -> Generator[Tuple[ParseState, dict]
168
174
break
169
175
170
176
171
- def _process_stream (response : Iterator [OpenAIObject ], content_fn_def : Optional [ContentFuncDef ]) \
172
- -> Generator [Tuple [str , ParseState , Union [dict , str ]], None , None ]:
177
+ class StreamProcessorState :
178
+ content_fn_def : Optional [ContentFuncDef ] = None
179
+ current_processor : Optional [Generator [Tuple [ParseState , dict ], str , None ]] = None
180
+ current_fn : Optional [str ] = None
181
+
182
+ def __init__ (self , content_fn_def : Optional [ContentFuncDef ]):
183
+ self .content_fn_def = content_fn_def
184
+
185
+
186
+ async def _process_stream (
187
+ response : OAIStream ,
188
+ content_fn_def : Optional [ContentFuncDef ]
189
+ ) -> AsyncGenerator [Tuple [str , ParseState , Union [dict , str ]], None ]:
173
190
"""
174
191
Processes an OpenAI response stream and yields the function name, the parse state and the parsed arguments.
175
192
:param response: The response stream from OpenAI
176
193
:param content_fn_def: The content function definition
177
194
:return: A generator that yields the function name, the parse state and the parsed arguments
178
195
"""
179
196
180
- current_processor = None
181
- current_fn = None
182
- for message in response :
183
- choice = message ["choices" ][0 ]
184
- if "delta" not in choice :
185
- raise LookupError ("No delta in choice" )
186
-
187
- delta = message ["choices" ][0 ]["delta" ]
188
- if "function_call" in delta :
189
- if "name" in delta ["function_call" ]:
190
- if current_processor is not None :
191
- current_processor .close ()
192
- current_fn = delta ["function_call" ]["name" ]
193
- current_processor = _arguments_processor ()
194
- next (current_processor )
195
- if "arguments" in delta ["function_call" ]:
196
- arg = delta ["function_call" ]["arguments" ]
197
- ret = current_processor .send (arg )
198
- if ret is not None :
199
- yield current_fn , ret [0 ], ret [1 ]
200
- if "content" in delta :
201
- if delta ["content" ] is None or delta ["content" ] == "" :
202
- continue
203
- if content_fn_def is not None :
204
- yield content_fn_def .name , ParseState .PARTIAL , delta ["content" ]
205
- else :
206
- yield None , ParseState .PARTIAL , delta ["content" ]
207
- if "finish_reason" in message and message ["finish_reason" ] == "finish_reason" :
208
- current_processor .close ()
209
- current_processor = None
210
- current_fn = None
197
+ state = StreamProcessorState (content_fn_def = content_fn_def )
198
+ if isinstance (response , AsyncGenerator ) or isinstance (response , AsyncIterator ):
199
+ async for message in response :
200
+ for res in _process_message (message , state ):
201
+ yield res
202
+ else :
203
+ for message in response :
204
+ for res in _process_message (message , state ):
205
+ yield res
206
+
207
+
208
+ def _process_message (
209
+ message : OAIObject ,
210
+ state : StreamProcessorState
211
+ ) -> Generator [Tuple [str , ParseState , Union [dict , str ]], None , None ]:
212
+ choice = message ["choices" ][0 ]
213
+ if "delta" not in choice :
214
+ raise LookupError ("No delta in choice" )
215
+
216
+ delta = message ["choices" ][0 ]["delta" ]
217
+ if "function_call" in delta :
218
+ if "name" in delta ["function_call" ]:
219
+ if state .current_processor is not None :
220
+ state .current_processor .close ()
221
+ state .current_fn = delta ["function_call" ]["name" ]
222
+ state .current_processor = _arguments_processor ()
223
+ next (state .current_processor )
224
+ if "arguments" in delta ["function_call" ]:
225
+ arg = delta ["function_call" ]["arguments" ]
226
+ ret = state .current_processor .send (arg )
227
+ if ret is not None :
228
+ yield state .current_fn , ret [0 ], ret [1 ]
229
+ if "content" in delta :
230
+ if delta ["content" ] is None or delta ["content" ] == "" :
231
+ return
232
+ if state .content_fn_def is not None :
233
+ yield state .content_fn_def .name , ParseState .PARTIAL , delta ["content" ]
234
+ else :
235
+ yield None , ParseState .PARTIAL , delta ["content" ]
236
+ if "finish_reason" in message and message ["finish_reason" ] == "finish_reason" :
237
+ state .current_processor .close ()
238
+ state .current_processor = None
239
+ state .current_fn = None
0 commit comments