Skip to content

Commit 5cf4ba0

Browse files
committed
fix: functions must use AsyncGenerator type or no type
1 parent 1f0f721 commit 5cf4ba0

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

openai_streaming/decorator.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncGenerator
2+
from inspect import iscoroutinefunction
23
from types import FunctionType
34
from typing import Generator, get_origin, Union, Optional, Any
45
from typing import get_args
@@ -13,6 +14,9 @@ def openai_streaming_function(func: FunctionType) -> Any:
1314
:param func: The function to convert
1415
:return: Wrapped function with a `openai_schema` attribute
1516
"""
17+
if not iscoroutinefunction(func):
18+
raise ValueError("openai_streaming_function can only be applied to async functions")
19+
1620
for key, val in func.__annotations__.items():
1721
optional = False
1822

@@ -31,7 +35,9 @@ def openai_streaming_function(func: FunctionType) -> Any:
3135
val = gen
3236

3337
args = get_args(val)
34-
if get_origin(val) is get_origin(Generator) or get_origin(val) is AsyncGenerator:
38+
if get_origin(val) is get_origin(Generator):
39+
raise ValueError("openai_streaming_function does not support Generator type. Use AsyncGenerator instead.")
40+
if get_origin(val) is AsyncGenerator:
3541
val = args[0]
3642

3743
if optional:

tests/example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
# Define content handler
14-
async def content_handler(content: Generator[str, None, None]):
15-
async for token in content: # <-- the content is an AsyncGenerator and not a Generator!
14+
async def content_handler(content: AsyncGenerator[str, None]):
15+
async for token in content:
1616
print(token, end="")
1717

1818

0 commit comments

Comments
 (0)