Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 215 additions & 25 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ azure-monitor-opentelemetry = "^1.8.1"
azure-monitor-opentelemetry-exporter = "^1.0.0b44"
xarray = "^2025.10.1"
pytest-ordering = "^0.6"
slowapi = "^0.1.9"

[tool.pytest.ini_options]
asyncio_mode = "strict"
Expand Down
4 changes: 4 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,9 @@ class Config(BaseSettings):
PROFILE: bool = False
LOGGER: bool = True

##ratelimiter settings
RATE_LIMIT_WINDOW: int = 60 # in seconds
MAX_REQUESTS_PER_WINDOW: int = 1000 # max requests per window


config = Config()
9 changes: 8 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from fastapi.middleware.cors import CORSMiddleware
from azure.monitor.opentelemetry import configure_azure_monitor # type: ignore
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor # type: ignore
from src.middleware.rate_limiter_middleware import limiter
from slowapi.middleware import SlowAPIMiddleware

from src.middleware.exception_handling_middleware import ExceptionFilterMiddleware
from src.logger import DOT_API_LOGGER_NAME, get_dot_api_logger
Expand Down Expand Up @@ -62,6 +64,10 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.info("Error occurred while configuring telemetry: %s", e)

# add rate limiter middleware
app.state.limiter = limiter
app.add_middleware(SlowAPIMiddleware)

# Adding CORS middleware to the FastAPI application
app.add_middleware(
CORSMiddleware,
Expand All @@ -70,12 +76,13 @@ async def lifespan(app: FastAPI):
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all HTTP headers
)
app.add_middleware(ExceptionFilterMiddleware)

if config.PROFILE:
# this will generate a profile.html at repository root when running any endpoint
app.add_middleware(PyInstrumentMiddleWare)

app.add_middleware(ExceptionFilterMiddleware)


@app.get("/", status_code=status.HTTP_200_OK)
async def root():
Expand Down
13 changes: 13 additions & 0 deletions src/middleware/exception_handling_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from slowapi.errors import RateLimitExceeded
from src.logger import get_dot_api_logger
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.exceptions import RequestValidationError
Expand All @@ -15,6 +16,18 @@ async def dispatch(self, request: Request, call_next): # type: ignore
# Process request and response
response = await call_next(request)
return response
except RateLimitExceeded as exc:
# Handle rate limit exceeded
logger.warning(
f"Rate limit exceeded for {request.client.host if request.client else 'unknown'}"
)
return JSONResponse(
status_code=429,
content={
"detail": "Rate limit exceeded. Please try again later.",
"retry_after": exc.detail,
},
)
except HTTPException as exc:
# Log and return custom message for HTTP exceptions
logger.error(f"HTTPException: {exc}")
Expand Down
34 changes: 34 additions & 0 deletions src/middleware/rate_limiter_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastapi import HTTPException, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from src.config import config


def get_client_key(request: Request) -> str:
"""
Get rate limit key for the client.

Priority:
1. IP address from X-Forwarded-For header (for proxied requests)
2. Direct client IP address(for local/development requests)
"""
# Try X-Forwarded-For first (for requests )
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
ip = forwarded.split(",")[0].strip()
if ip:
return f"ip:{ip}"

# Try direct client IP for local
ip = get_remote_address(request)
if ip:
return f"ip:{ip}"

# Cannot identify client - reject request
raise HTTPException(status_code=403, detail="Unable to identify client for rate limiting")


# Default rate limit string (e.g., "100/minute")
DEFAULT_RATE_LIMIT = f"{config.MAX_REQUESTS_PER_WINDOW}/{config.RATE_LIMIT_WINDOW}second"
# Create limiter instance with client key function
limiter = Limiter(key_func=get_client_key, default_limits=[DEFAULT_RATE_LIMIT])
17 changes: 17 additions & 0 deletions tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from httpx import AsyncClient

from src.config import config


@pytest.mark.asyncio
async def test_rate_limiter_with_root_endpoint(client: AsyncClient):
"""Test rate limiting on the root endpoint."""
# Make requests up to the configured limit
max_requests = config.MAX_REQUESTS_PER_WINDOW

# First batch of requests should succeed
for i in range(min(max_requests, 10)): # Test up to 10 to keep test fast
response = await client.get("/")
assert response.status_code == 200, f"Request {i+1} failed unexpectedly"
assert response.json() == {"message": "Welcome to the DOT api"}