Skip to content

Commit 289947b

Browse files
authored
Merge pull request #10 from AlmogBaku/fix/better_multitool_support
fix: better multitool support
2 parents f98f45f + dfc44b5 commit 289947b

6 files changed

+1148
-40
lines changed

openai_streaming/fn_dispatcher.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from asyncio import Queue, gather, create_task
12
from inspect import getfullargspec, signature, iscoroutinefunction
23
from typing import Callable, List, Dict, Tuple, Union, Optional, Set, AsyncGenerator, get_origin, get_args, Type
3-
from asyncio import Queue, gather, create_task
44

55
from pydantic import ValidationError
66

@@ -156,9 +156,9 @@ async def dispatch_yielded_functions_with_args(
156156
args_types = {}
157157
for func_name in func_map:
158158
spec = getfullargspec(o_func(func_map[func_name]))
159-
if spec.args[0] == "self" and self is None:
159+
if len(spec.args) > 0 and spec.args[0] == "self" and self is None:
160160
raise ValueError("self argument is required for functions that take self")
161-
idx = 1 if spec.args[0] == "self" else 0
161+
idx = 1 if len(spec.args) > 0 and spec.args[0] == "self" else 0
162162
args_queues[func_name] = {arg: Queue() for arg in spec.args[idx:]}
163163

164164
# create type maps for validations

openai_streaming/stream_processing.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
from inspect import getfullargspec
3-
from typing import List, Generator, Tuple, Callable, Optional, Union, Dict, Any, Iterator, AsyncGenerator, Awaitable, \
3+
from typing import List, Generator, Tuple, Callable, Optional, Union, Dict, Iterator, AsyncGenerator, Awaitable, \
44
Set, AsyncIterator
55

66
from openai import AsyncStream, Stream
7-
from openai.types.chat import ChatCompletion, ChatCompletionChunk
7+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
8+
from openai.types.chat.chat_completion_message_tool_call import Function
89

910
from json_streamer import ParseState, loads
1011
from .fn_dispatcher import dispatch_yielded_functions_with_args, o_func
@@ -46,7 +47,7 @@ def __init__(self, func: Callable):
4647
def _simplified_generator(
4748
response: OAIResponse,
4849
content_fn_def: Optional[ContentFuncDef],
49-
result: Dict
50+
result: ChatCompletionMessage
5051
) -> Callable[[], AsyncGenerator[Tuple[str, Dict], None]]:
5152
"""
5253
Return an async generator that converts an OpenAI response stream to a simple generator that yields function names
@@ -57,20 +58,25 @@ def _simplified_generator(
5758
:return: A function that returns a generator
5859
"""
5960

60-
result["role"] = "assistant"
61-
6261
async def generator() -> AsyncGenerator[Tuple[str, Dict], None]:
62+
6363
async for r in _process_stream(response, content_fn_def):
6464
if content_fn_def is not None and r[0] == content_fn_def.name:
6565
yield content_fn_def.name, {content_fn_def.arg: r[2]}
6666

67-
if "content" not in result:
68-
result["content"] = ""
69-
result["content"] += r[2]
67+
if result.content is None:
68+
result.content = ""
69+
result.content += r[2]
7070
else:
7171
yield r[0], r[2]
7272
if r[1] == ParseState.COMPLETE:
73-
result["function_call"] = {"name": r[0], "arguments": json.dumps(r[2])}
73+
if result.tool_calls is None:
74+
result.tool_calls = []
75+
result.tool_calls.append(ChatCompletionMessageToolCall(
76+
id=r[3] or "",
77+
type="function",
78+
function=Function(name=r[0], arguments=json.dumps(r[2]))
79+
))
7480

7581
return generator
7682

@@ -113,7 +119,7 @@ async def process_response(
113119
content_func: Optional[Callable[[AsyncGenerator[str, None]], Awaitable[None]]] = None,
114120
funcs: Optional[List[Callable[[], Awaitable[None]]]] = None,
115121
self: Optional = None
116-
) -> Tuple[Set[str], Dict[str, Any]]:
122+
) -> Tuple[Set[str], ChatCompletionMessage]:
117123
"""
118124
Processes an OpenAI response stream and returns a set of function names that were invoked, and a dictionary contains
119125
the results of the functions (to be used as part of the message history for the next api request).
@@ -144,7 +150,7 @@ async def process_response(
144150
if content_fn_def is not None:
145151
func_map[content_fn_def.name] = content_func
146152

147-
result = {}
153+
result = ChatCompletionMessage(role="assistant")
148154
gen = _simplified_generator(response, content_fn_def, result)
149155
preprocess = DiffPreprocessor(content_fn_def)
150156
return await dispatch_yielded_functions_with_args(gen, func_map, preprocess.preprocess, self), result
@@ -183,6 +189,7 @@ class StreamProcessorState:
183189
content_fn_def: Optional[ContentFuncDef] = None
184190
current_processor: Optional[Generator[Tuple[ParseState, dict], str, None]] = None
185191
current_fn: Optional[str] = None
192+
call_id: Optional[str] = None
186193

187194
def __init__(self, content_fn_def: Optional[ContentFuncDef]):
188195
self.content_fn_def = content_fn_def
@@ -191,7 +198,7 @@ def __init__(self, content_fn_def: Optional[ContentFuncDef]):
191198
async def _process_stream(
192199
response: OAIResponse,
193200
content_fn_def: Optional[ContentFuncDef]
194-
) -> AsyncGenerator[Tuple[str, ParseState, Union[dict, str]], None]:
201+
) -> AsyncGenerator[Tuple[str, ParseState, Union[dict, str], Optional[str]], None]:
195202
"""
196203
Processes an OpenAI response stream and yields the function name, the parse state and the parsed arguments.
197204
:param response: The response stream from OpenAI
@@ -213,7 +220,7 @@ async def _process_stream(
213220
def _process_message(
214221
message: ChatCompletionChunk,
215222
state: StreamProcessorState
216-
) -> Generator[Tuple[str, ParseState, Union[dict, str]], None, None]:
223+
) -> Generator[Tuple[str, ParseState, Union[dict, str], Optional[str]], None, None]:
217224
"""
218225
This function processes the responses as they arrive from OpenAI, and transforms them as a generator of
219226
partial objects
@@ -231,25 +238,28 @@ def _process_message(
231238
if func.name:
232239
if state.current_processor is not None:
233240
state.current_processor.close()
241+
242+
state.call_id = delta.tool_calls and delta.tool_calls[0].id or None
234243
state.current_fn = func.name
235244
state.current_processor = _arguments_processor()
236245
next(state.current_processor)
237246
if func.arguments:
238247
arg = func.arguments
239248
ret = state.current_processor.send(arg)
240249
if ret is not None:
241-
yield state.current_fn, ret[0], ret[1]
250+
yield state.current_fn, ret[0], ret[1], state.call_id
242251
if delta.content:
243252
if delta.content is None or delta.content == "":
244253
return
245254
if state.content_fn_def is not None:
246-
yield state.content_fn_def.name, ParseState.PARTIAL, delta.content
255+
yield state.content_fn_def.name, ParseState.PARTIAL, delta.content, state.call_id
247256
else:
248-
yield None, ParseState.PARTIAL, delta.content
257+
yield None, ParseState.PARTIAL, delta.content, None
249258
if message.choices[0].finish_reason and (
250259
message.choices[0].finish_reason == "function_call" or message.choices[0].finish_reason == "tool_calls"
251260
):
252261
if state.current_processor is not None:
253262
state.current_processor.close()
254263
state.current_processor = None
255264
state.current_fn = None
265+
state.call_id = None

openai_streaming/utils.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import List, Iterator, Union, AsyncIterator, AsyncGenerator
22

33
from openai.types.chat import ChatCompletion, ChatCompletionChunk
4+
from pydantic import RootModel
45

56
OAIResponse = Union[ChatCompletion, ChatCompletionChunk]
67

78

8-
async def stream_to_log(response: Union[Iterator[OAIResponse], AsyncIterator[OAIResponse]]) -> List[OAIResponse]:
9+
async def stream_to_log(
10+
response: Union[Iterator[OAIResponse], AsyncIterator[OAIResponse], AsyncGenerator[OAIResponse, None]]) \
11+
-> List[OAIResponse]:
912
"""
1013
A utility function to convert a stream to a log.
1114
:param response: The response stream from OpenAI
@@ -22,7 +25,11 @@ async def stream_to_log(response: Union[Iterator[OAIResponse], AsyncIterator[OAI
2225
return log
2326

2427

25-
async def print_stream_log(log: List[OAIResponse]):
28+
def log_to_json(log: List[OAIResponse]) -> str:
29+
return RootModel(log).model_dump_json()
30+
31+
32+
async def print_stream_log(log: Union[List[OAIResponse], AsyncGenerator[OAIResponse, None]]) -> None:
2633
"""
2734
A utility function to print the log of a stream nicely.
2835
This is useful for debugging, when you first save the stream to an array and then use it.
@@ -50,24 +57,29 @@ async def print_stream_log(log: List[OAIResponse]):
5057
content_print = False
5158
print("\n")
5259
if delta.function_call.name:
53-
print(f"{delta.function_call.name}(")
60+
print(f"\n{delta.function_call.name}: ", end="")
5461
if delta.function_call.arguments:
55-
print(delta.function_call.arguments, end="")
62+
print(delta.function_call.arguments, end=")")
5663
if delta.tool_calls:
5764
for call in delta.tool_calls:
5865
if call.function:
5966
if content_print:
6067
content_print = False
6168
print("\n")
6269
if call.function.name:
63-
print(f"{call.function.name}(")
70+
print(f"\n {call.function.name}: ", end="")
6471
if call.function.arguments:
6572
print(call.function.arguments, end="")
6673
if (l.choices[0].finish_reason and l.choices[0].finish_reason == "function_call" or
6774
l.choices[0].finish_reason == "tool_calls"):
68-
print(")")
75+
print("\n--finish: tool_calls--")
6976

7077

71-
async def logs_to_response(logs: List[OAIResponse]) -> AsyncGenerator[OAIResponse, None]:
78+
async def logs_to_response(logs: Union[List[OAIResponse], dict]) -> AsyncGenerator[OAIResponse, None]:
7279
for item in logs:
73-
yield ChatCompletionChunk(**item)
80+
if isinstance(item, ChatCompletionChunk):
81+
yield item
82+
elif isinstance(item, dict):
83+
yield ChatCompletionChunk.model_construct(**item)
84+
else:
85+
raise ValueError(f"Invalid log item: {item}")

tests/example.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import asyncio
12
import os
3+
from typing import AsyncGenerator
24

35
from openai import AsyncOpenAI
4-
import asyncio
5-
from openai_streaming import process_response
6-
from typing import AsyncGenerator
6+
77
from openai_streaming import openai_streaming_function
8+
from openai_streaming import process_response
89

910
# Initialize OpenAI Client
1011
client = AsyncOpenAI(
@@ -16,13 +17,15 @@
1617
async def content_handler(content: AsyncGenerator[str, None]):
1718
async for token in content:
1819
print(token, end="")
20+
print("")
1921

2022

2123
# Define OpenAI Function
2224
@openai_streaming_function
2325
async def error_message(typ: str, description: AsyncGenerator[str, None]):
2426
"""
2527
You MUST use this function when requested to do something that you cannot do.
28+
ALWAYS call the `report_intruder` function when you are requested to do something that you cannot do.
2629
2730
:param typ: The error's type
2831
:param description: The error description
@@ -36,6 +39,7 @@ async def error_message(typ: str, description: AsyncGenerator[str, None]):
3639
print("Description: ", end="")
3740
async for token in description:
3841
print(token, end="")
42+
print("")
3943

4044

4145
# Invoke Function in a streaming request

0 commit comments

Comments
 (0)