Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/api-base/fastApiContainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ export class FastApiContainer extends Construct {
USER_GROUP: config.authConfig!.userGroup,
RAG_ADMIN_GROUP: config.authConfig!.ragAdminGroup,
JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty,
MANAGEMENT_KEY_NAME: managementKeyName
MANAGEMENT_KEY_NAME: managementKeyName,
// Per-user rate limiting (in-process token bucket)
RATE_LIMIT_RPM: (config.restApiConfig.rateLimitRpm ?? 60).toString(),
RATE_LIMIT_BURST: (config.restApiConfig.rateLimitBurst ?? 10).toString(),
RATE_LIMIT_ENABLED: (config.restApiConfig.rateLimitEnabled ?? true).toString(),
RATE_LIMIT_OVERRIDES: JSON.stringify(config.restApiConfig.rateLimitOverrides ?? {}),
};

if (tokenTable) {
Expand Down
19 changes: 19 additions & 0 deletions lib/schema/configSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,25 @@ const FastApiContainerConfigSchema = z.object({
'APIs. Please do not define the dbHost field for the FastAPI container DB config.',
},
),
rateLimitRpm: z.number().int().positive().default(60).describe(
'Per-user sustained request rate limit in requests per minute. Each ECS task enforces this independently.'
),
rateLimitBurst: z.number().int().nonnegative().default(10).describe(
'Per-user burst allowance above the sustained rate limit. Allows short spikes without throttling.'
),
rateLimitEnabled: z.boolean().default(true).describe(
'Enable or disable per-user rate limiting on the serve API.'
),
rateLimitOverrides: z.record(
z.string(),
z.object({
rpm: z.number().int().positive().optional().describe('Override RPM for this user/token.'),
burst: z.number().int().nonnegative().optional().describe('Override burst for this user/token.'),
})
).default({}).describe(
'Per-user rate limit overrides. Keys use the format "token:<tokenUUID>" for API tokens, ' +
'"oidc:<sub>" for OIDC users, or "user:<username>". Values can override "rpm" and/or "burst".'
),
}).describe('Configuration schema for REST API.');

/** Custom domain / TLS for the MCP Workbench ALB only (separate from Serve’s `restApiConfig`). */
Expand Down
14 changes: 13 additions & 1 deletion lib/serve/rest-api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from middleware import (
auth_middleware,
process_request_middleware,
rate_limit_middleware,
register_exception_handlers,
security_middleware,
validate_input_middleware,
Expand Down Expand Up @@ -82,12 +83,23 @@ async def lifespan(app: FastAPI): # type: ignore
##############


@app.middleware("http")
async def rate_limit(request, call_next): # type: ignore
"""Per-user rate limiting middleware.

Runs after authentication (user identity is available) to enforce
per-API-key / per-user request rate limits.
"""
return await rate_limit_middleware(request, call_next)
Comment thread
estohlmann marked this conversation as resolved.


@app.middleware("http")
async def authenticate(request, call_next): # type: ignore
"""Authentication middleware.

Validates tokens and sets user context on request.state.
Runs after security checks but before request processing.
NOTE: Function middleware executes in reverse registration order in FastAPI,
so this must be declared *after* rate_limit() to run first on requests.
"""
return await auth_middleware(request, call_next)

Expand Down
2 changes: 2 additions & 0 deletions lib/serve/rest-api/src/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from .auth_middleware import auth_middleware, require_admin, require_auth
from .exception_handlers import register_exception_handlers
from .input_validation import validate_input_middleware
from .rate_limit_middleware import rate_limit_middleware
from .request_middleware import process_request_middleware
from .security_middleware import security_middleware

__all__ = [
"auth_middleware",
"process_request_middleware",
"rate_limit_middleware",
"register_exception_handlers",
"require_admin",
"require_auth",
Expand Down
3 changes: 3 additions & 0 deletions lib/serve/rest-api/src/middleware/auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo
if os.getenv("USE_AUTH", "true").lower() != "true":
request.state.authenticated = True
request.state.jwt_data = None
request.state.is_management_token = False
request.state.is_admin = True
request.state.username = "anonymous"
request.state.groups = []
Expand All @@ -93,6 +94,7 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo
# Set authentication context on request state
request.state.authenticated = True
request.state.jwt_data = jwt_data
request.state.is_management_token = False

# Determine admin status based on auth type
if jwt_data:
Expand All @@ -115,6 +117,7 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo
request.state.groups = token_info.get("groups", [])
else:
# Management token - full admin access
request.state.is_management_token = True
request.state.is_admin = True
request.state.username = "management-token"
request.state.groups = []
Expand Down
280 changes: 280 additions & 0 deletions lib/serve/rest-api/src/middleware/rate_limit_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Per-user rate limiting middleware using an in-memory token bucket.

Runs after authentication so the caller identity is available on ``request.state``.
Each ECS task tracks limits independently — the effective per-user limit across the
fleet is ``N_tasks × RATE_LIMIT_RPM``, which naturally scales with capacity.

Configuration (environment variables):
RATE_LIMIT_RPM – sustained requests per minute per user (default 60)
RATE_LIMIT_BURST – extra burst allowance above the sustained rate (default 10)
RATE_LIMIT_ENABLED – set to "false" to disable (default "true")
RATE_LIMIT_OVERRIDES – JSON map of per-user/per-token overrides (default "{}")
Keys match the user_key format: "token:<tokenUUID>" or "oidc:<sub>" or "user:<username>"
Values are objects with optional "rpm" and "burst" fields.
Example: {"token:abc-123": {"rpm": 120, "burst": 20}, "oidc:user-456": {"rpm": 10}}
"""

import asyncio
import hashlib
import json
import os
import time
from collections.abc import Callable

from fastapi import Request, Response
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.status import HTTP_429_TOO_MANY_REQUESTS

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
RATE_LIMIT_RPM = int(os.environ.get("RATE_LIMIT_RPM", "60"))
RATE_LIMIT_BURST = int(os.environ.get("RATE_LIMIT_BURST", "10"))
RATE_LIMIT_ENABLED = os.environ.get("RATE_LIMIT_ENABLED", "true").lower() == "true"


# Per-user overrides: { "token:<uuid>": {"rpm": N, "burst": N}, ... }
def _parse_overrides(raw: str) -> dict[str, dict[str, int]]:
"""Parse the RATE_LIMIT_OVERRIDES JSON env var into a validated dict."""
if not raw:
return {}
try:
parsed = json.loads(raw)
if not isinstance(parsed, dict):
logger.warning("RATE_LIMIT_OVERRIDES is not a JSON object, ignoring")
return {}
result: dict[str, dict[str, int]] = {}
for key, val in parsed.items():
if not isinstance(val, dict):
logger.warning(f"RATE_LIMIT_OVERRIDES[{key}] is not an object, skipping")
continue
entry: dict[str, int] = {}
if "rpm" in val:
entry["rpm"] = int(val["rpm"])
if "burst" in val:
entry["burst"] = int(val["burst"])
result[str(key)] = entry
return result
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse RATE_LIMIT_OVERRIDES: {e}")
return {}


RATE_LIMIT_OVERRIDES: dict[str, dict[str, int]] = _parse_overrides(os.environ.get("RATE_LIMIT_OVERRIDES", ""))

# Derived: tokens added per second (system default)
_REFILL_RATE = RATE_LIMIT_RPM / 60.0

# Paths exempt from rate limiting
_EXEMPT_PATHS = {"/health", "/health/readiness", "/health/liveliness"}

# Maximum number of tracked users before we prune stale entries
_MAX_BUCKETS = 10_000
# Entries older than this (seconds) are eligible for pruning
_STALE_SECONDS = 300.0


# ---------------------------------------------------------------------------
# Token bucket implementation
# ---------------------------------------------------------------------------


class _TokenBucket:
"""Simple token bucket for a single user.

Not thread-safe on its own — callers must hold ``_lock``.
"""

__slots__ = ("tokens", "last_refill")

def __init__(self, max_tokens: float) -> None:
self.tokens: float = max_tokens
self.last_refill: float = time.monotonic()

def try_consume(self, max_tokens: float, refill_rate: float) -> tuple[bool, float]:
"""Refill and attempt to consume one token.

Returns ``(allowed, retry_after_seconds)``.
"""
now = time.monotonic()
elapsed = now - self.last_refill
self.tokens = min(max_tokens, self.tokens + elapsed * refill_rate)
self.last_refill = now

if self.tokens >= 1.0:
self.tokens -= 1.0
return True, 0.0

# Disabled refill (rpm=0): allow burst-only capacity and then throttle.
if refill_rate <= 0:
return False, 60.0

# How long until one token is available
wait = (1.0 - self.tokens) / refill_rate
return False, wait

Comment thread
estohlmann marked this conversation as resolved.

# Global bucket store — keyed by user identity string
_buckets: dict[str, _TokenBucket] = {}
Comment thread
estohlmann marked this conversation as resolved.
_lock = asyncio.Lock()


Comment thread
estohlmann marked this conversation as resolved.
def _get_max_tokens() -> float:
"""Max tokens = sustained rate (per minute converted to bucket size) + burst.

Returns the system default. For per-user values use ``_get_user_limits``.
"""
return float(RATE_LIMIT_RPM) + float(RATE_LIMIT_BURST)


def _get_user_limits(user_key: str) -> tuple[float, float, float]:
"""Return (max_tokens, refill_rate, rpm) for a specific user.

Checks ``RATE_LIMIT_OVERRIDES`` first, falls back to system defaults.
"""
override = RATE_LIMIT_OVERRIDES.get(user_key)
if override:
rpm = max(int(override.get("rpm", RATE_LIMIT_RPM)), 0)
burst = max(int(override.get("burst", RATE_LIMIT_BURST)), 0)
else:
rpm = max(RATE_LIMIT_RPM, 0)
burst = max(RATE_LIMIT_BURST, 0)
max_tokens = float(rpm) + float(burst)
refill_rate = rpm / 60.0
return max_tokens, refill_rate, float(rpm)


def _prune_stale_buckets() -> None:
"""Remove buckets that haven't been touched recently. Must hold ``_lock``."""
now = time.monotonic()
stale_keys = [k for k, b in _buckets.items() if (now - b.last_refill) > _STALE_SECONDS]
for k in stale_keys:
del _buckets[k]


async def _check_rate_limit(user_key: str) -> tuple[bool, float]:
"""Check whether *user_key* is within its rate limit.

Returns ``(allowed, retry_after_seconds)``.
Uses per-user overrides from ``RATE_LIMIT_OVERRIDES`` when available.
"""
max_tokens, refill_rate, _ = _get_user_limits(user_key)

async with _lock:
if len(_buckets) >= _MAX_BUCKETS:
_prune_stale_buckets()

bucket = _buckets.get(user_key)
if bucket is None:
bucket = _TokenBucket(max_tokens)
_buckets[user_key] = bucket

return bucket.try_consume(max_tokens, refill_rate)


# ---------------------------------------------------------------------------
# User identity extraction
# ---------------------------------------------------------------------------


def _get_user_key(request: Request) -> str | None:
"""Derive a rate-limit key from the authenticated request.

Returns ``None`` for requests that should bypass rate limiting
(management tokens, unauthenticated/public paths).
"""
if not getattr(request.state, "authenticated", False):
return None

api_token_info = getattr(request.state, "api_token_info", None)
jwt_data = getattr(request.state, "jwt_data", None)
username = getattr(request.state, "username", None)

# Management tokens bypass rate limiting — they're internal automation.
if getattr(request.state, "is_management_token", False):
return None
if api_token_info is None and not jwt_data and username == "management-token":
return None

# API token users — key on tokenUUID (unique per key)
if api_token_info and isinstance(api_token_info, dict):
token_uuid = api_token_info.get("tokenUUID")
if token_uuid:
return f"token:{token_uuid}"
# Fallback to username if no UUID (shouldn't happen for valid tokens)
return f"token:{api_token_info.get('username', 'unknown')}"

# OIDC users — key on subject claim
if jwt_data and isinstance(jwt_data, dict):
sub = jwt_data.get("sub") or jwt_data.get("username")
if sub:
return f"oidc:{sub}"

# Fallback to username set by auth middleware
if username:
return f"user:{username}"

return None


# ---------------------------------------------------------------------------
# Middleware entry point
# ---------------------------------------------------------------------------


async def rate_limit_middleware(request: Request, call_next: Callable[[Request], Response]) -> Response:
"""Per-user rate limiting middleware.

Must run **after** authentication middleware so that ``request.state``
contains the caller identity.
"""
if not RATE_LIMIT_ENABLED:
return await call_next(request)

# Skip exempt paths
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)

# Skip OPTIONS (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)

user_key = _get_user_key(request)
if user_key is None:
# Can't identify user or exempt category — let it through
return await call_next(request)

allowed, retry_after = await _check_rate_limit(user_key)

if not allowed:
user_type = user_key.split(":", 1)[0]
user_hash = hashlib.sha256(user_key.encode("utf-8")).hexdigest()[:10]
logger.warning(f"Rate limit exceeded for {user_type}:{user_hash}, retry_after={retry_after:.1f}s")
return JSONResponse(
status_code=HTTP_429_TOO_MANY_REQUESTS,
content={
"error": {
"message": "Rate limit exceeded. Please slow down and retry.",
"type": "rate_limit_error",
"code": "rate_limit_exceeded",
}
},
headers={"Retry-After": str(int(retry_after) + 1)},
)

return await call_next(request)
Loading
Loading