Skip to content

Commit a80358b

Browse files
mjsir911pgjones
authored andcommitted
Accept AsyncIterables being passed to Response
Fixes pallets/flask#5322
1 parent de9d9ef commit a80358b

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

src/quart/utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from pathlib import Path
1111
from typing import (
1212
Any,
13-
AsyncGenerator,
13+
AsyncIterator,
1414
Awaitable,
1515
Callable,
1616
Coroutine,
17-
Generator,
1817
Iterable,
18+
Iterator,
1919
TYPE_CHECKING,
20+
TypeVar,
2021
)
2122

2223
from werkzeug.datastructures import Headers
@@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
6667
return _wrapper
6768

6869

69-
def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]:
70-
async def _gen_wrapper() -> AsyncGenerator[Any, None]:
70+
T = TypeVar("T")
71+
72+
73+
def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]:
74+
async def _gen_wrapper() -> AsyncIterator[T]:
7175
# Wrap the generator such that each iteration runs
7276
# in the executor. Then rationalise the raised
7377
# errors so that it ends.
74-
def _inner() -> Any:
78+
def _inner() -> T:
7579
# https://bugs.python.org/issue26221
7680
# StopIteration errors are swallowed by the
7781
# run_in_exector method

src/quart/wrappers/response.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from abc import ABC, abstractmethod
44
from hashlib import md5
5-
from inspect import isasyncgen, isgenerator
65
from io import BytesIO
76
from os import PathLike
87
from types import TracebackType
@@ -101,27 +100,21 @@ async def __anext__(self) -> bytes:
101100

102101

103102
class IterableBody(ResponseBody):
104-
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
105-
self.iter: AsyncGenerator[bytes, None]
106-
if isasyncgen(iterable):
107-
self.iter = iterable
108-
elif isgenerator(iterable):
109-
self.iter = run_sync_iterable(iterable)
103+
def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None:
104+
self.iter: AsyncIterator[Any]
105+
if isinstance(iterable, Iterable):
106+
self.iter = run_sync_iterable(iter(iterable))
110107
else:
111-
112-
async def _aiter() -> AsyncGenerator[bytes, None]:
113-
for data in iterable: # type: ignore
114-
yield data
115-
116-
self.iter = _aiter()
108+
self.iter = iterable.__aiter__() # Can't use aiter() until 3.10
117109

118110
async def __aenter__(self) -> IterableBody:
119111
return self
120112

121113
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
122-
await self.iter.aclose()
114+
if hasattr(self.iter, "aclose"):
115+
await self.iter.aclose()
123116

124-
def __aiter__(self) -> AsyncIterator:
117+
def __aiter__(self) -> AsyncIterator[Any]:
125118
return self.iter
126119

127120

@@ -261,7 +254,7 @@ class Response(SansIOResponse):
261254

262255
def __init__(
263256
self,
264-
response: ResponseBody | str | bytes | Iterable | None = None,
257+
response: ResponseBody | str | bytes | Iterable | AsyncIterable | None = None,
265258
status: int | None = None,
266259
headers: dict | Headers | None = None,
267260
mimetype: str | None = None,

tests/test_templating.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
g,
1010
Quart,
1111
render_template_string,
12+
Response,
1213
ResponseReturnValue,
1314
session,
1415
stream_template_string,
@@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
148149
test_client = app.test_client()
149150
response = await test_client.get("/")
150151
assert (await response.data) == b"42"
152+
153+
@app.get("/2")
154+
async def index2() -> ResponseReturnValue:
155+
return Response(await stream_template_string("{{ config }}", config=43))
156+
157+
test_client = app.test_client()
158+
response = await test_client.get("/2")
159+
assert (await response.data) == b"43"

0 commit comments

Comments
 (0)