diff --git a/litestar/middleware/rate_limit.py b/litestar/middleware/rate_limit.py index 0ab528a54e..cb5cc49f73 100644 --- a/litestar/middleware/rate_limit.py +++ b/litestar/middleware/rate_limit.py @@ -70,8 +70,7 @@ def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: ) self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler) self.config = config - self.max_requests: int = config.rate_limit[1] - self.unit: DurationUnit = config.rate_limit[0] + self.rate_limits: list[tuple[DurationUnit, int]] = config._all_rate_limits self.get_identifier_for_request = config.identifier_for_request async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -95,29 +94,60 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if getattr(route_handler, "is_mount", False): key += "::mount" - cache_object = await self.retrieve_cached_history(key, store) - if len(cache_object.history) >= self.max_requests: - raise TooManyRequestsException( - headers=self.create_response_headers(cache_object=cache_object) - if self.config.set_rate_limit_headers - else None - ) - await self.set_cached_history(key=key, cache_object=cache_object, store=store) + # Check every rate limit condition before updating any cache entry so that a + # request that violates a later condition does not consume quota from earlier ones. + checked: list[tuple[DurationUnit, int, CacheObject, str]] = [] + for unit, max_requests in self.rate_limits: + limit_key = f"{key}::{unit}" + cache_object = await self.retrieve_cached_history(limit_key, unit, store) + if len(cache_object.history) >= max_requests: + raise TooManyRequestsException( + headers=self.create_response_headers( + cache_object=cache_object, + max_requests=max_requests, + unit=unit, + ) + if self.config.set_rate_limit_headers + else None + ) + checked.append((unit, max_requests, cache_object, limit_key)) + + # All limits passed — persist updated histories + for unit, max_requests, cache_object, limit_key in checked: + await self.set_cached_history(key=limit_key, cache_object=cache_object, unit=unit, store=store) + if self.config.set_rate_limit_headers: - send = self.create_send_wrapper(send=send, cache_object=cache_object) + # Use the most restrictive limit (fewest remaining requests) for response headers + most_restrictive = min(checked, key=lambda x: x[1] - len(x[2].history)) + r_unit, r_max, r_cache, _ = most_restrictive + send = self.create_send_wrapper(send=send, cache_object=r_cache, max_requests=r_max, unit=r_unit) await self.app(scope, receive, send) # pyright: ignore - def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send: + def create_send_wrapper( + self, + send: Send, + cache_object: CacheObject, + max_requests: int | None = None, + unit: DurationUnit | None = None, + ) -> Send: """Create a ``send`` function that wraps the original send to inject response headers. Args: send: The ASGI send function. cache_object: A StorageObject instance. + max_requests: Maximum number of requests for the selected rate limit window. + Defaults to the first configured rate limit's max. + unit: The duration unit for the selected rate limit window. + Defaults to the first configured rate limit's unit. Returns: Send wrapper callable. """ + if max_requests is None: + max_requests = self.rate_limits[0][1] + if unit is None: + unit = self.rate_limits[0][0] async def send_wrapper(message: Message) -> None: """Wrap the ASGI ``Send`` callable. @@ -131,23 +161,26 @@ async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": message.setdefault("headers", []) headers = MutableScopeHeaders(message) - for key, value in self.create_response_headers(cache_object=cache_object).items(): + for key, value in self.create_response_headers( + cache_object=cache_object, max_requests=max_requests, unit=unit + ).items(): headers[key] = value await send(message) return send_wrapper - async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: + async def retrieve_cached_history(self, key: str, unit: DurationUnit, store: Store) -> CacheObject: """Retrieve a list of time stamps for the given duration unit. Args: key: Cache key. + unit: The :data:`DurationUnit` for this rate limit window. store: A :class:`Store <.stores.base.Store>` Returns: An :class:`CacheObject`. """ - duration = DURATION_VALUES[self.unit] + duration = DURATION_VALUES[unit] now = int(time()) cached_string = await store.get(key) if cached_string: @@ -161,19 +194,20 @@ async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: return CacheObject(history=[], reset=now + duration) - async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None: + async def set_cached_history(self, key: str, cache_object: CacheObject, unit: DurationUnit, store: Store) -> None: """Store history extended with the current timestamp in cache. Args: key: Cache key. cache_object: A :class:`CacheObject`. + unit: The :data:`DurationUnit` for this rate limit window. store: A :class:`Store <.stores.base.Store>` Returns: None """ cache_object.history = [int(time()), *cache_object.history] - await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit]) + await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[unit]) async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: """Return a boolean indicating if a request should be checked for rate limiting. @@ -188,25 +222,39 @@ async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: return await self.check_throttle_handler(request) return True - def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: + def create_response_headers( + self, + cache_object: CacheObject, + max_requests: int | None = None, + unit: DurationUnit | None = None, + ) -> dict[str, str]: """Create ratelimit response headers. Notes: * see the `IETF RateLimit draft `_ Args: - cache_object:A :class:`CacheObject`. + cache_object: A :class:`CacheObject`. + max_requests: Maximum number of requests for the chosen rate limit window. + Defaults to the first configured rate limit's max. + unit: The :data:`DurationUnit` for the chosen rate limit window. + Defaults to the first configured rate limit's unit. Returns: A dict of http headers. """ + if max_requests is None: + max_requests = self.rate_limits[0][1] + if unit is None: + unit = self.rate_limits[0][0] + remaining_requests = str( - self.max_requests - len(cache_object.history) if len(cache_object.history) <= self.max_requests else 0 + max_requests - len(cache_object.history) if len(cache_object.history) <= max_requests else 0 ) return { - self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}", - self.config.rate_limit_limit_header_key: str(self.max_requests), + self.config.rate_limit_policy_header_key: f"{max_requests}; w={DURATION_VALUES[unit]}", + self.config.rate_limit_limit_header_key: str(max_requests), self.config.rate_limit_remaining_header_key: remaining_requests, self.config.rate_limit_reset_header_key: str(cache_object.reset - int(time())), } @@ -216,8 +264,25 @@ def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: class RateLimitConfig: """Configuration for ``RateLimitMiddleware``""" - rate_limit: tuple[DurationUnit, int] - """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5).""" + rate_limit: tuple[DurationUnit, int] | None = field(default=None) + """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ``("day", 1)`` or + ``("minute", 5)``. + + Use :attr:`rate_limits` to specify multiple simultaneous rate limit conditions. When both ``rate_limit`` + and ``rate_limits`` are ``None`` a :exc:`ValueError` is raised at initialisation time. + """ + rate_limits: list[tuple[DurationUnit, int]] | None = field(default=None) + """A list of ``(unit, max_requests)`` tuples that are ALL enforced simultaneously. + + A ``429 Too Many Requests`` response is returned as soon as *any* condition is breached. This lets you + combine multiple time windows, for example:: + + RateLimitConfig(rate_limits=[("second", 10), ("minute", 100), ("hour", 2000)]) + + When ``rate_limit`` is also provided a :exc:`ValueError` is raised. When only ``rate_limit`` is + provided it is normalised to a single-element list internally, so all existing code continues to work + without modification. + """ exclude: str | list[str] | None = field(default=None) """A pattern or list of patterns to skip in the rate limiting middleware.""" exclude_opt_key: str | None = field(default=None) @@ -254,9 +319,27 @@ class RateLimitConfig: """Name of the :class:`Store <.stores.base.Store>` to use""" def __post_init__(self) -> None: + if self.rate_limit is None and not self.rate_limits: + raise ValueError("Either 'rate_limit' or 'rate_limits' must be provided to RateLimitConfig.") + if self.rate_limit is not None and self.rate_limits is not None: + raise ValueError( + "Provide either 'rate_limit' or 'rate_limits' to RateLimitConfig, not both." + ) if self.check_throttle_handler: self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] + @property + def _all_rate_limits(self) -> list[tuple[DurationUnit, int]]: + """Return a normalised list of ``(unit, max_requests)`` tuples. + + Always use this property rather than accessing :attr:`rate_limit` or :attr:`rate_limits` + directly so that the single-limit backward-compatible form is handled transparently. + """ + if self.rate_limits is not None: + return self.rate_limits + assert self.rate_limit is not None # guarded by __post_init__ + return [self.rate_limit] + @property def middleware(self) -> DefineMiddleware: """Use this property to insert the config into a middleware list on one of the application layers. @@ -277,6 +360,17 @@ def my_handler(request: Request) -> None: ... app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware]) + Multiple simultaneous conditions are also supported: + + .. code-block:: python + + from litestar.middleware.rate_limit import RateLimitConfig + + # max 5/second AND max 100/minute AND max 1000/hour — all enforced at once + throttle_config = RateLimitConfig( + rate_limits=[("second", 5), ("minute", 100), ("hour", 1000)] + ) + Returns: An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the config kwarg value. diff --git a/tests/unit/test_middleware/test_rate_limit_middleware.py b/tests/unit/test_middleware/test_rate_limit_middleware.py index b73899ccfc..754f997b8f 100644 --- a/tests/unit/test_middleware/test_rate_limit_middleware.py +++ b/tests/unit/test_middleware/test_rate_limit_middleware.py @@ -81,7 +81,7 @@ def handler() -> None: return None app = Litestar( - [handler], middleware=[RateLimitConfig(("second", 10)).middleware], stores={"rate_limit": memory_store} + [handler], middleware=[RateLimitConfig(rate_limit=("second", 10)).middleware], stores={"rate_limit": memory_store} ) with TestClient(app) as client: @@ -98,7 +98,7 @@ def handler() -> None: app = Litestar( [handler], - middleware=[RateLimitConfig(("second", 10), store="some_store").middleware], + middleware=[RateLimitConfig(rate_limit=("second", 10), store="some_store").middleware], stores={"some_store": memory_store}, ) @@ -303,3 +303,90 @@ def get_id_from_random_header(request: Request[Any, Any, Any]) -> str: response = client.get("/", headers={"x-private-header": "value"}) assert response.status_code == HTTP_429_TOO_MANY_REQUESTS + + +# --------------------------------------------------------------------------- +# Tests for multi-condition rate limiting (rate_limits=) +# --------------------------------------------------------------------------- + + +def test_rate_limit_config_requires_at_least_one_limit() -> None: + """RateLimitConfig should raise when neither rate_limit nor rate_limits is given.""" + import pytest + + with pytest.raises(ValueError, match="Either 'rate_limit' or 'rate_limits'"): + RateLimitConfig() + + +def test_rate_limit_config_rejects_both_fields() -> None: + """RateLimitConfig should raise when both rate_limit and rate_limits are given.""" + import pytest + + with pytest.raises(ValueError, match="not both"): + RateLimitConfig(rate_limit=("second", 5), rate_limits=[("minute", 100)]) + + +@travel(datetime.utcnow, tick=False) +def test_multiple_rate_limits_passes_when_all_satisfied() -> None: + """Requests within all limits should succeed.""" + + @get("/") + def handler() -> None: + return None + + config = RateLimitConfig(rate_limits=[("second", 3), ("minute", 5)]) + + with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client: + # First 3 requests are within both windows — all must succeed + for _ in range(3): + assert client.get("/").status_code == HTTP_200_OK + + +@travel(datetime.utcnow, tick=False) +def test_multiple_rate_limits_blocked_by_tighter_window() -> None: + """The per-second limit should trigger even though the per-minute limit is not yet reached.""" + + @get("/") + def handler() -> None: + return None + + # Allow 2/second but 100/minute — the second window is the bottleneck + config = RateLimitConfig(rate_limits=[("second", 2), ("minute", 100)]) + + with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client: + assert client.get("/").status_code == HTTP_200_OK + assert client.get("/").status_code == HTTP_200_OK + # Third request in the same second exceeds the per-second limit + assert client.get("/").status_code == HTTP_429_TOO_MANY_REQUESTS + + +@travel(datetime.utcnow, tick=False) +def test_multiple_rate_limits_blocked_by_wider_window() -> None: + """The per-minute limit should trigger once it is exhausted, even though the per-second limit still has quota.""" + + @get("/") + def handler() -> None: + return None + + # 5/second but only 3/minute — the minute window will be exhausted first + config = RateLimitConfig(rate_limits=[("second", 5), ("minute", 3)]) + + with create_test_client(route_handlers=[handler], middleware=[config.middleware]) as client: + assert client.get("/").status_code == HTTP_200_OK + assert client.get("/").status_code == HTTP_200_OK + assert client.get("/").status_code == HTTP_200_OK + # 4th request is still within the second window but exhausts the minute quota + assert client.get("/").status_code == HTTP_429_TOO_MANY_REQUESTS + + +def test_rate_limits_all_rate_limits_property_single() -> None: + """_all_rate_limits returns a one-element list when only rate_limit is given.""" + config = RateLimitConfig(rate_limit=("hour", 50)) + assert config._all_rate_limits == [("hour", 50)] + + +def test_rate_limits_all_rate_limits_property_multi() -> None: + """_all_rate_limits returns the full list when rate_limits is given.""" + limits = [("second", 10), ("minute", 200), ("hour", 5000)] + config = RateLimitConfig(rate_limits=limits) + assert config._all_rate_limits == limits