Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 118 additions & 24 deletions litestar/middleware/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None:
)
self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler)
self.config = config
self.max_requests: int = config.rate_limit[1]
self.unit: DurationUnit = config.rate_limit[0]
self.rate_limits: list[tuple[DurationUnit, int]] = config._all_rate_limits
self.get_identifier_for_request = config.identifier_for_request

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand All @@ -95,29 +94,60 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if getattr(route_handler, "is_mount", False):
key += "::mount"

cache_object = await self.retrieve_cached_history(key, store)
if len(cache_object.history) >= self.max_requests:
raise TooManyRequestsException(
headers=self.create_response_headers(cache_object=cache_object)
if self.config.set_rate_limit_headers
else None
)
await self.set_cached_history(key=key, cache_object=cache_object, store=store)
# Check every rate limit condition before updating any cache entry so that a
# request that violates a later condition does not consume quota from earlier ones.
checked: list[tuple[DurationUnit, int, CacheObject, str]] = []
for unit, max_requests in self.rate_limits:
limit_key = f"{key}::{unit}"
cache_object = await self.retrieve_cached_history(limit_key, unit, store)
if len(cache_object.history) >= max_requests:
raise TooManyRequestsException(
headers=self.create_response_headers(
cache_object=cache_object,
max_requests=max_requests,
unit=unit,
)
if self.config.set_rate_limit_headers
else None
)
checked.append((unit, max_requests, cache_object, limit_key))

# All limits passed — persist updated histories
for unit, max_requests, cache_object, limit_key in checked:
await self.set_cached_history(key=limit_key, cache_object=cache_object, unit=unit, store=store)

if self.config.set_rate_limit_headers:
send = self.create_send_wrapper(send=send, cache_object=cache_object)
# Use the most restrictive limit (fewest remaining requests) for response headers
most_restrictive = min(checked, key=lambda x: x[1] - len(x[2].history))
r_unit, r_max, r_cache, _ = most_restrictive
send = self.create_send_wrapper(send=send, cache_object=r_cache, max_requests=r_max, unit=r_unit)

await self.app(scope, receive, send) # pyright: ignore

def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send:
def create_send_wrapper(
self,
send: Send,
cache_object: CacheObject,
max_requests: int | None = None,
unit: DurationUnit | None = None,
) -> Send:
"""Create a ``send`` function that wraps the original send to inject response headers.

Args:
send: The ASGI send function.
cache_object: A StorageObject instance.
max_requests: Maximum number of requests for the selected rate limit window.
Defaults to the first configured rate limit's max.
unit: The duration unit for the selected rate limit window.
Defaults to the first configured rate limit's unit.

Returns:
Send wrapper callable.
"""
if max_requests is None:
max_requests = self.rate_limits[0][1]
if unit is None:
unit = self.rate_limits[0][0]

async def send_wrapper(message: Message) -> None:
"""Wrap the ASGI ``Send`` callable.
Expand All @@ -131,23 +161,26 @@ async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
message.setdefault("headers", [])
headers = MutableScopeHeaders(message)
for key, value in self.create_response_headers(cache_object=cache_object).items():
for key, value in self.create_response_headers(
cache_object=cache_object, max_requests=max_requests, unit=unit
).items():
headers[key] = value
await send(message)

return send_wrapper

async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject:
async def retrieve_cached_history(self, key: str, unit: DurationUnit, store: Store) -> CacheObject:
"""Retrieve a list of time stamps for the given duration unit.

Args:
key: Cache key.
unit: The :data:`DurationUnit` for this rate limit window.
store: A :class:`Store <.stores.base.Store>`

Returns:
An :class:`CacheObject`.
"""
duration = DURATION_VALUES[self.unit]
duration = DURATION_VALUES[unit]
now = int(time())
cached_string = await store.get(key)
if cached_string:
Expand All @@ -161,19 +194,20 @@ async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject:

return CacheObject(history=[], reset=now + duration)

async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None:
async def set_cached_history(self, key: str, cache_object: CacheObject, unit: DurationUnit, store: Store) -> None:
"""Store history extended with the current timestamp in cache.

Args:
key: Cache key.
cache_object: A :class:`CacheObject`.
unit: The :data:`DurationUnit` for this rate limit window.
store: A :class:`Store <.stores.base.Store>`

Returns:
None
"""
cache_object.history = [int(time()), *cache_object.history]
await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit])
await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[unit])

async def should_check_request(self, request: Request[Any, Any, Any]) -> bool:
"""Return a boolean indicating if a request should be checked for rate limiting.
Expand All @@ -188,25 +222,39 @@ async def should_check_request(self, request: Request[Any, Any, Any]) -> bool:
return await self.check_throttle_handler(request)
return True

def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]:
def create_response_headers(
self,
cache_object: CacheObject,
max_requests: int | None = None,
unit: DurationUnit | None = None,
) -> dict[str, str]:
"""Create ratelimit response headers.

Notes:
* see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>`_

Args:
cache_object:A :class:`CacheObject`.
cache_object: A :class:`CacheObject`.
max_requests: Maximum number of requests for the chosen rate limit window.
Defaults to the first configured rate limit's max.
unit: The :data:`DurationUnit` for the chosen rate limit window.
Defaults to the first configured rate limit's unit.

Returns:
A dict of http headers.
"""
if max_requests is None:
max_requests = self.rate_limits[0][1]
if unit is None:
unit = self.rate_limits[0][0]

remaining_requests = str(
self.max_requests - len(cache_object.history) if len(cache_object.history) <= self.max_requests else 0
max_requests - len(cache_object.history) if len(cache_object.history) <= max_requests else 0
)

return {
self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}",
self.config.rate_limit_limit_header_key: str(self.max_requests),
self.config.rate_limit_policy_header_key: f"{max_requests}; w={DURATION_VALUES[unit]}",
self.config.rate_limit_limit_header_key: str(max_requests),
self.config.rate_limit_remaining_header_key: remaining_requests,
self.config.rate_limit_reset_header_key: str(cache_object.reset - int(time())),
}
Expand All @@ -216,8 +264,25 @@ def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]:
class RateLimitConfig:
"""Configuration for ``RateLimitMiddleware``"""

rate_limit: tuple[DurationUnit, int]
"""A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5)."""
rate_limit: tuple[DurationUnit, int] | None = field(default=None)
"""A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ``("day", 1)`` or
``("minute", 5)``.

Use :attr:`rate_limits` to specify multiple simultaneous rate limit conditions. When both ``rate_limit``
and ``rate_limits`` are ``None`` a :exc:`ValueError` is raised at initialisation time.
"""
rate_limits: list[tuple[DurationUnit, int]] | None = field(default=None)
"""A list of ``(unit, max_requests)`` tuples that are ALL enforced simultaneously.

A ``429 Too Many Requests`` response is returned as soon as *any* condition is breached. This lets you
combine multiple time windows, for example::

RateLimitConfig(rate_limits=[("second", 10), ("minute", 100), ("hour", 2000)])

When ``rate_limit`` is also provided a :exc:`ValueError` is raised. When only ``rate_limit`` is
provided it is normalised to a single-element list internally, so all existing code continues to work
without modification.
"""
exclude: str | list[str] | None = field(default=None)
"""A pattern or list of patterns to skip in the rate limiting middleware."""
exclude_opt_key: str | None = field(default=None)
Expand Down Expand Up @@ -254,9 +319,27 @@ class RateLimitConfig:
"""Name of the :class:`Store <.stores.base.Store>` to use"""

def __post_init__(self) -> None:
if self.rate_limit is None and not self.rate_limits:
raise ValueError("Either 'rate_limit' or 'rate_limits' must be provided to RateLimitConfig.")
if self.rate_limit is not None and self.rate_limits is not None:
raise ValueError(
"Provide either 'rate_limit' or 'rate_limits' to RateLimitConfig, not both."
)
if self.check_throttle_handler:
self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type]

@property
def _all_rate_limits(self) -> list[tuple[DurationUnit, int]]:
"""Return a normalised list of ``(unit, max_requests)`` tuples.

Always use this property rather than accessing :attr:`rate_limit` or :attr:`rate_limits`
directly so that the single-limit backward-compatible form is handled transparently.
"""
if self.rate_limits is not None:
return self.rate_limits
assert self.rate_limit is not None # guarded by __post_init__
return [self.rate_limit]

@property
def middleware(self) -> DefineMiddleware:
"""Use this property to insert the config into a middleware list on one of the application layers.
Expand All @@ -277,6 +360,17 @@ def my_handler(request: Request) -> None: ...

app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware])

Multiple simultaneous conditions are also supported:

.. code-block:: python

from litestar.middleware.rate_limit import RateLimitConfig

# max 5/second AND max 100/minute AND max 1000/hour — all enforced at once
throttle_config = RateLimitConfig(
rate_limits=[("second", 5), ("minute", 100), ("hour", 1000)]
)

Returns:
An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the
config kwarg value.
Expand Down
91 changes: 89 additions & 2 deletions tests/unit/test_middleware/test_rate_limit_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def handler() -> None:
return None

app = Litestar(
[handler], middleware=[RateLimitConfig(("second", 10)).middleware], stores={"rate_limit": memory_store}
[handler], middleware=[RateLimitConfig(rate_limit=("second", 10)).middleware], stores={"rate_limit": memory_store}
)

with TestClient(app) as client:
Expand All @@ -98,7 +98,7 @@ def handler() -> None:

app = Litestar(
[handler],
middleware=[RateLimitConfig(("second", 10), store="some_store").middleware],
middleware=[RateLimitConfig(rate_limit=("second", 10), store="some_store").middleware],
stores={"some_store": memory_store},
)

Expand Down Expand Up @@ -303,3 +303,90 @@ def get_id_from_random_header(request: Request[Any, Any, Any]) -> str:

response = client.get("/", headers={"x-private-header": "value"})
assert response.status_code == HTTP_429_TOO_MANY_REQUESTS


# ---------------------------------------------------------------------------
# Tests for multi-condition rate limiting (rate_limits=)
# ---------------------------------------------------------------------------


def test_rate_limit_config_requires_at_least_one_limit() -> None:
"""RateLimitConfig should raise when neither rate_limit nor rate_limits is given."""
import pytest

with pytest.raises(ValueError, match="Either 'rate_limit' or 'rate_limits'"):
RateLimitConfig()


def test_rate_limit_config_rejects_both_fields() -> None:
"""RateLimitConfig should raise when both rate_limit and rate_limits are given."""
import pytest

with pytest.raises(ValueError, match="not both"):
RateLimitConfig(rate_limit=("second", 5), rate_limits=[("minute", 100)])


@travel(datetime.utcnow, tick=False)
def test_multiple_rate_limits_passes_when_all_satisfied() -> None:
"""Requests within all limits should succeed."""

@get("/")
def handler() -> None:
return None

config = RateLimitConfig(rate_limits=[("second", 3), ("minute", 5)])

with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client:
# First 3 requests are within both windows — all must succeed
for _ in range(3):
assert client.get("/").status_code == HTTP_200_OK


@travel(datetime.utcnow, tick=False)
def test_multiple_rate_limits_blocked_by_tighter_window() -> None:
"""The per-second limit should trigger even though the per-minute limit is not yet reached."""

@get("/")
def handler() -> None:
return None

# Allow 2/second but 100/minute — the second window is the bottleneck
config = RateLimitConfig(rate_limits=[("second", 2), ("minute", 100)])

with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client:
assert client.get("/").status_code == HTTP_200_OK
assert client.get("/").status_code == HTTP_200_OK
# Third request in the same second exceeds the per-second limit
assert client.get("/").status_code == HTTP_429_TOO_MANY_REQUESTS


@travel(datetime.utcnow, tick=False)
def test_multiple_rate_limits_blocked_by_wider_window() -> None:
"""The per-minute limit should trigger once it is exhausted, even though the per-second limit still has quota."""

@get("/")
def handler() -> None:
return None

# 5/second but only 3/minute — the minute window will be exhausted first
config = RateLimitConfig(rate_limits=[("second", 5), ("minute", 3)])

with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client:
assert client.get("/").status_code == HTTP_200_OK
assert client.get("/").status_code == HTTP_200_OK
assert client.get("/").status_code == HTTP_200_OK
# 4th request is still within the second window but exhausts the minute quota
assert client.get("/").status_code == HTTP_429_TOO_MANY_REQUESTS


def test_rate_limits_all_rate_limits_property_single() -> None:
"""_all_rate_limits returns a one-element list when only rate_limit is given."""
config = RateLimitConfig(rate_limit=("hour", 50))
assert config._all_rate_limits == [("hour", 50)]


def test_rate_limits_all_rate_limits_property_multi() -> None:
"""_all_rate_limits returns the full list when rate_limits is given."""
limits = [("second", 10), ("minute", 200), ("hour", 5000)]
config = RateLimitConfig(rate_limits=limits)
assert config._all_rate_limits == limits
Loading