1111from litestar .serialization import decode_json , encode_json
1212from litestar .utils import ensure_async_callable
1313
14- __all__ = ("CacheObject" , "RateLimitConfig" , "RateLimitMiddleware" )
14+ __all__ = (
15+ "CacheObject" ,
16+ "RateLimitConfig" ,
17+ "RateLimitMiddleware" ,
18+ "get_remote_address" ,
19+ )
1520
1621
1722if TYPE_CHECKING :
@@ -38,6 +43,18 @@ class CacheObject:
3843 reset : int
3944
4045
46+ def get_remote_address (request : Request [Any , Any , Any ]) -> str :
47+ """Get a client's remote address from a ``Request``
48+
49+ Args:
50+ request: A :class:`Request <.connection.Request>` instance.
51+
52+ Returns:
53+ An address, uniquely identifying this client
54+ """
55+ return request .client .host if request .client else "127.0.0.1"
56+
57+
4158class RateLimitMiddleware (AbstractMiddleware ):
4259 """Rate-limiting middleware."""
4360
@@ -55,6 +72,7 @@ def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None:
5572 self .config = config
5673 self .max_requests : int = config .rate_limit [1 ]
5774 self .unit : DurationUnit = config .rate_limit [0 ]
75+ self .get_identifier_for_request = config .identifier_for_request
5876
5977 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
6078 """ASGI callable.
@@ -71,7 +89,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
7189 request : Request [Any , Any , Any ] = app .request_class (scope )
7290 store = self .config .get_store_from_app (app )
7391 if await self .should_check_request (request = request ):
74- key = self .cache_key_from_request (request = request )
92+ identifier = self .get_identifier_for_request (request )
93+ key = f"{ type (self ).__name__ } ::{ identifier } "
94+ route_handler = request .scope ["route_handler" ]
95+ if getattr (route_handler , "is_mount" , False ):
96+ key += "::mount"
97+
7598 cache_object = await self .retrieve_cached_history (key , store )
7699 if len (cache_object .history ) >= self .max_requests :
77100 raise TooManyRequestsException (
@@ -114,23 +137,6 @@ async def send_wrapper(message: Message) -> None:
114137
115138 return send_wrapper
116139
117- def cache_key_from_request (self , request : Request [Any , Any , Any ]) -> str :
118- """Get a cache-key from a ``Request``
119-
120- Args:
121- request: A :class:`Request <.connection.Request>` instance.
122-
123- Returns:
124- A cache key.
125- """
126- host = request .client .host if request .client else "anonymous"
127- identifier = request .headers .get ("X-Forwarded-For" ) or request .headers .get ("X-Real-IP" ) or host
128- route_handler = request .scope ["route_handler" ]
129- if getattr (route_handler , "is_mount" , False ):
130- identifier += "::mount"
131-
132- return f"{ type (self ).__name__ } ::{ identifier } "
133-
134140 async def retrieve_cached_history (self , key : str , store : Store ) -> CacheObject :
135141 """Retrieve a list of time stamps for the given duration unit.
136142
@@ -216,6 +222,18 @@ class RateLimitConfig:
216222 """A pattern or list of patterns to skip in the rate limiting middleware."""
217223 exclude_opt_key : str | None = field (default = None )
218224 """An identifier to use on routes to disable rate limiting for a particular route."""
225+ identifier_for_request : Callable [[Request ], str ] = get_remote_address
226+ """
227+ A callable that receives the request and returns an identifier for which the limit
228+ should be applied. Defaults to :func:`~litestar.middleware.rate_limit.get_remote_address`, which returns the client's
229+ address.
230+
231+ Note that :func:`~litestar.middleware.rate_limit.get_remote_address` does *NOT* honour ``X-FORWARDED-FOR`` headers, as these cannot be
232+ trusted implicitly. If running behind a proxy, a secure way of updating the client's
233+ address should be implemented, such as uvicorn's
234+ `ProxyHeaderMiddleware <https://github.com/encode/uvicorn/blob/master/uvicorn/middleware/proxy_headers.py>`_
235+ or hypercon's `ProxyFixMiddleware <https://hypercorn.readthedocs.io/en/latest/how_to_guides/proxy_fix.html>`_ .
236+ """
219237 check_throttle_handler : Callable [[Request [Any , Any , Any ]], SyncOrAsyncUnion [bool ]] | None = field (default = None )
220238 """Handler callable that receives the request instance, returning a boolean dictating whether or not the request
221239 should be checked for rate limiting.
0 commit comments