Skip to content

Commit 917f3b4

Browse files
committed
refactor: handle athrow edge cases at the same site
1 parent dd4ae6b commit 917f3b4

4 files changed

Lines changed: 25 additions & 33 deletions

File tree

litestar/response/sse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
if retry_duration is not None:
5050
chunks.append(f"retry: {retry_duration}\r\n".encode())
5151

52-
super().__init__(iterable=chunks)
52+
super().__init__(iterator=chunks)
5353

5454
if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
5555
content = content() # type: ignore[unreachable]

litestar/response/streaming.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def async_iterator_to_generator(
5050
class ASGIStreamingResponse(ASGIResponse):
5151
"""A streaming response."""
5252

53-
__slots__ = ("_original_generator", "disconnect_event", "iterator")
53+
__slots__ = ("disconnect_event", "iterator")
5454

5555
_should_set_content_length = False
5656

@@ -93,18 +93,7 @@ def __init__(
9393
)
9494

9595
self.disconnect_event = Event()
96-
self.iterator: AsyncGenerator[str | bytes, None]
97-
98-
if isinstance(iterator, (AsyncIteratorWrapper, AsyncGenerator)):
99-
self.iterator = iterator
100-
elif isinstance(iterator, (Iterable, Iterator)):
101-
self.iterator = AsyncIteratorWrapper(iterator)
102-
elif isinstance(iterator, (AsyncIterable, AsyncIterator)):
103-
self.iterator = async_iterator_to_generator(iterator)
104-
105-
self._original_generator = (
106-
iterator.content_async_iterator if hasattr(iterator, "content_async_iterator") else self.iterator # pyright: ignore[reportAttributeAccessIssue]
107-
)
96+
self.iterator = iterator if isinstance(iterator, AsyncIterable) else AsyncIteratorWrapper(iterator)
10897

10998
async def _listen_for_disconnect(self, receive: Receive) -> None:
11099
"""Listen for a cancellation message, and if received - call cancel on the cancel scope.
@@ -133,8 +122,11 @@ async def _stream(self, send: Send) -> None:
133122
async for chunk in self.iterator:
134123
if self.disconnect_event.is_set():
135124
try:
136-
await self.iterator.athrow(ClientDisconnectError)
137-
except BaseException: # noqa: BLE001, S110
125+
if hasattr(self.iterator, "content_async_iterator"):
126+
await self.iterator.content_async_iterator.athrow(ClientDisconnectError) # pyright: ignore[reportAttributeAccessIssue]
127+
elif isinstance(self.iterator, AsyncGenerator):
128+
await self.iterator.athrow(ClientDisconnectError)
129+
except (ClientDisconnectError, StopAsyncIteration):
138130
pass
139131
finally:
140132
break # noqa: B012

litestar/utils/sync.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
S = TypeVar("S", default=None)
2424

2525

26-
def iterable_to_generator(iterator: Iterable[T]) -> Generator[T, S, None]:
26+
def iterable_to_generator(iterable: Iterable[T]) -> Generator[T, S, None]:
2727
"""Convert an iterable to a generator.
2828
2929
Args:
30-
iterator: An iterable.
30+
iterable: An iterable.
3131
3232
Returns:
3333
A generator.
3434
"""
35-
yield from iterator
35+
yield from iterable
3636

3737

3838
@overload
@@ -71,22 +71,22 @@ class AsyncIteratorWrapper(AsyncGenerator[T, S]):
7171

7272
__slots__ = ("_original_generator", "generator", "iterator")
7373

74-
def __init__(self, iterable: Iterable[T]) -> None:
75-
"""Take a sync iterator or iterable and yields values from it asynchronously.
74+
def __init__(self, iterator: Iterable[T]) -> None:
75+
"""Take a sync iterable and yields values from it asynchronously.
7676
7777
Args:
78-
iterable: A sync iterable.
78+
iterator: A sync iterable.
7979
"""
8080
self._original_generator: Generator[T, S, None]
8181

82-
if isinstance(iterable, Generator):
83-
self._original_generator = iterable
84-
elif isinstance(iterable, AsyncIteratorWrapper):
85-
self._original_generator = iterable._original_generator
82+
if isinstance(iterator, Generator):
83+
self._original_generator = iterator
84+
elif isinstance(iterator, AsyncIteratorWrapper):
85+
self._original_generator = iterator._original_generator
8686
else:
87-
self._original_generator = iterable_to_generator(iterable)
87+
self._original_generator = iterable_to_generator(iterator)
8888

89-
self.iterator = iter(iterable)
89+
self.iterator = iter(iterator)
9090
self.generator = self._async_generator()
9191

9292
def _call_next(self) -> T:
@@ -109,10 +109,10 @@ async def __anext__(self) -> T:
109109
return await self.generator.__anext__()
110110

111111
async def aclose(self) -> None:
112-
self._original_generator.close()
112+
await sync_to_thread(self._original_generator.close)
113113

114114
async def asend(self, value: S) -> T:
115-
return self._original_generator.send(value)
115+
return await sync_to_thread(self._original_generator.send, value)
116116

117117
async def athrow(
118118
self,
@@ -122,9 +122,9 @@ async def athrow(
122122
) -> T:
123123
try:
124124
return (
125-
self._original_generator.throw(typ)
125+
await sync_to_thread(self._original_generator.throw, typ)
126126
if isinstance(typ, BaseException)
127-
else self._original_generator.throw(typ, val, tb)
127+
else await sync_to_thread(self._original_generator.throw, typ, val, tb)
128128
)
129129
except StopIteration as e:
130130
raise StopAsyncIteration from e

tests/unit/test_response/test_response_to_asgi_response.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __next__(self) -> str:
6363

6464
class MyAsyncIterator(AsyncIteratorWrapper[str]):
6565
def __init__(self) -> None:
66-
super().__init__(iterable=MySyncIterator())
66+
super().__init__(iterator=MySyncIterator())
6767

6868

6969
async def test_to_response_returning_litestar_response() -> None:

0 commit comments

Comments
 (0)