1+ import time
2+ import logging
3+ from collections import defaultdict
4+
15from fastapi import Request , HTTPException , status
26from starlette .middleware .base import BaseHTTPMiddleware
7+ from starlette .responses import JSONResponse
38from typing import Dict , List , Callable , Optional
49
510from app .core .permissions import Permission , get_permissions_for_role
611from app .db .models .user import UserRole
712
13+ logger = logging .getLogger (__name__ )
14+
815
916class PermissionsMiddleware (BaseHTTPMiddleware ):
1017 """
@@ -144,4 +151,58 @@ async def dispatch(self, request: Request, call_next: Callable):
144151 "PUT" : [Permission .MANAGE_SYSTEM ],
145152 "DELETE" : [Permission .MANAGE_SYSTEM ],
146153 },
147- }
154+ }
155+
156+
157+ class RateLimitMiddleware (BaseHTTPMiddleware ):
158+ """
159+ Simple in-memory rate limiting middleware.
160+
161+ Limits requests per client IP using a sliding window approach.
162+ For production deployments with multiple workers, consider using
163+ a Redis-backed solution instead.
164+ """
165+
166+ def __init__ (
167+ self ,
168+ app ,
169+ requests_per_minute : int = 60 ,
170+ exempt_paths : Optional [List [str ]] = None ,
171+ ):
172+ super ().__init__ (app )
173+ self .requests_per_minute = requests_per_minute
174+ self .exempt_paths = exempt_paths or ["/health" , "/docs" , "/openapi.json" , "/redoc" ]
175+ # {client_ip: [timestamp, ...]}
176+ self ._requests : Dict [str , List [float ]] = defaultdict (list )
177+
178+ def _get_client_ip (self , request : Request ) -> str :
179+ forwarded = request .headers .get ("x-forwarded-for" )
180+ if forwarded :
181+ return forwarded .split ("," )[0 ].strip ()
182+ return request .client .host if request .client else "unknown"
183+
184+ def _cleanup (self , timestamps : List [float ], now : float ) -> List [float ]:
185+ """Remove timestamps older than 60 seconds."""
186+ cutoff = now - 60.0
187+ return [t for t in timestamps if t > cutoff ]
188+
189+ async def dispatch (self , request : Request , call_next : Callable ):
190+ if any (request .url .path .startswith (p ) for p in self .exempt_paths ):
191+ return await call_next (request )
192+
193+ client_ip = self ._get_client_ip (request )
194+ now = time .time ()
195+
196+ # Clean old entries and record this request
197+ self ._requests [client_ip ] = self ._cleanup (self ._requests [client_ip ], now )
198+
199+ if len (self ._requests [client_ip ]) >= self .requests_per_minute :
200+ logger .warning (f"Rate limit exceeded for { client_ip } " )
201+ return JSONResponse (
202+ status_code = 429 ,
203+ content = {"detail" : "Too many requests. Please try again later." },
204+ headers = {"Retry-After" : "60" },
205+ )
206+
207+ self ._requests [client_ip ].append (now )
208+ return await call_next (request )
0 commit comments