Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/release-notes/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
applications already using Pydantic 2. Applications using Pydantic 1 or a mix
between 1 and 2 should upgrade to Pydantic 2.

.. change:: Add ``ping_interval`` to ``ServerSentEvent`` for keepalive pings
:type: feature
:pr: 4623
:issue: 4082

Added optional ``ping_interval`` parameter to :class:`ServerSentEvent <litestar.response.ServerSentEvent>` that
sends SSE comment keepalive pings at the specified interval to prevent connection timeouts from reverse proxies
or clients.

.. change:: Remove logging config and related constructs
:type: feature
:breaking:
Expand Down
22 changes: 22 additions & 0 deletions docs/usage/responses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,28 @@ If you want to send a different event type, you can use a dictionary with the ke

You can further customize all the sse parameters, add comments, and set the retry duration by using the :class:`ServerSentEvent <.response.ServerSentEvent>` class directly or by using the :class:`ServerSentEventMessage <.response.ServerSentEventMessage>` or dictionaries with the appropriate keys.

To prevent reverse proxies or clients from closing idle SSE connections, use the ``ping_interval`` parameter:

.. code-block:: python

@get("/stream")
async def stream_handler() -> ServerSentEvent:
async def generator():
while True:
data = await get_data()
yield data

return ServerSentEvent(generator(), ping_interval=15)

The ``ping_interval`` value is in **seconds**. This sends an SSE comment (``: ping``) every 15 seconds to keep the
connection alive. SSE comments are invisible to ``EventSource`` clients and will not trigger message events.
Pings begin after the first interval elapses (no immediate ping on connect).

.. tip::

Common values are 15–30 seconds, depending on your reverse proxy's idle timeout
(e.g., nginx defaults to 60 seconds, Telegram Mini Apps time out after 60 seconds).


Template Responses
------------------
Expand Down
133 changes: 130 additions & 3 deletions litestar/response/sse.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
from __future__ import annotations

import itertools
import re
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator
from dataclasses import dataclass
from functools import partial
from io import StringIO
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import anyio

from litestar.concurrency import sync_to_thread
from litestar.enums import MediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import Stream
from litestar.response.streaming import ASGIStreamingResponse, Stream
from litestar.utils import AsyncIteratorWrapper
from litestar.utils.helpers import get_enum_string_value

__all__ = ("ASGIStreamingSSEResponse", "ServerSentEvent", "ServerSentEventMessage")

if TYPE_CHECKING:
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType
from litestar.connection import Request
from litestar.datastructures.cookie import Cookie
from litestar.response.base import ASGIResponse
from litestar.types import Receive, ResponseCookies, ResponseHeaders, Send, SSEData, StreamType, TypeEncodersMap

_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"
Expand Down Expand Up @@ -129,6 +140,55 @@ def encode(self) -> bytes:
return buffer.getvalue().encode("utf-8")


class ASGIStreamingSSEResponse(ASGIStreamingResponse):
"""ASGI streaming response with optional keepalive ping support for SSE."""

__slots__ = ("_ping_interval", "_send_lock")

def __init__(self, *, ping_interval: float | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._ping_interval = ping_interval
self._send_lock = anyio.Lock() if ping_interval is not None else None

async def _send(self, send: Send, payload: bytes) -> None:
"""Send a body chunk with lock for concurrent ping/stream safety."""
if self._send_lock is None:
raise RuntimeError("_send called without a send lock; ping_interval must be set")
async with self._send_lock:
await send({"type": "http.response.body", "body": payload, "more_body": True})

async def _ping(self, send: Send, stop_event: anyio.Event) -> None:
"""Send SSE comment keepalive pings at the configured interval."""
if self._ping_interval is None:
raise RuntimeError("_ping called without a ping interval configured")
while not stop_event.is_set():
with anyio.move_on_after(self._ping_interval):
await stop_event.wait()
if not stop_event.is_set():
await self._send(send, b": ping\r\n\r\n")

async def send_body(self, send: Send, receive: Receive) -> None:
"""Emit the response body, with optional keepalive pings."""
if self._ping_interval is None:
await super().send_body(send, receive)
return

stop_event = anyio.Event()

async with anyio.create_task_group() as tg:
tg.start_soon(partial(self._listen_for_disconnect, tg.cancel_scope, receive))
tg.start_soon(self._ping, send, stop_event)

async for chunk in self.iterator:
data = chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding)
await self._send(send, data)

stop_event.set()
tg.cancel_scope.cancel()

await send({"type": "http.response.body", "body": b"", "more_body": False})


class ServerSentEvent(Stream):
def __init__(
self,
Expand All @@ -143,6 +203,7 @@ def __init__(
retry_duration: int | None = None,
comment_message: str | None = None,
status_code: int | None = None,
ping_interval: float | None = None,
) -> None:
"""Initialize the response.

Expand All @@ -161,7 +222,13 @@ def __init__(
event_id: The event ID. This sets the event source's 'last event id'.
retry_duration: Retry duration in milliseconds.
comment_message: A comment message. This value is ignored by clients and is used mostly for pinging.
ping_interval: Interval in seconds between keepalive pings. When set, an SSE comment
(``: ping``) is sent at the specified interval to prevent connection timeouts from
reverse proxies or clients. Defaults to ``None`` (no pings).
"""
if ping_interval is not None and ping_interval <= 0:
raise ImproperlyConfiguredException("ping_interval must be a positive number")
self.ping_interval = ping_interval
super().__init__(
content=_ServerSentEventIterator(
content=content,
Expand All @@ -180,3 +247,63 @@ def __init__(
self.headers.setdefault("Cache-Control", "no-cache")
self.headers["Connection"] = "keep-alive"
self.headers["X-Accel-Buffering"] = "no"

def to_asgi_response(
self,
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: Iterable[Cookie] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
type_encoders: TypeEncodersMap | None = None,
) -> ASGIResponse:
"""Create an ASGI streaming response, with optional keepalive ping support.

When ``ping_interval`` is set, returns an :class:`ASGIStreamingSSEResponse` that
sends periodic SSE comment pings. Otherwise delegates to the parent implementation.

Args:
request: The :class:`Request <.connection.Request>` instance.
background: Background task(s) to be executed after the response is sent.
cookies: A list of cookies to be set on the response.
headers: Additional headers to be merged with the response headers. Response headers take precedence.
is_head_response: Whether the response is a HEAD response.
media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored.
status_code: Status code for the response. If ``status_code`` is already set on the response, this is
ignored.
type_encoders: A dictionary of type encoders to use for encoding the response content.

Returns:
An ASGIStreamingResponse (or ASGIStreamingSSEResponse when ping_interval is set).
"""
if self.ping_interval is None:
return super().to_asgi_response(
request,
background=background,
cookies=cookies,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=status_code,
type_encoders=type_encoders,
)

headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)
media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON)

return ASGIStreamingSSEResponse(
ping_interval=self.ping_interval,
background=self.background or background,
content_length=0,
cookies=cookies,
encoding=self.encoding,
headers=headers,
is_head_response=is_head_response,
iterator=self.iterator,
media_type=media_type,
status_code=self.status_code or status_code,
)
Loading
Loading