Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions mcpgateway/middleware/rate_limit_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, app):
self._violation_expiry: Dict[str, float] = {}
self._store_lock = threading.Lock()

logger.info(f"RateLimitMiddleware initialized: enabled={self.enabled}, " f"use_redis={self.use_redis}, lockout={self.lockout_enabled}")
logger.info(f"RateLimitMiddleware initialized: enabled={self.enabled}, use_redis={self.use_redis}, lockout={self.lockout_enabled}")

def _init_redis(self) -> None:
"""Initialize Redis client."""
Expand Down Expand Up @@ -166,8 +166,33 @@ def _compile_tiers(self) -> List[Tuple[re.Pattern, Dict[str, Any]]]:
compiled.append((pattern, config))
return compiled

def get_endpoint_tier(self, path: str) -> Dict[str, Any]:
"""Get tier config for endpoint."""
def get_endpoint_tier(self, request_or_path) -> Dict[str, Any]:
"""Get tier config for endpoint.

Args:
request_or_path: Either a Request object (recommended) or a string path (legacy).
When Request is passed, strips root_path before matching.

Returns:
Tier configuration dict with 'limit' and 'burst' keys.
"""
# Handle both Request object and string path for backwards compatibility
if isinstance(request_or_path, str):
path = request_or_path
else:
# Extract and normalize path from Request
path = request_or_path.url.path
root_path = request_or_path.scope.get("root_path") or settings.app_root_path or ""

# Strip root_path prefix (same logic as security_headers.py and token_scoping.py)
if root_path and len(root_path) > 1:
root_path = root_path.rstrip("/")
if path.startswith(root_path):
rest = path[len(root_path) :]
# Ensure we only strip if followed by "/" or end of string (avoid partial matches)
if rest == "" or rest.startswith("/"):
path = rest or "/"

for pattern, config in self.compiled_tiers:
if pattern.match(path):
return config
Expand Down Expand Up @@ -199,10 +224,10 @@ async def dispatch(self, request: Request, call_next):
if not self.enabled:
return await call_next(request)

tier = self.get_endpoint_tier(request.url.path)
tier = self.get_endpoint_tier(request)
dimensions = self._get_client_dimensions(request)

tier_name = self._get_tier_name(request.url.path)
tier_name = self._get_tier_name(request)

# Check lockout first — a locked-out dimension blocks regardless of
# whether the sliding window has cleared.
Expand Down Expand Up @@ -267,8 +292,32 @@ async def dispatch(self, request: Request, call_next):

return response

def _get_tier_name(self, path: str) -> str:
"""Get tier name for logging."""
def _get_tier_name(self, request_or_path) -> str:
"""Get tier name for logging.

Args:
request_or_path: Either a Request object (recommended) or a string path (legacy).
When Request is passed, strips root_path before matching.

Returns:
Tier name string (CRITICAL/CRITICAL_SSO/HIGH/MEDIUM/LOW).
"""
# Handle both Request object and string path for backwards compatibility
if isinstance(request_or_path, str):
path = request_or_path
else:
# Extract and normalize path from Request
path = request_or_path.url.path
root_path = request_or_path.scope.get("root_path") or settings.app_root_path or ""

# Strip root_path prefix (same logic as get_endpoint_tier)
if root_path and len(root_path) > 1:
root_path = root_path.rstrip("/")
if path.startswith(root_path):
rest = path[len(root_path) :]
if rest == "" or rest.startswith("/"):
path = rest or "/"

for tier_name, config in self.endpoint_tiers.items():
if re.match(config["pattern"], path):
return tier_name
Expand Down Expand Up @@ -472,7 +521,7 @@ def _create_rate_limit_response(
status_code=429,
content={
"error": "Account locked",
"message": f"Too many rate limit violations. Account locked for {self.lockout_duration_minutes} minutes. " "This may indicate suspicious activity on your account.",
"message": f"Too many rate limit violations. Account locked for {self.lockout_duration_minutes} minutes. This may indicate suspicious activity on your account.",
"lockout_duration_minutes": self.lockout_duration_minutes,
"reset_in_seconds": self.lockout_duration_minutes * 60,
},
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/mcpgateway/middleware/test_rate_limit_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,3 +1263,110 @@ def test_should_lockout_memory_initializes_expiry(self, middleware):

assert result is False
assert hasattr(middleware, "_violation_expiry")

def test_endpoint_tier_strips_root_path(self, mock_app):
"""Test tier matching strips root_path prefix before pattern matching.

Issue #4994: When APP_ROOT_PATH is set (e.g., /mcp-gateway/service/gateway),
the middleware should strip this prefix before matching tier patterns.
Without stripping, all requests fall into the LOW default tier.
"""
# Set up middleware with app_root_path in settings
with patch("mcpgateway.middleware.rate_limit_middleware.settings") as mock_settings:
mock_settings.rate_limiting_enabled = True
mock_settings.rate_limiting_redis_enabled = False
mock_settings.trust_proxy_auth = True
mock_settings.rate_limit_critical_rpm = 10
mock_settings.rate_limit_critical_burst = 0
mock_settings.rate_limit_high_rpm = 30
mock_settings.rate_limit_high_burst = 0
mock_settings.rate_limit_medium_rpm = 100
mock_settings.rate_limit_medium_burst = 20
mock_settings.rate_limit_low_rpm = 500
mock_settings.rate_limit_low_burst = 100
mock_settings.rate_limit_lockout_enabled = True
mock_settings.rate_limit_lockout_threshold = 5
mock_settings.rate_limit_lockout_duration_minutes = 15
mock_settings.app_root_path = "/mcp-gateway/service/gateway"

from mcpgateway.middleware.rate_limit_middleware import RateLimitMiddleware

middleware = RateLimitMiddleware(mock_app)

# Test cases with root_path
test_cases = [
# (full_path_with_root, root_path_in_scope, expected_limit, expected_tier_name)
("/mcp-gateway/service/gateway/servers/123/mcp", "/mcp-gateway/service/gateway", 100, "MEDIUM"),
("/mcp-gateway/service/gateway/tokens", "/mcp-gateway/service/gateway", 30, "HIGH"),
("/mcp-gateway/service/gateway/mcp/list", "/mcp-gateway/service/gateway", 100, "MEDIUM"),
("/mcp-gateway/service/gateway/auth/email/login", "/mcp-gateway/service/gateway", 10, "CRITICAL"),
("/mcp-gateway/service/gateway/rbac/roles", "/mcp-gateway/service/gateway", 30, "HIGH"),
("/mcp-gateway/service/gateway/health", "/mcp-gateway/service/gateway", 500, "LOW"),
# Test with different root_path
("/api/v1/tools/execute", "/api/v1", 100, "MEDIUM"),
("/api/v1/tokens/create", "/api/v1", 30, "HIGH"),
# Test without root_path should still work
("/servers/123/mcp", "", 100, "MEDIUM"),
("/tokens", "", 30, "HIGH"),
]

for full_path, root_path, expected_limit, expected_tier_name in test_cases:
# Create mock request with root_path in scope
mock_request = MagicMock()
mock_request.url.path = full_path
mock_request.scope = {"root_path": root_path} if root_path else {}

# Test get_endpoint_tier - it should accept Request and normalize path
tier = middleware.get_endpoint_tier(mock_request)
assert tier["limit"] == expected_limit, f"Path {full_path} with root_path {root_path} should match tier with limit {expected_limit}, got {tier['limit']}"

# Test _get_tier_name - it should accept Request and normalize path
tier_name = middleware._get_tier_name(mock_request)
assert tier_name == expected_tier_name, f"Path {full_path} with root_path {root_path} should match tier {expected_tier_name}, got {tier_name}"

def test_endpoint_tier_root_path_edge_cases(self, mock_app):
"""Test root_path edge cases: trailing slash, empty, single slash."""
with patch("mcpgateway.middleware.rate_limit_middleware.settings") as mock_settings:
mock_settings.rate_limiting_enabled = True
mock_settings.rate_limiting_redis_enabled = False
mock_settings.trust_proxy_auth = True
mock_settings.rate_limit_critical_rpm = 10
mock_settings.rate_limit_critical_burst = 0
mock_settings.rate_limit_high_rpm = 30
mock_settings.rate_limit_high_burst = 0
mock_settings.rate_limit_medium_rpm = 100
mock_settings.rate_limit_medium_burst = 20
mock_settings.rate_limit_low_rpm = 500
mock_settings.rate_limit_low_burst = 100
mock_settings.rate_limit_lockout_enabled = True
mock_settings.rate_limit_lockout_threshold = 5
mock_settings.rate_limit_lockout_duration_minutes = 15
mock_settings.app_root_path = ""

from mcpgateway.middleware.rate_limit_middleware import RateLimitMiddleware

middleware = RateLimitMiddleware(mock_app)

test_cases = [
# root_path with trailing slash should be normalized
("/app/servers/list", "/app/", 100, "MEDIUM"),
# root_path that is just "/" should not strip anything
("/servers/list", "/", 100, "MEDIUM"),
# Empty root_path should behave normally
("/servers/list", "", 100, "MEDIUM"),
# root_path must not partially match (e.g., /app should not strip from /application)
("/application/data", "/app", 500, "LOW"),
# Test that stripping produces a valid path (not empty)
("/app", "/app", 500, "LOW"), # After stripping becomes "/" or "" which should default to LOW
]

for full_path, root_path, expected_limit, expected_tier_name in test_cases:
mock_request = MagicMock()
mock_request.url.path = full_path
mock_request.scope = {"root_path": root_path} if root_path else {}

tier = middleware.get_endpoint_tier(mock_request)
assert tier["limit"] == expected_limit, f"Path {full_path} with root_path {root_path} should have limit {expected_limit}, got {tier['limit']}"

tier_name = middleware._get_tier_name(mock_request)
assert tier_name == expected_tier_name, f"Path {full_path} with root_path {root_path} should match tier {expected_tier_name}, got {tier_name}"
Loading