Skip to content

Commit

Permalink
Accept AsyncIterables being passed to Response
Browse files Browse the repository at this point in the history
  • Loading branch information
mjsir911 committed May 21, 2024
1 parent 2fc6d4f commit 7b428f2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
14 changes: 9 additions & 5 deletions src/quart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterable,
Iterator,
TYPE_CHECKING,
TypeVar,
)

from werkzeug.datastructures import Headers
Expand Down Expand Up @@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
return _wrapper


def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]:
async def _gen_wrapper() -> AsyncGenerator[Any, None]:
T = TypeVar("T")


def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]:
async def _gen_wrapper() -> AsyncIterator[T]:
# Wrap the generator such that each iteration runs
# in the executor. Then rationalise the raised
# errors so that it ends.
def _inner() -> Any:
def _inner() -> T:
# https://bugs.python.org/issue26221
# StopIteration errors are swallowed by the
# run_in_exector method
Expand Down
25 changes: 9 additions & 16 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from abc import ABC, abstractmethod
from hashlib import md5
from inspect import isasyncgen, isgenerator
from io import BytesIO
from os import PathLike
from types import TracebackType
Expand Down Expand Up @@ -102,27 +101,21 @@ async def __anext__(self) -> bytes:


class IterableBody(ResponseBody):
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable)
def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None:
self.iter: AsyncIterator[Any]
if isinstance(iterable, Iterable):
self.iter = run_sync_iterable(iter(iterable))
else:

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
yield data

self.iter = _aiter()
self.iter = aiter(iterable)

async def __aenter__(self) -> IterableBody:
return self

async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
await self.iter.aclose()
if hasattr(self.iter, "aclose"): # Is a generator?
await self.iter.aclose()

def __aiter__(self) -> AsyncIterator:
def __aiter__(self) -> AsyncIterator[Any]:
return self.iter


Expand Down Expand Up @@ -262,7 +255,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
g,
Quart,
render_template_string,
Response,
ResponseReturnValue,
session,
stream_template_string,
Expand Down Expand Up @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
test_client = app.test_client()
response = await test_client.get("/")
assert (await response.data) == b"42"

@app.get("/2")
async def index2() -> ResponseReturnValue:
return Response(await stream_template_string("{{ config }}", config=43))

test_client = app.test_client()
response = await test_client.get("/2")
assert (await response.data) == b"43"

0 comments on commit 7b428f2

Please sign in to comment.