diff --git a/docs/release-notes/changelog.rst b/docs/release-notes/changelog.rst index 02fc9335c2..111a3b3df4 100644 --- a/docs/release-notes/changelog.rst +++ b/docs/release-notes/changelog.rst @@ -16,6 +16,15 @@ applications already using Pydantic 2. Applications using Pydantic 1 or a mix between 1 and 2 should upgrade to Pydantic 2. + .. change:: Add ``ping_interval`` to ``ServerSentEvent`` for keepalive pings + :type: feature + :pr: 4623 + :issue: 4082 + + Added optional ``ping_interval`` parameter to :class:`ServerSentEvent ` that + sends SSE comment keepalive pings at the specified interval to prevent connection timeouts from reverse proxies + or clients. + .. change:: Remove logging config and related constructs :type: feature :breaking: diff --git a/docs/usage/responses.rst b/docs/usage/responses.rst index f965216a69..64436f57c1 100644 --- a/docs/usage/responses.rst +++ b/docs/usage/responses.rst @@ -721,6 +721,28 @@ If you want to send a different event type, you can use a dictionary with the ke You can further customize all the sse parameters, add comments, and set the retry duration by using the :class:`ServerSentEvent <.response.ServerSentEvent>` class directly or by using the :class:`ServerSentEventMessage <.response.ServerSentEventMessage>` or dictionaries with the appropriate keys. +To prevent reverse proxies or clients from closing idle SSE connections, use the ``ping_interval`` parameter: + +.. code-block:: python + + @get("/stream") + async def stream_handler() -> ServerSentEvent: + async def generator(): + while True: + data = await get_data() + yield data + + return ServerSentEvent(generator(), ping_interval=15) + +The ``ping_interval`` value is in **seconds**. This sends an SSE comment (``: ping``) every 15 seconds to keep the +connection alive. SSE comments are invisible to ``EventSource`` clients and will not trigger message events. +Pings begin after the first interval elapses (no immediate ping on connect). + +.. tip:: + + Common values are 15–30 seconds, depending on your reverse proxy's idle timeout + (e.g., nginx defaults to 60 seconds, Telegram Mini Apps time out after 60 seconds). + Template Responses ------------------ diff --git a/litestar/response/sse.py b/litestar/response/sse.py index 512044b1eb..a760dbb559 100644 --- a/litestar/response/sse.py +++ b/litestar/response/sse.py @@ -1,19 +1,30 @@ from __future__ import annotations +import itertools import re from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator from dataclasses import dataclass +from functools import partial from io import StringIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +import anyio from litestar.concurrency import sync_to_thread +from litestar.enums import MediaType from litestar.exceptions import ImproperlyConfiguredException -from litestar.response.streaming import Stream +from litestar.response.streaming import ASGIStreamingResponse, Stream from litestar.utils import AsyncIteratorWrapper +from litestar.utils.helpers import get_enum_string_value + +__all__ = ("ASGIStreamingSSEResponse", "ServerSentEvent", "ServerSentEventMessage") if TYPE_CHECKING: from litestar.background_tasks import BackgroundTask, BackgroundTasks - from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType + from litestar.connection import Request + from litestar.datastructures.cookie import Cookie + from litestar.response.base import ASGIResponse + from litestar.types import Receive, ResponseCookies, ResponseHeaders, Send, SSEData, StreamType, TypeEncodersMap _LINE_BREAK_RE = re.compile(r"\r\n|\r|\n") DEFAULT_SEPARATOR = "\r\n" @@ -129,6 +140,55 @@ def encode(self) -> bytes: return buffer.getvalue().encode("utf-8") +class ASGIStreamingSSEResponse(ASGIStreamingResponse): + """ASGI streaming response with optional keepalive ping support for SSE.""" + + __slots__ = ("_ping_interval", "_send_lock") + + def __init__(self, *, ping_interval: float | None = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._ping_interval = ping_interval + self._send_lock = anyio.Lock() if ping_interval is not None else None + + async def _send(self, send: Send, payload: bytes) -> None: + """Send a body chunk with lock for concurrent ping/stream safety.""" + if self._send_lock is None: + raise RuntimeError("_send called without a send lock; ping_interval must be set") + async with self._send_lock: + await send({"type": "http.response.body", "body": payload, "more_body": True}) + + async def _ping(self, send: Send, stop_event: anyio.Event) -> None: + """Send SSE comment keepalive pings at the configured interval.""" + if self._ping_interval is None: + raise RuntimeError("_ping called without a ping interval configured") + while not stop_event.is_set(): + with anyio.move_on_after(self._ping_interval): + await stop_event.wait() + if not stop_event.is_set(): + await self._send(send, b": ping\r\n\r\n") + + async def send_body(self, send: Send, receive: Receive) -> None: + """Emit the response body, with optional keepalive pings.""" + if self._ping_interval is None: + await super().send_body(send, receive) + return + + stop_event = anyio.Event() + + async with anyio.create_task_group() as tg: + tg.start_soon(partial(self._listen_for_disconnect, tg.cancel_scope, receive)) + tg.start_soon(self._ping, send, stop_event) + + async for chunk in self.iterator: + data = chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding) + await self._send(send, data) + + stop_event.set() + tg.cancel_scope.cancel() + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + class ServerSentEvent(Stream): def __init__( self, @@ -143,6 +203,7 @@ def __init__( retry_duration: int | None = None, comment_message: str | None = None, status_code: int | None = None, + ping_interval: float | None = None, ) -> None: """Initialize the response. @@ -161,7 +222,13 @@ def __init__( event_id: The event ID. This sets the event source's 'last event id'. retry_duration: Retry duration in milliseconds. comment_message: A comment message. This value is ignored by clients and is used mostly for pinging. + ping_interval: Interval in seconds between keepalive pings. When set, an SSE comment + (``: ping``) is sent at the specified interval to prevent connection timeouts from + reverse proxies or clients. Defaults to ``None`` (no pings). """ + if ping_interval is not None and ping_interval <= 0: + raise ImproperlyConfiguredException("ping_interval must be a positive number") + self.ping_interval = ping_interval super().__init__( content=_ServerSentEventIterator( content=content, @@ -180,3 +247,63 @@ def __init__( self.headers.setdefault("Cache-Control", "no-cache") self.headers["Connection"] = "keep-alive" self.headers["X-Accel-Buffering"] = "no" + + def to_asgi_response( + self, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + """Create an ASGI streaming response, with optional keepalive ping support. + + When ``ping_interval`` is set, returns an :class:`ASGIStreamingSSEResponse` that + sends periodic SSE comment pings. Otherwise delegates to the parent implementation. + + Args: + request: The :class:`Request <.connection.Request>` instance. + background: Background task(s) to be executed after the response is sent. + cookies: A list of cookies to be set on the response. + headers: Additional headers to be merged with the response headers. Response headers take precedence. + is_head_response: Whether the response is a HEAD response. + media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. + status_code: Status code for the response. If ``status_code`` is already set on the response, this is + ignored. + type_encoders: A dictionary of type encoders to use for encoding the response content. + + Returns: + An ASGIStreamingResponse (or ASGIStreamingSSEResponse when ping_interval is set). + """ + if self.ping_interval is None: + return super().to_asgi_response( + request, + background=background, + cookies=cookies, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=status_code, + type_encoders=type_encoders, + ) + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON) + + return ASGIStreamingSSEResponse( + ping_interval=self.ping_interval, + background=self.background or background, + content_length=0, + cookies=cookies, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + iterator=self.iterator, + media_type=media_type, + status_code=self.status_code or status_code, + ) diff --git a/tests/unit/test_response/test_sse.py b/tests/unit/test_response/test_sse.py index e825b9ec28..77976a5277 100644 --- a/tests/unit/test_response/test_sse.py +++ b/tests/unit/test_response/test_sse.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from collections.abc import AsyncIterator, Iterator +from typing import TYPE_CHECKING import anyio import pytest @@ -8,10 +11,13 @@ from litestar import get from litestar.exceptions import ImproperlyConfiguredException from litestar.response import ServerSentEvent -from litestar.response.sse import ServerSentEventMessage +from litestar.response.sse import ASGIStreamingSSEResponse, ServerSentEventMessage from litestar.testing import create_async_test_client from litestar.types import SSEData +if TYPE_CHECKING: + from litestar.types.asgi_types import HTTPDisconnectEvent, Message + async def test_sse_steaming_response() -> None: @get( @@ -99,3 +105,258 @@ async def numbers() -> AsyncIterator[SSEData]: def test_invalid_content_type_raises() -> None: with pytest.raises(ImproperlyConfiguredException): ServerSentEvent(content=object()) # type: ignore[arg-type] + + +async def test_sse_without_ping_interval_works_unchanged() -> None: + """Regression test: SSE without ping_interval behaves exactly as before.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + for i in range(3): + yield str(i) + + return ServerSentEvent(gen()) + + async with create_async_test_client(handler) as client: + async with aconnect_sse(client, "GET", f"{client.base_url}/test") as event_source: + events = [sse async for sse in event_source.aiter_sse()] + assert len(events) == 3 + for idx, sse in enumerate(events): + assert sse.data == str(idx) + + +async def test_sse_ping_interval_sends_keepalive_comments() -> None: + """SSE with ping_interval sends keepalive comments during idle periods.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + await anyio.sleep(0.35) + yield "done" + + return ServerSentEvent(gen(), ping_interval=0.1) + + async with create_async_test_client(handler) as client: + response = await client.get("/test") + body = response.content.decode() + # Should contain at least one ping comment + assert ": ping\r\n\r\n" in body + # Should also contain the actual data + assert "data: done" in body + + +async def test_sse_ping_uses_comments_not_events() -> None: + """Pings must be SSE comments (start with ':'), not 'event: ping'.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + await anyio.sleep(0.25) + yield "data" + + return ServerSentEvent(gen(), ping_interval=0.1) + + async with create_async_test_client(handler) as client: + response = await client.get("/test") + body = response.content.decode() + assert ": ping\r\n" in body + assert "event: ping" not in body + + +async def test_sse_ping_stops_when_stream_ends() -> None: + """Ping task should stop cleanly when the stream ends without task leaks.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + yield "hello" + yield "world" + + return ServerSentEvent(gen(), ping_interval=0.1) + + async with create_async_test_client(handler) as client: + async with aconnect_sse(client, "GET", f"{client.base_url}/test") as event_source: + events = [sse async for sse in event_source.aiter_sse()] + assert len(events) == 2 + assert events[0].data == "hello" + assert events[1].data == "world" + + +async def test_sse_concurrent_ping_and_data() -> None: + """Rapid data emission with short ping interval should not corrupt the response.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + for i in range(20): + await anyio.sleep(0.02) + yield str(i) + + return ServerSentEvent(gen(), ping_interval=0.05) + + async with create_async_test_client(handler) as client: + response = await client.get("/test") + body = response.content.decode() + # All 20 data events should be present + for i in range(20): + assert f"data: {i}\r\n" in body + + +def test_sse_ping_interval_rejects_zero() -> None: + """ping_interval=0 should raise ImproperlyConfiguredException.""" + + async def gen() -> AsyncIterator[str]: + yield "data" + + with pytest.raises(ImproperlyConfiguredException, match="ping_interval must be a positive number"): + ServerSentEvent(gen(), ping_interval=0) + + +def test_sse_ping_interval_rejects_negative() -> None: + """ping_interval=-1 should raise ImproperlyConfiguredException.""" + + async def gen() -> AsyncIterator[str]: + yield "data" + + with pytest.raises(ImproperlyConfiguredException, match="ping_interval must be a positive number"): + ServerSentEvent(gen(), ping_interval=-1) + + +async def test_sse_ping_with_empty_generator() -> None: + """Ping task shuts down cleanly when generator yields nothing.""" + + async def empty_gen() -> AsyncIterator[str]: + return + yield # make it an async generator + + @get("/test") + async def handler() -> ServerSentEvent: + return ServerSentEvent(empty_gen(), ping_interval=1) + + async with create_async_test_client(handler) as client: + response = await client.get("/test") + assert response.status_code == 200 + + +async def test_sse_large_ping_interval_no_pings_sent() -> None: + """With a very large ping_interval, no pings should be sent for a short-lived stream.""" + + @get("/test") + async def handler() -> ServerSentEvent: + async def gen() -> AsyncIterator[str]: + yield "data1" + yield "data2" + + return ServerSentEvent(gen(), ping_interval=99999) + + async with create_async_test_client(handler) as client: + response = await client.get("/test") + assert response.status_code == 200 + assert b": ping" not in response.content + assert b"data: data1" in response.content + assert b"data: data2" in response.content + + +async def test_sse_ping_with_str_chunks() -> None: + """ASGIStreamingSSEResponse handles str chunks correctly when ping is enabled.""" + + async def str_iterator() -> AsyncIterator[str]: + yield "hello" + yield "world" + + response = ASGIStreamingSSEResponse( + iterator=str_iterator(), + ping_interval=0.1, + media_type="text/event-stream", + status_code=200, + ) + + received: list[bytes] = [] + + async def mock_send(message: Message) -> None: + if message.get("type") == "http.response.body": + body = message.get("body", b"") + received.append(body if isinstance(body, bytes) else b"") + + async def mock_receive() -> HTTPDisconnectEvent: + await anyio.sleep(10) + return {"type": "http.disconnect"} + + await response.send_body(mock_send, mock_receive) + + body = b"".join(received).decode() + assert "hello" in body + assert "world" in body + + +async def test_sse_send_raises_without_ping_interval() -> None: + """_send requires a lock (and therefore ping_interval); RuntimeError if missing.""" + + async def empty() -> AsyncIterator[str]: + return + yield + + response = ASGIStreamingSSEResponse( + iterator=empty(), + media_type="text/event-stream", + status_code=200, + ) + + async def mock_send(message: Message) -> None: + pass + + with pytest.raises(RuntimeError, match="_send called without a send lock"): + await response._send(mock_send, b"data") + + +async def test_sse_ping_raises_without_ping_interval() -> None: + """_ping requires ping_interval; RuntimeError if called without one.""" + + async def empty() -> AsyncIterator[str]: + return + yield + + response = ASGIStreamingSSEResponse( + iterator=empty(), + media_type="text/event-stream", + status_code=200, + ) + + async def mock_send(message: Message) -> None: + pass + + with pytest.raises(RuntimeError, match="_ping called without a ping interval"): + await response._ping(mock_send, anyio.Event()) + + +async def test_sse_send_body_delegates_to_parent_without_ping() -> None: + """Without ping_interval, send_body falls through to the parent (no lock, no task group).""" + + async def gen() -> AsyncIterator[str]: + yield "chunk1" + yield "chunk2" + + response = ASGIStreamingSSEResponse( + iterator=gen(), + media_type="text/event-stream", + status_code=200, + ) + + received: list[bytes] = [] + + async def mock_send(message: Message) -> None: + if message.get("type") == "http.response.body": + body = message.get("body", b"") + received.append(body if isinstance(body, bytes) else b"") + + async def mock_receive() -> HTTPDisconnectEvent: + await anyio.sleep(10) + return {"type": "http.disconnect"} + + await response.send_body(mock_send, mock_receive) + + body = b"".join(received).decode() + assert "chunk1" in body + assert "chunk2" in body + assert ": ping" not in body