Skip to content

Commit 1a646b5

Browse files
committed
Merge commit from fork
1 parent 08297fd commit 1a646b5

3 files changed

Lines changed: 102 additions & 19 deletions

File tree

docs/usage/middleware/builtin-middleware.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,25 @@ The only required configuration kwarg is ``rate_limit``, which expects a tuple c
269269
``"minute"``, ``"hour"``, ``"day"``\ ) and a value for the request quota (integer).
270270

271271

272+
Using behind a proxy
273+
^^^^^^^^^^^^^^^^^^^^
274+
275+
The default mode for uniquely identifiying client uses the client's address. When an
276+
application is running behind a proxy, that address will be the proxy's, not the "real"
277+
address of the end-user.
278+
279+
While there are special headers set by proxies to retrieve the remote client's actual
280+
address (``X-FORWARDED-FOR``), their values should not implicitly be trusted, as any
281+
client is free to set them to whatever value they want. A rate-limit could easily be
282+
circumvented by spoofing these, and simply attaching a new, random address to each
283+
request.
284+
285+
The best way to handle applications running behind a proxy is to use a middleware that
286+
updates the client's address in a secure way, such as uvicorn's
287+
`ProxyHeaderMiddleware <https://github.com/encode/uvicorn/blob/master/uvicorn/middleware/proxy_headers.py>`_
288+
or hypercon's `ProxyFixMiddleware <https://hypercorn.readthedocs.io/en/latest/how_to_guides/proxy_fix.html>`_ .
289+
290+
272291
Logging Middleware
273292
------------------
274293

litestar/middleware/rate_limit.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from litestar.serialization import decode_json, encode_json
1212
from litestar.utils import ensure_async_callable
1313

14-
__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware")
14+
__all__ = (
15+
"CacheObject",
16+
"RateLimitConfig",
17+
"RateLimitMiddleware",
18+
"get_remote_address",
19+
)
1520

1621

1722
if TYPE_CHECKING:
@@ -38,6 +43,18 @@ class CacheObject:
3843
reset: int
3944

4045

46+
def get_remote_address(request: Request[Any, Any, Any]) -> str:
47+
"""Get a client's remote address from a ``Request``
48+
49+
Args:
50+
request: A :class:`Request <.connection.Request>` instance.
51+
52+
Returns:
53+
An address, uniquely identifying this client
54+
"""
55+
return request.client.host if request.client else "127.0.0.1"
56+
57+
4158
class RateLimitMiddleware(AbstractMiddleware):
4259
"""Rate-limiting middleware."""
4360

@@ -55,6 +72,7 @@ def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None:
5572
self.config = config
5673
self.max_requests: int = config.rate_limit[1]
5774
self.unit: DurationUnit = config.rate_limit[0]
75+
self.get_identifier_for_request = config.identifier_for_request
5876

5977
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6078
"""ASGI callable.
@@ -71,7 +89,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
7189
request: Request[Any, Any, Any] = app.request_class(scope)
7290
store = self.config.get_store_from_app(app)
7391
if await self.should_check_request(request=request):
74-
key = self.cache_key_from_request(request=request)
92+
identifier = self.get_identifier_for_request(request)
93+
key = f"{type(self).__name__}::{identifier}"
94+
route_handler = request.scope["route_handler"]
95+
if getattr(route_handler, "is_mount", False):
96+
key += "::mount"
97+
7598
cache_object = await self.retrieve_cached_history(key, store)
7699
if len(cache_object.history) >= self.max_requests:
77100
raise TooManyRequestsException(
@@ -114,23 +137,6 @@ async def send_wrapper(message: Message) -> None:
114137

115138
return send_wrapper
116139

117-
def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str:
118-
"""Get a cache-key from a ``Request``
119-
120-
Args:
121-
request: A :class:`Request <.connection.Request>` instance.
122-
123-
Returns:
124-
A cache key.
125-
"""
126-
host = request.client.host if request.client else "anonymous"
127-
identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host
128-
route_handler = request.scope["route_handler"]
129-
if getattr(route_handler, "is_mount", False):
130-
identifier += "::mount"
131-
132-
return f"{type(self).__name__}::{identifier}"
133-
134140
async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject:
135141
"""Retrieve a list of time stamps for the given duration unit.
136142
@@ -216,6 +222,18 @@ class RateLimitConfig:
216222
"""A pattern or list of patterns to skip in the rate limiting middleware."""
217223
exclude_opt_key: str | None = field(default=None)
218224
"""An identifier to use on routes to disable rate limiting for a particular route."""
225+
identifier_for_request: Callable[[Request], str] = get_remote_address
226+
"""
227+
A callable that receives the request and returns an identifier for which the limit
228+
should be applied. Defaults to :func:`~litestar.middleware.rate_limit.get_remote_address`, which returns the client's
229+
address.
230+
231+
Note that :func:`~litestar.middleware.rate_limit.get_remote_address` does *NOT* honour ``X-FORWARDED-FOR`` headers, as these cannot be
232+
trusted implicitly. If running behind a proxy, a secure way of updating the client's
233+
address should be implemented, such as uvicorn's
234+
`ProxyHeaderMiddleware <https://github.com/encode/uvicorn/blob/master/uvicorn/middleware/proxy_headers.py>`_
235+
or hypercon's `ProxyFixMiddleware <https://hypercorn.readthedocs.io/en/latest/how_to_guides/proxy_fix.html>`_ .
236+
"""
219237
check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None)
220238
"""Handler callable that receives the request instance, returning a boolean dictating whether or not the request
221239
should be checked for rate limiting.

tests/unit/test_middleware/test_rate_limit_middleware.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,49 @@ def handler() -> None:
257257

258258
response = client.get("/")
259259
assert response.status_code == HTTP_429_TOO_MANY_REQUESTS
260+
261+
262+
def test_ignore_x_forwarded_for() -> None:
263+
@get("/")
264+
def handler() -> None:
265+
return None
266+
267+
app = Litestar(
268+
route_handlers=[handler],
269+
middleware=[RateLimitConfig(rate_limit=("minute", 2)).middleware],
270+
)
271+
272+
with TestClient(app=app) as client:
273+
response = client.get("/")
274+
assert response.status_code == HTTP_200_OK
275+
response = client.get("/")
276+
assert response.status_code == HTTP_200_OK
277+
278+
# this shouldn't have any effect
279+
response = client.get("/", headers={"x-forwarded-for": "1.2.3.4"})
280+
assert response.status_code == HTTP_429_TOO_MANY_REQUESTS
281+
282+
283+
def test_custom_identity_function() -> None:
284+
@get("/")
285+
def handler() -> None:
286+
return None
287+
288+
def get_id_from_random_header(request: Request[Any, Any, Any]) -> str:
289+
return request.headers["x-private-header"]
290+
291+
app = Litestar(
292+
route_handlers=[handler],
293+
middleware=[
294+
RateLimitConfig(rate_limit=("minute", 2), identifier_for_request=get_id_from_random_header).middleware
295+
],
296+
)
297+
298+
with TestClient(app=app) as client:
299+
response = client.get("/", headers={"x-private-header": "value"})
300+
assert response.status_code == HTTP_200_OK
301+
response = client.get("/", headers={"x-private-header": "value"})
302+
assert response.status_code == HTTP_200_OK
303+
304+
response = client.get("/", headers={"x-private-header": "value"})
305+
assert response.status_code == HTTP_429_TOO_MANY_REQUESTS

0 commit comments

Comments
 (0)