Skip to content
2 changes: 1 addition & 1 deletion litestar/response/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.content_async_iterator = AsyncIteratorWrapper([content])
elif isinstance(content, (Iterable, Iterator)):
self.content_async_iterator = AsyncIteratorWrapper(content)
elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
elif isinstance(content, AsyncIterable):
self.content_async_iterator = content
else:
raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")
Expand Down
49 changes: 30 additions & 19 deletions litestar/response/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import itertools
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union

from anyio import CancelScope, create_task_group
from anyio import Event, create_task_group

from litestar.enums import MediaType
from litestar.response.base import ASGIResponse, Response
Expand All @@ -25,11 +25,17 @@
"Stream",
)

T = TypeVar("T")


class ClientDisconnectError(Exception):
"""Exception raised when the client disconnects."""


class ASGIStreamingResponse(ASGIResponse):
"""A streaming response."""

__slots__ = ("iterator",)
__slots__ = ("disconnect_event", "iterator")

_should_set_content_length = False

Expand Down Expand Up @@ -70,28 +76,24 @@ def __init__(
media_type=media_type,
status_code=status_code,
)
self.iterator: AsyncIterable[str | bytes] | AsyncGenerator[str | bytes, None] = (
iterator if isinstance(iterator, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(iterator)
)

async def _listen_for_disconnect(self, cancel_scope: CancelScope, receive: Receive) -> None:
self.disconnect_event = Event()
self.iterator = iterator if isinstance(iterator, AsyncIterable) else AsyncIteratorWrapper(iterator)

async def _listen_for_disconnect(self, receive: Receive) -> None:
"""Listen for a cancellation message, and if received - call cancel on the cancel scope.

Args:
cancel_scope: A task group cancel scope instance.
receive: The ASGI receive function.

Returns:
None
"""
if not cancel_scope.cancel_called:
message = await receive()
if message["type"] == "http.disconnect":
# despite the IDE warning, this is not a coroutine because anyio 3+ changed this.
# therefore make sure not to await this.
cancel_scope.cancel()
else:
await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive)
while message := await receive():
if message["type"].endswith(".disconnect"):
break

self.disconnect_event.set()

async def _stream(self, send: Send) -> None:
"""Send the chunks from the iterator as a stream of ASGI 'http.response.body' events.
Expand All @@ -103,6 +105,16 @@ async def _stream(self, send: Send) -> None:
None
"""
async for chunk in self.iterator:
if self.disconnect_event.is_set():
try:
if isinstance(self.iterator.content_async_iterator, AsyncGenerator): # type: ignore[attr-defined]
await self.iterator.content_async_iterator.athrow(ClientDisconnectError) # type: ignore[attr-defined]
Comment on lines +110 to +111
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I do not understand why you have to be so over specific here. It should be good enough to simply check:

  • Is it an async-generator? athrow
  • Is it an async-iterable? Just break

Copy link
Copy Markdown
Contributor Author

@winstxnhdw winstxnhdw Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is how _ServerSentEventIterator is designed. The content is kept in self.content_async_iterator, which means to close the original generator from the handler, we need to specifically throw self.iterator.content_async_iterator because self.iterator will only contain the header chunks like event_id, event_type, etc.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ServerSentEventIterator should handle that though, not the code that's calling it. It knows about its internal state, and is best suited to decide where to throw what exception. From the outside, it should be treated as any other async iterator / generator.

elif isinstance(self.iterator, AsyncGenerator):
await self.iterator.athrow(ClientDisconnectError)
except (ClientDisconnectError, StopAsyncIteration):
pass
finally:
break # noqa: B012
stream_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding),
Expand All @@ -122,10 +134,9 @@ async def send_body(self, send: Send, receive: Receive) -> None:
Returns:
None
"""

async with create_task_group() as task_group:
task_group.start_soon(partial(self._stream, send))
await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)
task_group.start_soon(partial(self._listen_for_disconnect, receive))
await self._stream(send)


class Stream(Response[StreamType[Union[str, bytes]]]):
Expand Down
66 changes: 55 additions & 11 deletions litestar/utils/sync.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
from __future__ import annotations

from collections.abc import AsyncGenerator, Awaitable, Iterable, Iterator
from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable
from typing import (
TYPE_CHECKING,
Callable,
Generic,
TypeVar,
overload,
)

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeVar

from litestar.concurrency import sync_to_thread
from litestar.utils.predicates import is_async_callable

if TYPE_CHECKING:
from types import TracebackType

__all__ = ("AsyncCallable", "AsyncIteratorWrapper", "ensure_async_callable", "is_async_callable")


P = ParamSpec("P")
T = TypeVar("T")
S = TypeVar("S", default=None)


def _iterable_to_generator(iterable: Iterable[T]) -> Generator[T, S, None]:
"""Convert an iterable to a generator.

Args:
iterable: An iterable.

Returns:
A generator.
"""
yield from iterable


@overload
Expand Down Expand Up @@ -51,18 +66,27 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: # type:
return sync_to_thread(self.func, *args, **kwargs) # type: ignore[arg-type]


class AsyncIteratorWrapper(Generic[T]):
class AsyncIteratorWrapper(AsyncGenerator[T, S]):
"""Asynchronous generator, wrapping an iterable or iterator."""

__slots__ = ("generator", "iterator")
__slots__ = ("_original_generator", "generator", "iterator")

def __init__(self, iterator: Iterator[T] | Iterable[T]) -> None:
"""Take a sync iterator or iterable and yields values from it asynchronously.
def __init__(self, iterator: Iterable[T]) -> None:
"""Take a sync iterable and yields values from it asynchronously.

Args:
iterator: A sync iterator or iterable.
iterator: A sync iterable.
"""
self.iterator = iterator if isinstance(iterator, Iterator) else iter(iterator)
self._original_generator: Generator[T, S, None]

if isinstance(iterator, Generator):
self._original_generator = iterator
elif isinstance(iterator, AsyncIteratorWrapper):
self._original_generator = iterator._original_generator
Comment on lines +84 to +85
Copy link
Copy Markdown
Member

@provinzkraut provinzkraut Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fundamentally changes the functionality of AsyncIteratorWrapper. It's supposed to wrap a synchronous iterable / iterator, and turn it into an async iterable. Now with the added functionality of also wrapping instances of itself (?), it gets really confusing.

else:
self._original_generator = _iterable_to_generator(iterator)

self.iterator = iter(iterator)
self.generator = self._async_generator()

def _call_next(self) -> T:
Expand All @@ -78,8 +102,28 @@ async def _async_generator(self) -> AsyncGenerator[T, None]:
except ValueError:
return

def __aiter__(self) -> AsyncIteratorWrapper[T]:
def __aiter__(self) -> AsyncIteratorWrapper[T, S]:
return self

async def __anext__(self) -> T:
return await self.generator.__anext__()

async def aclose(self) -> None:
await sync_to_thread(self._original_generator.close)

async def asend(self, value: S) -> T:
return await sync_to_thread(self._original_generator.send, value)

async def athrow(
self,
typ: BaseException | type[BaseException],
val: BaseException | object = None,
tb: TracebackType | None = None,
) -> T:
try:
return await sync_to_thread(
self._original_generator.throw,
typ if isinstance(typ, BaseException) else typ(val, tb),
)
except StopIteration as e:
raise StopAsyncIteration from e
36 changes: 36 additions & 0 deletions tests/unit/test_response/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import AsyncIterator

from anyio import Path, open_file, sleep
from msgspec import Struct

from litestar import Litestar, post
from litestar.response import ServerSentEvent
from litestar.response.streaming import ClientDisconnectError


class CleanupRequest(Struct):
file_path: str
file_content: str


@post("/cleanup")
async def get_notified(data: CleanupRequest) -> ServerSentEvent:
async with await open_file(data.file_path, "w") as file:
await file.write(data.file_content)

async def generator() -> AsyncIterator[str]:
try:
for _ in range(10):
yield data.file_content
await sleep(0.1)
except ClientDisconnectError:
await Path(data.file_path).unlink()

return ServerSentEvent(generator())


def create_test_app() -> Litestar:
return Litestar(route_handlers=[get_notified])


app = create_test_app()
34 changes: 33 additions & 1 deletion tests/unit/test_response/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pathlib
from collections.abc import AsyncIterator, Iterator

import anyio
import httpx
import pytest
from httpx_sse import ServerSentEvent as HTTPXServerSentEvent
from httpx_sse import aconnect_sse
Expand All @@ -9,9 +11,18 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response import ServerSentEvent
from litestar.response.sse import ServerSentEventMessage
from litestar.testing import create_async_test_client
from litestar.testing import create_async_test_client, subprocess_async_client
from litestar.types import SSEData

ROOT = pathlib.Path(__file__).parent
APP = "demo:app"


@pytest.fixture(name="async_client")
async def fx_async_client() -> AsyncIterator[httpx.AsyncClient]:
async with subprocess_async_client(workdir=ROOT, app=APP) as client:
yield client


async def test_sse_steaming_response() -> None:
@get(
Expand Down Expand Up @@ -96,6 +107,27 @@ async def numbers() -> AsyncIterator[SSEData]:
assert events[i].retry == expected_events[i].retry


async def test_sse_cleanup(async_client: httpx.AsyncClient, tmp_path: pathlib.Path) -> None:
file_content = "cleanup"
file_path = tmp_path / "cleanup_file.txt"
data = {
"file_path": str(file_path),
"file_content": file_content,
}

assert not file_path.exists()

async with aconnect_sse(async_client, "POST", "/cleanup", json=data) as event_source:
async for sse in event_source.aiter_sse():
async with await anyio.open_file(file_path, "r") as file:
assert sse.data == file_content
assert await file.read() == file_content
break

await anyio.sleep(3)
assert not file_path.exists()


def test_invalid_content_type_raises() -> None:
with pytest.raises(ImproperlyConfiguredException):
ServerSentEvent(content=object()) # type: ignore[arg-type]
Loading
Loading