From d559ad90be441ff04db1bb552a1e3b50ecd64a1f Mon Sep 17 00:00:00 2001 From: Jitesh Nair Date: Tue, 10 Mar 2026 00:38:53 +0000 Subject: [PATCH 1/3] fix(plugins): propagate HTTP status codes and headers from plugin violations Add support for plugins to specify HTTP status codes (e.g., 429 for rate limiting) and custom headers (e.g., Retry-After) in PluginViolation responses. - Add http_status_code and http_headers fields to PluginViolation model - Implement PLUGIN_VIOLATION_CODE_MAPPING for common violation types - Update plugin_violation_exception_handler to use explicit status codes with fallback to code mapping, defaulting to 200 for JSON-RPC compliance - Add RFC 9110 header validation to prevent header injection - Enhance rate limiter plugin with multi-dimensional rate limiting, proper HTTP 429 responses, and X-RateLimit-* / Retry-After headers - Add comprehensive test coverage for status code precedence, header propagation, and header validation Fixes #2668 Signed-off-by: Jitesh Nair Signed-off-by: Mihai Criveti --- mcpgateway/main.py | 79 ++- mcpgateway/plugins/framework/constants.py | 92 ++++ mcpgateway/plugins/framework/models.py | 6 + plugins/rate_limiter/rate_limiter.py | 192 +++++-- tests/integration/test_rate_limiter.py | 455 +++++++++++++++++ .../plugins/framework/test_constants.py | 226 +++++++++ .../plugins/rate_limiter/test_rate_limiter.py | 469 +++++++++++++++++- tests/unit/mcpgateway/test_main.py | 341 ++++++++++++- tests/unit/mcpgateway/test_main_helpers.py | 122 +++++ 9 files changed, 1934 insertions(+), 48 deletions(-) create mode 100644 tests/integration/test_rate_limiter.py create mode 100644 tests/unit/mcpgateway/plugins/framework/test_constants.py diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 2750f6073e..24ba896caf 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -33,6 +33,7 @@ from functools import lru_cache import hashlib import html +import re import sys from typing import Any, AsyncIterator, Dict, List, Optional, Union from urllib.parse import urlparse, urlunparse @@ -86,6 +87,7 @@ from mcpgateway.middleware.validation_middleware import ValidationMiddleware from mcpgateway.observability import init_telemetry from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError +from mcpgateway.plugins.framework.constants import PLUGIN_VIOLATION_CODE_MAPPING, PluginViolationCode, VALID_HTTP_STATUS_CODES from mcpgateway.routers.server_well_known import router as server_well_known_router from mcpgateway.routers.well_known import router as well_known_router from mcpgateway.schemas import ( @@ -1450,6 +1452,51 @@ async def database_exception_handler(_request: Request, exc: IntegrityError): return ORJSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc)) +def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]: + """Validate headers according to RFC 9110. + + Args: + headers: dict of headers + + Returns: + Optional[dict[str, str]]: dictionary of valid headers + + Rules enforced: + - Header name must match RFC 9110 'token'. + - No whitespace before colon (enforced by dictionary usage). + - Header value must not contain CTL characters (0x00–0x1F, 0x7F). + """ + + # RFC 9110 'token' definition: + # token = 1*tchar + # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + # / DIGIT / ALPHA + header_key = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") + validated: dict[str, str] = {} + for key, value in headers.items(): + # Validate header name (RFC 9110) + if not re.match(header_key, key): + logger.warning(f"Invalid header name: {key}") + continue + # Validate header value (no CRLF) + if "\r" in value or "\n" in value: + logger.warning(f"Header value contains CRLF: {key}") + continue + # RFC 9110: Reject CTLs (0x00–0x1F, 0x7F). Allow SP (0x20) and HTAB (0x09). + # Further structure (quoted-string, lists, parameters) is left to higher-level parsers. + valid = True + for ch in value: + code = ord(ch) + if (0 <= code <= 31 or code == 127) and code not in (9, 32): + valid = False + break + if not valid: + continue + validated[key] = value + return validated if validated else None + + @app.exception_handler(PluginViolationError) async def plugin_violation_exception_handler(_request: Request, exc: PluginViolationError): """Handle plugins violations globally. @@ -1464,7 +1511,9 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola violation details. Returns: - JSONResponse: A 200 response with error details in JSON-RPC format. + JSONResponse: A response with error details in JSON-RPC format. + Uses HTTP status code from violation if present (e.g., 429 for rate limiting), + otherwise defaults to 200 for JSON-RPC compliance. Examples: >>> from mcpgateway.plugins.framework import PluginViolationError @@ -1482,7 +1531,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola ... )) >>> result = asyncio.run(plugin_violation_exception_handler(None, mock_error)) >>> result.status_code - 200 + 422 >>> content = orjson.loads(result.body.decode()) >>> content["error"]["code"] -32602 @@ -1496,6 +1545,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola policy_violation["message"] = exc.message status_code = exc.violation.mcp_error_code if exc.violation and exc.violation.mcp_error_code else -32602 violation_details: dict[str, Any] = {} + http_status = 200 if exc.violation: if exc.violation.description: violation_details["description"] = exc.violation.description @@ -1505,8 +1555,31 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola violation_details["plugin_error_code"] = exc.violation.code if exc.violation.plugin_name: violation_details["plugin_name"] = exc.violation.plugin_name + + # Use HTTP status code from violation if present (e.g., 429 for rate limiting) + http_status = exc.violation.http_status_code if exc.violation.http_status_code else None + if http_status and not VALID_HTTP_STATUS_CODES.get(http_status): + logger.warning(f"Invalid HTTP status code {http_status} from violation, defaulting to 200") + http_status = None + if not http_status: + logger.debug("Using Plugin violation code mapping for lack of http_status_code") + mapping: Optional[PluginViolationCode] = PLUGIN_VIOLATION_CODE_MAPPING.get(exc.violation.code) if exc.violation.code else None + if not mapping: + http_status = 200 + else: + http_status = mapping.code + json_rpc_error = PydanticJSONRPCError(code=status_code, message="Plugin Violation: " + message, data=violation_details) - return ORJSONResponse(status_code=200, content={"error": json_rpc_error.model_dump()}) + + # Collect HTTP headers from violation if present + headers = exc.violation.http_headers if exc.violation and exc.violation.http_headers else None + + response = ORJSONResponse(status_code=http_status, content={"error": json_rpc_error.model_dump()}) + if headers: + validatated_headers = _validate_http_headers(headers) + if validatated_headers: + response.headers.update(validatated_headers) + return response @app.exception_handler(PluginError) diff --git a/mcpgateway/plugins/framework/constants.py b/mcpgateway/plugins/framework/constants.py index be4f8fd736..fdae0c5492 100644 --- a/mcpgateway/plugins/framework/constants.py +++ b/mcpgateway/plugins/framework/constants.py @@ -10,6 +10,11 @@ # Standard +# Standard +from dataclasses import dataclass +from types import MappingProxyType +from typing import Mapping + # Model constants. # Specialized plugin types. EXTERNAL_PLUGIN_TYPE = "external" @@ -43,3 +48,90 @@ GET_PLUGIN_CONFIG = "get_plugin_config" HOOK_TYPE = "hook_type" INVOKE_HOOK = "invoke_hook" + + +@dataclass(frozen=True) +class PluginViolationCode: + """ + Plugin violation codes as an immutable dataclass object. + + Provide Mapping for violation codes to their corresponding HTTP status codes for proper error responses. + """ + + code: int + name: str + message: str + + +PLUGIN_VIOLATION_CODE_MAPPING: Mapping[str, PluginViolationCode] = MappingProxyType( + # MappingProxyType will make sure the resulting object is immutable and hence this will act as a constant. + { + # Rate Limiting + "RATE_LIMIT": PluginViolationCode(429, "RATE_LIMIT", "Used when rate limit is exceeded (rate_limiter plugin)"), + # Resource & URI Validation + "INVALID_URI": PluginViolationCode(400, "INVALID_URI", "Used when URI cannot be parsed or has invalid format (resource_filter, cedar, opa)"), + "PROTOCOL_BLOCKED": PluginViolationCode(403, "PROTOCOL_BLOCKED", "Used when protocol/scheme is not allowed (resource_filter)"), + "DOMAIN_BLOCKED": PluginViolationCode(403, "DOMAIN_BLOCKED", "Used when domain is in blocklist (resource_filter)"), + "CONTENT_TOO_LARGE": PluginViolationCode(413, "CONTENT_TOO_LARGE", "Used when resource content exceeds size limit (resource_filter)"), + # Content Moderation & Safety + "CONTENT_MODERATION": PluginViolationCode(422, "CONTENT_MODERATION", "Used when harmful content is detected (content_moderation plugin)"), + "MODERATION_ERROR": PluginViolationCode(503, "MODERATION_ERROR", "Used when moderation service fails (content_moderation plugin)"), + "PII_DETECTED": PluginViolationCode(422, "PII_DETECTED", "Used when PII is detected in content (pii_filter plugin)"), + "SENSITIVE_CONTENT": PluginViolationCode(422, "SENSITIVE_CONTENT", "Used when sensitive information is detected"), + # Authentication & Authorization + "INVALID_TOKEN": PluginViolationCode(401, "INVALID_TOKEN", "Used for invalid/expired tokens (simple_token_auth example)"), # nosec B105 - Not a password; INVALID_TOKEN is a HTTP Status Code + "API_KEY_REVOKED": PluginViolationCode(401, "API_KEY_REVOKED", "Used when API key has been revoked (custom_auth_example)"), + "AUTH_REQUIRED": PluginViolationCode(401, "AUTH_REQUIRED", "Used when authentication is missing"), + # Generic Violation Codes + "PROHIBITED_CONTENT": PluginViolationCode(422, "PROHIBITED_CONTENT", "Used when content violates policy rules"), + "BLOCKED_CONTENT": PluginViolationCode(403, "BLOCKED_CONTENT", "Used when content is explicitly blocked by policy"), + "BLOCKED": PluginViolationCode(403, "BLOCKED", "Generic blocking violation"), + "EXECUTION_ERROR": PluginViolationCode(500, "EXECUTION_ERROR", "Used when plugin execution fails"), + "PROCESSING_ERROR": PluginViolationCode(500, "PROCESSING_ERROR", "Used when processing encounters an error"), + } +) + +VALID_HTTP_STATUS_CODES: dict[int, str] = { # RFC 9110 + # 4xx — Client Error + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Content Too Large", # (was "Payload Too Large" before RFC 9110) + 414: "URI Too Long", + 415: "Unsupported Media Type", + 416: "Range Not Satisfiable", + 417: "Expectation Failed", + 418: "(Unused)", + 421: "Misdirected Request", + 422: "Unprocessable Content", # (was "Unprocessable Entity") + 423: "Locked", + 424: "Failed Dependency", + 425: "Too Early", + 426: "Upgrade Required", + 428: "Precondition Required", + 429: "Too Many Requests", + 431: "Request Header Fields Too Large", + 451: "Unavailable For Legal Reasons", + # 5xx — Server Error + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", + 506: "Variant Also Negotiates", + 507: "Insufficient Storage", + 508: "Loop Detected", + 510: "Not Extended", + 511: "Network Authentication Required", +} diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 82c4be5f27..6687d1f691 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -1231,6 +1231,8 @@ class PluginViolation(BaseModel): details: (dict[str, Any]): additional violation details. _plugin_name (str): the plugin name, private attribute set by the plugin manager. mcp_error_code(Optional[int]): A valid mcp error code which will be sent back to the client if plugin enabled. + http_status_code (Optional[int]): HTTP status code to return (e.g., 429 for rate limiting). + http_headers (Optional[dict[str, str]]): HTTP headers to include in the response. Examples: >>> violation = PluginViolation( @@ -1254,6 +1256,8 @@ class PluginViolation(BaseModel): details: Optional[dict[str, Any]] = Field(default_factory=dict) _plugin_name: str = PrivateAttr(default="") mcp_error_code: Optional[int] = None + http_status_code: Optional[int] = None + http_headers: Optional[dict[str, str]] = None @property def plugin_name(self) -> str: @@ -1327,6 +1331,7 @@ class PluginResult(BaseModel, Generic[T]): modified_payload (Optional[Any]): The modified payload if the plugin is a transformer. violation (Optional[PluginViolation]): violation object. metadata (Optional[dict[str, Any]]): additional metadata. + http_headers (Optional[dict[str, str]]): HTTP headers to include in successful responses. Examples: >>> result = PluginResult() @@ -1355,6 +1360,7 @@ class PluginResult(BaseModel, Generic[T]): modified_payload: Optional[T] = None violation: Optional[PluginViolation] = None metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + http_headers: Optional[dict[str, str]] = None class GlobalContext(BaseModel): diff --git a/plugins/rate_limiter/rate_limiter.py b/plugins/rate_limiter/rate_limiter.py index 78eccafa4c..b61939545c 100644 --- a/plugins/rate_limiter/rate_limiter.py +++ b/plugins/rate_limiter/rate_limiter.py @@ -87,7 +87,7 @@ class _Window: _store: Dict[str, _Window] = {} -def _allow(key: str, limit: Optional[str]) -> tuple[bool, dict[str, Any]]: +def _allow(key: str, limit: Optional[str]) -> tuple[bool, int, int, dict[str, Any]]: """Check if a request is allowed under the rate limit. Args: @@ -95,23 +95,121 @@ def _allow(key: str, limit: Optional[str]) -> tuple[bool, dict[str, Any]]: limit: Rate limit string (e.g., '60/m') or None to allow unlimited. Returns: - Tuple of (allowed, metadata) where allowed is True if the request is allowed, - and metadata contains rate limiting information. + Tuple of (allowed, limit_count, reset_timestamp, metadata) where: + - allowed: True if the request is allowed + - limit_count: The rate limit count (0 if unlimited) + - reset_timestamp: Unix timestamp when the window resets (0 if unlimited) + - metadata: Additional rate limiting information """ if not limit: - return True, {"limited": False} + return True, 0, 0, {"limited": False} count, window_seconds = _parse_rate(limit) now = int(time.time()) win_key = f"{key}:{window_seconds}" wnd = _store.get(win_key) + if not wnd or now - wnd.window_start >= window_seconds: + # New window + reset_timestamp = now + window_seconds _store[win_key] = _Window(window_start=now, count=1) - return True, {"limited": True, "remaining": count - 1, "reset_in": window_seconds} + return True, count, reset_timestamp, {"limited": True, "remaining": count - 1, "reset_in": window_seconds} + + reset_timestamp = wnd.window_start + window_seconds if wnd.count < count: + # Within limit wnd.count += 1 - return True, {"limited": True, "remaining": count - wnd.count, "reset_in": window_seconds - (now - wnd.window_start)} - # exceeded - return False, {"limited": True, "remaining": 0, "reset_in": window_seconds - (now - wnd.window_start)} + reset_in = window_seconds - (now - wnd.window_start) + return True, count, reset_timestamp, {"limited": True, "remaining": count - wnd.count, "reset_in": reset_in} + + # Exceeded + reset_in = window_seconds - (now - wnd.window_start) + return False, count, reset_timestamp, {"limited": True, "remaining": 0, "reset_in": reset_in} + + +def _make_headers(limit: int, remaining: int, reset_timestamp: int, retry_after: int, include_retry_after: bool = True) -> dict[str, str]: + """Create RFC-compliant rate limit headers. + + Args: + limit: The rate limit count. + remaining: Number of requests remaining in the current window. + reset_timestamp: Unix timestamp when the window resets. + retry_after: Seconds until the window resets (for Retry-After header). + include_retry_after: Whether to include Retry-After header (only for violations). + + Returns: + Dictionary of HTTP headers for rate limiting. + """ + headers = { + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(reset_timestamp), + } + if include_retry_after: + headers["Retry-After"] = str(retry_after) + return headers + + +def _select_most_restrictive( + results: list[tuple[bool, int, int, dict[str, Any]]] +) -> tuple[bool, int, int, int, dict[str, Any]]: + """Select the most restrictive rate limit from multiple dimensions. + + Args: + results: List of (allowed, limit, reset_timestamp, metadata) tuples from _allow(). + - allowed: True if the request is allowed + - limit_count: The rate limit count (0 if unlimited) + - reset_timestamp: Unix timestamp when the window resets (0 if unlimited) + - metadata: Additional rate limiting information + + Returns: + Tuple of (allowed, limit, remaining, reset_timestamp, metadata) representing + the most restrictive limit. If any dimension is violated, allowed is False. + The metadata includes aggregated information from all dimensions. + """ + # Filter out unlimited results (limit == 0) + limited_results = [(allowed, limit, reset_ts, meta) for allowed, limit, reset_ts, meta in results if limit > 0] + + if not limited_results: + # All unlimited + return True, 0, 0, 0, {"limited": False} + + # Separate violated and allowed dimensions + violated = [(allowed, limit, reset_ts, meta) for allowed, limit, reset_ts, meta in limited_results if not allowed] + allowed_dims = [(allowed, limit, reset_ts, meta) for allowed, limit, reset_ts, meta in limited_results if allowed] + + # If any dimension is violated, pick the one with shortest retry_after (resets soonest) + if violated: + most_restrictive = min(violated, key=lambda x: x[3].get("reset_in", float("inf"))) + _, limit, reset_ts, meta = most_restrictive + remaining = meta.get("remaining", 0) + retry_after = meta.get("reset_in", 0) + + # Aggregate metadata from all dimensions for observability + aggregated_meta = { + "limited": True, + "remaining": remaining, + "reset_in": retry_after, + "dimensions": { + "violated": [m for _, _, _, m in violated], + "allowed": [m for _, _, _, m in allowed_dims], + } + } + return False, limit, remaining, reset_ts, aggregated_meta + + # All dimensions allowed - find the most restrictive (lowest remaining) + most_restrictive = min(allowed_dims, key=lambda x: x[3].get("remaining", float("inf"))) + _, limit, reset_ts, meta = most_restrictive + remaining = meta.get("remaining", 0) + retry_after = meta.get("reset_in", 0) + + # Aggregate metadata from all dimensions + aggregated_meta = { + "limited": True, + "remaining": remaining, + "reset_in": retry_after, + "dimensions": {"allowed": [m for _, _, _, m in allowed_dims]}, + } + return True, limit, remaining, reset_ts, aggregated_meta class RateLimiterPlugin(Plugin): @@ -139,31 +237,36 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC user = context.global_context.user or "anonymous" tenant = context.global_context.tenant_id or "default" - ok_u, meta_u = _allow(f"user:{user}", self._cfg.by_user) - if not ok_u: - return PromptPrehookResult( - continue_processing=False, - violation=PluginViolation( - reason="Rate limit exceeded", - description=f"User {user} rate limit exceeded", - code="RATE_LIMIT", - details=meta_u, - ), - ) + # Check all dimensions + results = [ + _allow(f"user:{user}", self._cfg.by_user), + _allow(f"tenant:{tenant}", self._cfg.by_tenant), + ] - ok_t, meta_t = _allow(f"tenant:{tenant}", self._cfg.by_tenant) - if not ok_t: + # Select most restrictive + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + retry_after = meta.get("reset_in", 0) + + if not allowed: + # Rate limit exceeded - include Retry-After header + headers = _make_headers(limit, remaining, reset_ts, retry_after, include_retry_after=True) return PromptPrehookResult( continue_processing=False, violation=PluginViolation( reason="Rate limit exceeded", - description=f"Tenant {tenant} rate limit exceeded", + description=f"Rate limit exceeded for user {user} or tenant {tenant}", code="RATE_LIMIT", - details=meta_t, + details=meta, + http_status_code=429, + http_headers=headers, ), ) - meta = {"by_user": meta_u, "by_tenant": meta_t} + # Success - include informational headers (without Retry-After) + if limit > 0: + headers = _make_headers(limit, remaining, reset_ts, retry_after, include_retry_after=False) + return PromptPrehookResult(metadata=meta, http_headers=headers) + return PromptPrehookResult(metadata=meta) async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: @@ -180,27 +283,40 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo user = context.global_context.user or "anonymous" tenant = context.global_context.tenant_id or "default" - meta: dict[str, Any] = {} - ok_u, meta_u = _allow(f"user:{user}", self._cfg.by_user) - ok_t, meta_t = _allow(f"tenant:{tenant}", self._cfg.by_tenant) - ok_tool = True - meta_tool: dict[str, Any] | None = None + # Check all dimensions + results = [ + _allow(f"user:{user}", self._cfg.by_user), + _allow(f"tenant:{tenant}", self._cfg.by_tenant), + ] + + # Check per-tool limit if configured by_tool_config = self._cfg.by_tool - if hasattr(by_tool_config, "__contains__"): - if tool in by_tool_config: # pylint: disable=unsupported-membership-test - ok_tool, meta_tool = _allow(f"tool:{tool}", by_tool_config[tool]) - meta.update({"by_user": meta_u, "by_tenant": meta_t}) - if meta_tool is not None: - meta["by_tool"] = meta_tool - - if not (ok_u and ok_t and ok_tool): + if by_tool_config: + if hasattr(by_tool_config, "__contains__") and tool in by_tool_config: # pylint: disable=unsupported-membership-test + results.append(_allow(f"tool:{tool}", by_tool_config[tool])) + + # Select most restrictive + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + retry_after = meta.get("reset_in", 0) + + if not allowed: + # Rate limit exceeded - include Retry-After header + headers = _make_headers(limit, remaining, reset_ts, retry_after, include_retry_after=True) return ToolPreInvokeResult( continue_processing=False, violation=PluginViolation( reason="Rate limit exceeded", - description=f"Rate limit exceeded for {'tool ' + tool if not ok_tool else ('user' if not ok_u else 'tenant')}", + description=f"Rate limit exceeded for tool {tool}, user {user}, or tenant {tenant}", code="RATE_LIMIT", details=meta, + http_status_code=429, + http_headers=headers, ), ) + + # Success - include informational headers (without Retry-After) + if limit > 0: + headers = _make_headers(limit, remaining, reset_ts, retry_after, include_retry_after=False) + return ToolPreInvokeResult(metadata=meta, http_headers=headers) + return ToolPreInvokeResult(metadata=meta) diff --git a/tests/integration/test_rate_limiter.py b/tests/integration/test_rate_limiter.py new file mode 100644 index 0000000000..765a4be0f1 --- /dev/null +++ b/tests/integration/test_rate_limiter.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Integration tests for rate limiter plugin. + +Tests verify: +1. Rate limit enforcement via plugin hooks +2. HTTP 429 status code on limit exceeded +3. Retry-After and X-RateLimit-* headers +4. Multi-dimensional rate limiting (user, tenant, tool) +5. Window reset behavior +6. Header propagation through exception handler +7. Plugin configuration from config file + +Note: This tests the PLUGIN-based rate limiting, not middleware. +The rate limiter is implemented as a plugin that hooks into +prompt_pre_fetch and tool_pre_invoke. +""" + +import asyncio +import time +from typing import AsyncIterator, Dict +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from mcpgateway.main import app +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginConfig, + PluginContext, + PromptPrehookPayload, + ToolPreInvokePayload, +) +from mcpgateway.utils.create_jwt_token import _create_jwt_token +from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _store + + +# API Endpoints +PROMPT_ENDPOINT = "/api/v1/prompts/" +TOOL_INVOKE_ENDPOINT = "/api/v1/tools/invoke" + + +@pytest.fixture(autouse=True) +def clear_rate_limit_store(): + """Clear rate limit store before and after each test.""" + _store.clear() + yield + _store.clear() + + +@pytest.fixture +def jwt_token_alice(): + """JWT token for user alice in team1.""" + return _create_jwt_token( + {"sub": "alice", "username": "alice"}, + expires_in_minutes=60, + user_data={"email": "alice@example.com", "full_name": "Alice", "is_admin": False, "auth_provider": "test"}, + teams=["team1"], + ) + + +@pytest.fixture +def jwt_token_bob(): + """JWT token for user bob in team2.""" + return _create_jwt_token( + {"sub": "bob", "username": "bob"}, + expires_in_minutes=60, + user_data={"email": "bob@example.com", "full_name": "Bob", "is_admin": False, "auth_provider": "test"}, + teams=["team2"], + ) + + +@pytest.fixture +def rate_limit_plugin_2_per_second(): + """Rate limiter plugin configured for 2 requests per second.""" + config = PluginConfig( + name="RateLimiter", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=["prompt_pre_fetch", "tool_pre_invoke"], + priority=100, + config={ + "by_user": "2/s", + "by_tenant": None, + "by_tool": {} + } + ) + return RateLimiterPlugin(config) + + +@pytest.fixture +def rate_limit_plugin_multi_dimensional(): + """Rate limiter plugin with multi-dimensional limits.""" + config = PluginConfig( + name="RateLimiter", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=["prompt_pre_fetch", "tool_pre_invoke"], + priority=100, + config={ + "by_user": "10/s", + "by_tenant": "5/s", + "by_tool": { + "restricted_tool": "1/s" + } + } + ) + return RateLimiterPlugin(config) + + +@pytest.fixture +def client(): + """FastAPI test client.""" + return TestClient(app) + + +class TestRateLimitBasics: + """Basic rate limit enforcement tests via plugin.""" + + @pytest.mark.asyncio + async def test_under_limit_allows_requests(self, rate_limit_plugin_2_per_second): + """Verify requests under limit are allowed.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # First request - should succeed + result1 = await plugin.prompt_pre_fetch(payload, ctx) + assert result1.violation is None + assert result1.http_headers is not None + assert result1.http_headers["X-RateLimit-Remaining"] == "1" + + # Second request - should succeed + result2 = await plugin.prompt_pre_fetch(payload, ctx) + assert result2.violation is None + assert result2.http_headers["X-RateLimit-Remaining"] == "0" + + @pytest.mark.asyncio + async def test_exceeding_limit_returns_violation(self, rate_limit_plugin_2_per_second): + """Verify exceeding limit returns violation with HTTP 429.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Exhaust rate limit + await plugin.prompt_pre_fetch(payload, ctx) + await plugin.prompt_pre_fetch(payload, ctx) + + # Third request should be rate limited + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is not None + assert result.violation.http_status_code == 429 + assert result.violation.code == "RATE_LIMIT" + assert "rate limit exceeded" in result.violation.description.lower() + + @pytest.mark.asyncio + async def test_rate_limit_headers_present(self, rate_limit_plugin_2_per_second): + """Verify all rate limit headers are present.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + result = await plugin.prompt_pre_fetch(payload, ctx) + + assert result.http_headers is not None + assert "X-RateLimit-Limit" in result.http_headers + assert "X-RateLimit-Remaining" in result.http_headers + assert "X-RateLimit-Reset" in result.http_headers + + limit = int(result.http_headers["X-RateLimit-Limit"]) + remaining = int(result.http_headers["X-RateLimit-Remaining"]) + reset = int(result.http_headers["X-RateLimit-Reset"]) + + assert limit == 2 + assert remaining == 1 + assert reset > int(time.time()) + + @pytest.mark.asyncio + async def test_retry_after_header_on_violation(self, rate_limit_plugin_2_per_second): + """Verify Retry-After header is present on violations.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Exhaust rate limit + await plugin.prompt_pre_fetch(payload, ctx) + await plugin.prompt_pre_fetch(payload, ctx) + + # Get violation + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is not None + assert result.violation.http_headers is not None + assert "Retry-After" in result.violation.http_headers + + retry_after = int(result.violation.http_headers["Retry-After"]) + assert 0 < retry_after <= 1 # 1 second window + + @pytest.mark.asyncio + async def test_success_response_no_retry_after(self, rate_limit_plugin_2_per_second): + """Verify successful responses don't include Retry-After header.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + result = await plugin.prompt_pre_fetch(payload, ctx) + + assert result.violation is None + assert result.http_headers is not None + assert "Retry-After" not in result.http_headers + + +class TestRateLimitAlgorithm: + """Window-based rate limiting algorithm tests.""" + + @pytest.mark.asyncio + async def test_remaining_count_decrements(self, rate_limit_plugin_2_per_second): + """Verify remaining count decrements correctly.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # First request + result1 = await plugin.prompt_pre_fetch(payload, ctx) + assert result1.http_headers["X-RateLimit-Remaining"] == "1" + + # Second request + result2 = await plugin.prompt_pre_fetch(payload, ctx) + assert result2.http_headers["X-RateLimit-Remaining"] == "0" + + @pytest.mark.asyncio + async def test_rate_limit_resets_after_window(self, rate_limit_plugin_2_per_second): + """Verify rate limit resets after the window expires.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Exhaust rate limit + await plugin.prompt_pre_fetch(payload, ctx) + await plugin.prompt_pre_fetch(payload, ctx) + + # Verify rate limited + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is not None + + # Wait for window to reset (1 second + buffer) + await asyncio.sleep(1.1) + + # Verify rate limit reset + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is None + assert result.http_headers["X-RateLimit-Remaining"] == "1" + + @pytest.mark.asyncio + async def test_reset_timestamp_accuracy(self, rate_limit_plugin_2_per_second): + """Verify X-RateLimit-Reset timestamp is accurate.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + result = await plugin.prompt_pre_fetch(payload, ctx) + reset_time = int(result.http_headers["X-RateLimit-Reset"]) + current_time = int(time.time()) + + # Reset should be current time + 1 second (with small tolerance) + expected_reset = current_time + 1 + assert abs(reset_time - expected_reset) <= 2 + + +class TestMultiDimensionalRateLimiting: + """Multi-dimensional rate limiting tests (user, tenant, tool).""" + + @pytest.mark.asyncio + async def test_user_rate_limit_enforced(self): + """Verify user rate limits are enforced independently per user.""" + # Configure with ONLY user limits (no tenant limit) + config = PluginConfig( + name="RateLimiter", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=["prompt_pre_fetch"], + priority=100, + config={ + "by_user": "10/s", + "by_tenant": None, # No tenant limit + "by_tool": {} + } + ) + plugin = RateLimiterPlugin(config) + + ctx_alice = PluginContext(global_context=GlobalContext(request_id="r1", user="alice", tenant_id="team1")) + ctx_bob = PluginContext(global_context=GlobalContext(request_id="r2", user="bob", tenant_id="team1")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Alice makes 10 requests (her limit) + for _ in range(10): + result = await plugin.prompt_pre_fetch(payload, ctx_alice) + assert result.violation is None + + # Alice's 11th request should be rate limited + result = await plugin.prompt_pre_fetch(payload, ctx_alice) + assert result.violation is not None + + # Bob should still have his own limit (not affected by Alice) + result = await plugin.prompt_pre_fetch(payload, ctx_bob) + assert result.violation is None + + @pytest.mark.asyncio + async def test_tenant_rate_limit_enforced(self, rate_limit_plugin_multi_dimensional): + """Verify tenant rate limits are enforced across users.""" + plugin = rate_limit_plugin_multi_dimensional + ctx_alice = PluginContext(global_context=GlobalContext(request_id="r1", user="alice", tenant_id="team1")) + ctx_bob = PluginContext(global_context=GlobalContext(request_id="r2", user="bob", tenant_id="team1")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Alice makes 3 requests + for _ in range(3): + result = await plugin.prompt_pre_fetch(payload, ctx_alice) + assert result.violation is None + + # Bob makes 2 requests (total 5 for team1) + for _ in range(2): + result = await plugin.prompt_pre_fetch(payload, ctx_bob) + assert result.violation is None + + # Next request from either user should be rate limited (tenant limit reached) + result = await plugin.prompt_pre_fetch(payload, ctx_alice) + assert result.violation is not None + assert result.violation.http_status_code == 429 + + @pytest.mark.asyncio + async def test_per_tool_rate_limiting(self, rate_limit_plugin_multi_dimensional): + """Verify per-tool rate limits are enforced.""" + plugin = rate_limit_plugin_multi_dimensional + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + + restricted_payload = ToolPreInvokePayload(name="restricted_tool", arguments={}) + unrestricted_payload = ToolPreInvokePayload(name="other_tool", arguments={}) + + # First call to restricted tool succeeds + result = await plugin.tool_pre_invoke(restricted_payload, ctx) + assert result.violation is None + + # Second call to restricted tool should be rate limited + result = await plugin.tool_pre_invoke(restricted_payload, ctx) + assert result.violation is not None + assert result.violation.http_status_code == 429 + + # Other tool should still work + result = await plugin.tool_pre_invoke(unrestricted_payload, ctx) + assert result.violation is None + + @pytest.mark.asyncio + async def test_most_restrictive_dimension_selected(self): + """Verify most restrictive dimension is selected.""" + # Configure with different limits + config = PluginConfig( + name="RateLimiter", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=["prompt_pre_fetch"], + priority=100, + config={ + "by_user": "10/s", # More permissive + "by_tenant": "2/s", # More restrictive + } + ) + plugin = RateLimiterPlugin(config) + + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice", tenant_id="team1")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Make 2 requests (tenant limit) + await plugin.prompt_pre_fetch(payload, ctx) + await plugin.prompt_pre_fetch(payload, ctx) + + # Third request should be rate limited by tenant limit + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is not None + # Headers should show tenant limit (2), not user limit (10) + assert result.violation.http_headers["X-RateLimit-Limit"] == "2" + + +class TestToolPreInvoke: + """Tests for tool_pre_invoke hook.""" + + @pytest.mark.asyncio + async def test_tool_invoke_rate_limiting(self, rate_limit_plugin_2_per_second): + """Verify tool invocations are rate limited.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + # First two requests succeed + result1 = await plugin.tool_pre_invoke(payload, ctx) + assert result1.violation is None + + result2 = await plugin.tool_pre_invoke(payload, ctx) + assert result2.violation is None + + # Third request should be rate limited + result3 = await plugin.tool_pre_invoke(payload, ctx) + assert result3.violation is not None + assert result3.violation.http_status_code == 429 + + @pytest.mark.asyncio + async def test_tool_invoke_headers_present(self, rate_limit_plugin_2_per_second): + """Verify headers are present on tool invocations.""" + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + result = await plugin.tool_pre_invoke(payload, ctx) + + assert result.http_headers is not None + assert "X-RateLimit-Limit" in result.http_headers + assert "X-RateLimit-Remaining" in result.http_headers + assert "X-RateLimit-Reset" in result.http_headers + assert "Retry-After" not in result.http_headers # Not on success + + +class TestStoreCleanup: + """Tests for rate limit store cleanup.""" + + @pytest.mark.asyncio + async def test_store_cleanup_between_tests(self, rate_limit_plugin_2_per_second): + """Verify store is cleaned up between tests.""" + # Store should be empty at start (autouse fixture) + assert len(_store) == 0 + + plugin = rate_limit_plugin_2_per_second + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Make a request + await plugin.prompt_pre_fetch(payload, ctx) + + # Store should have entries + assert len(_store) > 0 + + @pytest.mark.asyncio + async def test_multiple_users_create_separate_windows(self, rate_limit_plugin_2_per_second): + """Verify multiple users create separate windows in store.""" + plugin = rate_limit_plugin_2_per_second + + ctx_alice = PluginContext(global_context=GlobalContext(request_id="r1", user="alice")) + ctx_bob = PluginContext(global_context=GlobalContext(request_id="r2", user="bob")) + payload = PromptPrehookPayload(prompt_id="test", args={}) + + # Make requests from both users + await plugin.prompt_pre_fetch(payload, ctx_alice) + await plugin.prompt_pre_fetch(payload, ctx_bob) + + # Store should have entries for both users + assert len(_store) >= 2 diff --git a/tests/unit/mcpgateway/plugins/framework/test_constants.py b/tests/unit/mcpgateway/plugins/framework/test_constants.py new file mode 100644 index 0000000000..2022e2b1c7 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/test_constants.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# pylint: disable=wrong-import-position, import-outside-toplevel, no-name-in-module +"""Location: ./tests/unit/mcpgateway/plugins/framework/test_constants.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Tests for plugin framework constants. +""" + +import pytest +from types import MappingProxyType + +from mcpgateway.plugins.framework.constants import ( + PLUGIN_VIOLATION_CODE_MAPPING, + PluginViolationCode, +) + + +class TestPluginViolationCode: + """Test suite for PluginViolationCode dataclass.""" + + def test_dataclass_immutability(self): + """Test that PluginViolationCode dataclass is frozen (immutable).""" + code = PluginViolationCode(429, "RATE_LIMIT", "Rate limit exceeded") + with pytest.raises(AttributeError): + code.code = 500 + + def test_dataclass_structure(self): + """Test that PluginViolationCode has correct fields.""" + code = PluginViolationCode(429, "RATE_LIMIT", "Rate limit exceeded") + assert code.code == 429 + assert code.name == "RATE_LIMIT" + assert code.message == "Rate limit exceeded" + + def test_dataclass_equality(self): + """Test that dataclass instances with same values are equal.""" + code1 = PluginViolationCode(429, "RATE_LIMIT", "Rate limit exceeded") + code2 = PluginViolationCode(429, "RATE_LIMIT", "Rate limit exceeded") + assert code1 == code2 + + def test_dataclass_hashable(self): + """Test that frozen dataclass instances are hashable.""" + code1 = PluginViolationCode(429, "RATE_LIMIT", "Rate limit exceeded") + code2 = PluginViolationCode(422, "PII_DETECTED", "PII detected") + codes_dict = {code1: "rate_limit", code2: "pii"} + assert codes_dict[code1] == "rate_limit" + assert codes_dict[code2] == "pii" + + +class TestPluginViolationCodeMapping: + """Test suite for PLUGIN_VIOLATION_CODE_MAPPING.""" + + def test_mapping_is_mappingproxy(self): + """Test that mapping is wrapped in MappingProxyType for immutability.""" + assert isinstance(PLUGIN_VIOLATION_CODE_MAPPING, MappingProxyType) + + def test_mapping_is_immutable(self): + """Test that mapping cannot be modified.""" + with pytest.raises(TypeError): + PLUGIN_VIOLATION_CODE_MAPPING["NEW_CODE"] = PluginViolationCode(999, "NEW", "New code") + + def test_mapping_values_are_dataclass_instances(self): + """Test that all mapping values are PluginViolationCode instances.""" + for key, value in PLUGIN_VIOLATION_CODE_MAPPING.items(): + assert isinstance(value, PluginViolationCode) + assert isinstance(key, str) + + def test_rate_limiting_codes(self): + """Test rate limiting violation codes.""" + rate_limit = PLUGIN_VIOLATION_CODE_MAPPING["RATE_LIMIT"] + assert rate_limit.code == 429 + assert rate_limit.name == "RATE_LIMIT" + assert "rate limit" in rate_limit.message.lower() + + def test_resource_validation_codes(self): + """Test resource and URI validation codes.""" + invalid_uri = PLUGIN_VIOLATION_CODE_MAPPING["INVALID_URI"] + assert invalid_uri.code == 400 + assert invalid_uri.name == "INVALID_URI" + + protocol_blocked = PLUGIN_VIOLATION_CODE_MAPPING["PROTOCOL_BLOCKED"] + assert protocol_blocked.code == 403 + assert protocol_blocked.name == "PROTOCOL_BLOCKED" + + domain_blocked = PLUGIN_VIOLATION_CODE_MAPPING["DOMAIN_BLOCKED"] + assert domain_blocked.code == 403 + assert domain_blocked.name == "DOMAIN_BLOCKED" + + content_too_large = PLUGIN_VIOLATION_CODE_MAPPING["CONTENT_TOO_LARGE"] + assert content_too_large.code == 413 + assert content_too_large.name == "CONTENT_TOO_LARGE" + + def test_content_moderation_codes(self): + """Test content moderation and safety codes.""" + content_moderation = PLUGIN_VIOLATION_CODE_MAPPING["CONTENT_MODERATION"] + assert content_moderation.code == 422 + assert content_moderation.name == "CONTENT_MODERATION" + + moderation_error = PLUGIN_VIOLATION_CODE_MAPPING["MODERATION_ERROR"] + assert moderation_error.code == 503 + assert moderation_error.name == "MODERATION_ERROR" + + pii_detected = PLUGIN_VIOLATION_CODE_MAPPING["PII_DETECTED"] + assert pii_detected.code == 422 + assert pii_detected.name == "PII_DETECTED" + + sensitive_content = PLUGIN_VIOLATION_CODE_MAPPING["SENSITIVE_CONTENT"] + assert sensitive_content.code == 422 + assert sensitive_content.name == "SENSITIVE_CONTENT" + + def test_authentication_codes(self): + """Test authentication and authorization codes.""" + invalid_token = PLUGIN_VIOLATION_CODE_MAPPING["INVALID_TOKEN"] + assert invalid_token.code == 401 + assert invalid_token.name == "INVALID_TOKEN" + + api_key_revoked = PLUGIN_VIOLATION_CODE_MAPPING["API_KEY_REVOKED"] + assert api_key_revoked.code == 401 + assert api_key_revoked.name == "API_KEY_REVOKED" + + auth_required = PLUGIN_VIOLATION_CODE_MAPPING["AUTH_REQUIRED"] + assert auth_required.code == 401 + assert auth_required.name == "AUTH_REQUIRED" + + def test_generic_violation_codes(self): + """Test generic violation codes.""" + prohibited_content = PLUGIN_VIOLATION_CODE_MAPPING["PROHIBITED_CONTENT"] + assert prohibited_content.code == 422 + assert prohibited_content.name == "PROHIBITED_CONTENT" + + blocked_content = PLUGIN_VIOLATION_CODE_MAPPING["BLOCKED_CONTENT"] + assert blocked_content.code == 403 + assert blocked_content.name == "BLOCKED_CONTENT" + + blocked = PLUGIN_VIOLATION_CODE_MAPPING["BLOCKED"] + assert blocked.code == 403 + assert blocked.name == "BLOCKED" + + execution_error = PLUGIN_VIOLATION_CODE_MAPPING["EXECUTION_ERROR"] + assert execution_error.code == 500 + assert execution_error.name == "EXECUTION_ERROR" + + processing_error = PLUGIN_VIOLATION_CODE_MAPPING["PROCESSING_ERROR"] + assert processing_error.code == 500 + assert processing_error.name == "PROCESSING_ERROR" + + def test_mapping_contains_expected_keys(self): + """Test that mapping contains all expected violation code keys.""" + expected_keys = { + "RATE_LIMIT", + "INVALID_URI", + "PROTOCOL_BLOCKED", + "DOMAIN_BLOCKED", + "CONTENT_TOO_LARGE", + "CONTENT_MODERATION", + "MODERATION_ERROR", + "PII_DETECTED", + "SENSITIVE_CONTENT", + "INVALID_TOKEN", + "API_KEY_REVOKED", + "AUTH_REQUIRED", + "PROHIBITED_CONTENT", + "BLOCKED_CONTENT", + "BLOCKED", + "EXECUTION_ERROR", + "PROCESSING_ERROR", + } + assert set(PLUGIN_VIOLATION_CODE_MAPPING.keys()) == expected_keys + + def test_mapping_count(self): + """Test that mapping contains expected number of codes.""" + assert len(PLUGIN_VIOLATION_CODE_MAPPING) == 17 + + def test_code_ranges(self): + """Test that codes are in expected HTTP status code ranges.""" + for key, violation_code in PLUGIN_VIOLATION_CODE_MAPPING.items(): + # All codes should be valid HTTP status codes + assert 400 <= violation_code.code <= 599, f"{key} has invalid code {violation_code.code}" + + def test_dataclass_repr(self): + """Test that dataclass has a useful string representation.""" + code = PLUGIN_VIOLATION_CODE_MAPPING["RATE_LIMIT"] + repr_str = repr(code) + assert "429" in repr_str + assert "RATE_LIMIT" in repr_str + + def test_all_codes_have_messages(self): + """Test that all violation codes have non-empty messages.""" + for key, violation_code in PLUGIN_VIOLATION_CODE_MAPPING.items(): + assert violation_code.message, f"{key} has empty message" + assert len(violation_code.message) > 0, f"{key} has empty message" + + def test_code_name_consistency(self): + """Test that code names match their dictionary keys.""" + for key, violation_code in PLUGIN_VIOLATION_CODE_MAPPING.items(): + assert violation_code.name == key, f"Key {key} doesn't match code name {violation_code.name}" + + def test_http_status_code_categories(self): + """Test that codes are grouped by HTTP status code categories.""" + # 4xx Client Errors + client_error_codes = [ + "INVALID_URI", + "PROTOCOL_BLOCKED", + "DOMAIN_BLOCKED", + "CONTENT_TOO_LARGE", + "CONTENT_MODERATION", + "PII_DETECTED", + "SENSITIVE_CONTENT", + "INVALID_TOKEN", + "API_KEY_REVOKED", + "AUTH_REQUIRED", + "PROHIBITED_CONTENT", + "BLOCKED_CONTENT", + "BLOCKED", + "RATE_LIMIT", + ] + for key in client_error_codes: + code = PLUGIN_VIOLATION_CODE_MAPPING[key] + assert 400 <= code.code < 500, f"{key} should be 4xx client error" + + # 5xx Server Errors + server_error_codes = ["MODERATION_ERROR", "EXECUTION_ERROR", "PROCESSING_ERROR"] + for key in server_error_codes: + code = PLUGIN_VIOLATION_CODE_MAPPING[key] + assert 500 <= code.code < 600, f"{key} should be 5xx server error" diff --git a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py index 2ee6d0db34..9fba1054f0 100644 --- a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +++ b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py @@ -17,7 +17,15 @@ PromptPrehookPayload, ToolHookType ) -from plugins.rate_limiter.rate_limiter import RateLimiterPlugin +from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _make_headers, _select_most_restrictive, _store + + +@pytest.fixture(autouse=True) +def clear_rate_limit_store(): + """Clear the rate limiter store before each test to ensure test isolation.""" + _store.clear() + yield + _store.clear() def _mk(rate: str) -> RateLimiterPlugin: @@ -42,3 +50,462 @@ async def test_rate_limit_blocks_on_third_call(): assert r2.violation is None r3 = await plugin.prompt_pre_fetch(payload, ctx) assert r3.violation is not None + + +# ============================================================================ +# HTTP 429 Status Code Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_prompt_pre_fetch_violation_returns_http_429(): + """Test that rate limit violations return HTTP 429 status code.""" + plugin = _mk("1/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = PromptPrehookPayload(prompt_id="p", args={}) + + # First request succeeds + r1 = await plugin.prompt_pre_fetch(payload, ctx) + assert r1.violation is None + + # Second request should be rate limited + r2 = await plugin.prompt_pre_fetch(payload, ctx) + assert r2.violation is not None + assert r2.violation.http_status_code == 429 + assert r2.violation.code == "RATE_LIMIT" + + +@pytest.mark.asyncio +async def test_prompt_pre_fetch_violation_includes_all_headers(): + """Test that violations include all RFC-compliant rate limit headers.""" + plugin = _mk("2/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = PromptPrehookPayload(prompt_id="p", args={}) + + # Trigger rate limit + await plugin.prompt_pre_fetch(payload, ctx) # 1st + await plugin.prompt_pre_fetch(payload, ctx) # 2nd + result = await plugin.prompt_pre_fetch(payload, ctx) # 3rd - exceeds limit + + assert result.violation is not None + headers = result.violation.http_headers + assert headers is not None + + # Verify all required headers + assert "X-RateLimit-Limit" in headers + assert headers["X-RateLimit-Limit"] == "2" + + assert "X-RateLimit-Remaining" in headers + assert headers["X-RateLimit-Remaining"] == "0" + + assert "X-RateLimit-Reset" in headers + assert int(headers["X-RateLimit-Reset"]) > 0 + + assert "Retry-After" in headers + assert int(headers["Retry-After"]) > 0 + + +@pytest.mark.asyncio +async def test_prompt_pre_fetch_success_includes_headers_without_retry_after(): + """Test that successful requests include headers but not Retry-After.""" + plugin = _mk("10/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = PromptPrehookPayload(prompt_id="p", args={}) + + result = await plugin.prompt_pre_fetch(payload, ctx) + + assert result.violation is None + assert result.http_headers is not None + + headers = result.http_headers + assert "X-RateLimit-Limit" in headers + assert headers["X-RateLimit-Limit"] == "10" + + assert "X-RateLimit-Remaining" in headers + assert headers["X-RateLimit-Remaining"] == "9" # 1 used, 9 remaining + + assert "X-RateLimit-Reset" in headers + assert int(headers["X-RateLimit-Reset"]) > 0 + + assert "Retry-After" not in headers # Should NOT be present on success + + +# ============================================================================ +# tool_pre_invoke Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_tool_pre_invoke_violation_returns_http_429(): + """Test that tool_pre_invoke violations return HTTP 429 status code.""" + from mcpgateway.plugins.framework import ToolPreInvokePayload + + plugin = _mk("1/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + # First request succeeds + r1 = await plugin.tool_pre_invoke(payload, ctx) + assert r1.violation is None + + # Second request should be rate limited + r2 = await plugin.tool_pre_invoke(payload, ctx) + assert r2.violation is not None + assert r2.violation.http_status_code == 429 + assert r2.violation.code == "RATE_LIMIT" + + +@pytest.mark.asyncio +async def test_tool_pre_invoke_violation_includes_headers(): + """Test that tool_pre_invoke violations include rate limit headers.""" + from mcpgateway.plugins.framework import ToolPreInvokePayload + + plugin = _mk("2/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + # Trigger rate limit + await plugin.tool_pre_invoke(payload, ctx) # 1st + await plugin.tool_pre_invoke(payload, ctx) # 2nd + result = await plugin.tool_pre_invoke(payload, ctx) # 3rd - exceeds limit + + assert result.violation is not None + headers = result.violation.http_headers + assert headers is not None + + # Verify headers are present + assert "X-RateLimit-Limit" in headers + assert "X-RateLimit-Remaining" in headers + assert headers["X-RateLimit-Remaining"] == "0" + assert "X-RateLimit-Reset" in headers + assert "Retry-After" in headers + + +@pytest.mark.asyncio +async def test_tool_pre_invoke_success_includes_headers_without_retry_after(): + """Test that successful tool invocations include headers but not Retry-After.""" + from mcpgateway.plugins.framework import ToolPreInvokePayload + + plugin = _mk("10/s") + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + result = await plugin.tool_pre_invoke(payload, ctx) + + assert result.violation is None + assert result.http_headers is not None + + headers = result.http_headers + assert "X-RateLimit-Limit" in headers + assert "X-RateLimit-Remaining" in headers + assert "X-RateLimit-Reset" in headers + assert "Retry-After" not in headers + + +@pytest.mark.asyncio +async def test_tool_pre_invoke_per_tool_rate_limiting(): + """Test per-tool rate limiting configuration.""" + from mcpgateway.plugins.framework import ToolPreInvokePayload + + plugin = RateLimiterPlugin( + PluginConfig( + name="rl", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=[ToolHookType.TOOL_PRE_INVOKE], + config={ + "by_user": "100/s", # High user limit + "by_tool": { + "restricted_tool": "1/s" # Low tool-specific limit + } + }, + ) + ) + + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + restricted_payload = ToolPreInvokePayload(name="restricted_tool", arguments={}) + unrestricted_payload = ToolPreInvokePayload(name="other_tool", arguments={}) + + # First call to restricted tool succeeds + r1 = await plugin.tool_pre_invoke(restricted_payload, ctx) + assert r1.violation is None + + # Second call to same tool should be rate limited + r2 = await plugin.tool_pre_invoke(restricted_payload, ctx) + assert r2.violation is not None + assert r2.violation.http_status_code == 429 + + # But other tool should still work (only user limit applies) + r3 = await plugin.tool_pre_invoke(unrestricted_payload, ctx) + assert r3.violation is None + + +# ============================================================================ +# Helper Function Tests +# ============================================================================ + + +def test_make_headers_with_retry_after(): + """Test header generation with Retry-After.""" + headers = _make_headers(limit=60, remaining=0, reset_timestamp=1737394800, retry_after=35, include_retry_after=True) + + assert headers["X-RateLimit-Limit"] == "60" + assert headers["X-RateLimit-Remaining"] == "0" + assert headers["X-RateLimit-Reset"] == "1737394800" + assert headers["Retry-After"] == "35" + + +def test_make_headers_without_retry_after(): + """Test header generation without Retry-After.""" + headers = _make_headers(limit=60, remaining=45, reset_timestamp=1737394800, retry_after=35, include_retry_after=False) + + assert headers["X-RateLimit-Limit"] == "60" + assert headers["X-RateLimit-Remaining"] == "45" + assert headers["X-RateLimit-Reset"] == "1737394800" + assert "Retry-After" not in headers + + +# ============================================================================ +# _select_most_restrictive TESTS +# ============================================================================ + +class TestSelectMostRestrictive: + """Comprehensive tests for _select_most_restrictive function.""" + + # Test Category 1: Edge Cases & Empty Handling + + def test_empty_list_returns_unlimited(self): + """Empty list should return unlimited result.""" + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive([]) + assert allowed is True + assert limit == 0 + assert remaining == 0 + assert reset_ts == 0 + assert meta == {"limited": False} + + def test_single_unlimited_result(self): + """Single unlimited result (limit=0) should return unlimited.""" + results = [(True, 0, 0, {"limited": False})] + allowed, limit, _remaining, _reset_ts, meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 0 + assert meta["limited"] is False + + def test_all_unlimited_results(self): + """All unlimited results should return unlimited.""" + results = [ + (True, 0, 0, {"limited": False}), + (True, 0, 0, {"limited": False}), + (True, 0, 0, {"limited": False}), + ] + allowed, limit, _remaining, _reset_ts, meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 0 + assert meta["limited"] is False + + # Test Category 2: Single Dimension + + def test_single_violated_dimension(self): + """Single violated dimension should be returned with remaining=0.""" + now = 1000 + results = [(False, 10, now + 60, {"limited": True, "remaining": 0, "reset_in": 60})] + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 10 + assert remaining == 0 + assert reset_ts == now + 60 + assert meta["reset_in"] == 60 + + def test_single_allowed_dimension(self): + """Single allowed dimension should be returned with correct remaining.""" + now = 1000 + results = [(True, 100, now + 60, {"limited": True, "remaining": 95, "reset_in": 60})] + allowed, limit, remaining, reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 100 + assert remaining == 95 + assert reset_ts == now + 60 + + # Test Category 3: Multiple Violated Dimensions - Select Shortest Reset + + def test_multiple_violated_shortest_reset_wins(self): + """When multiple violated, select the one with shortest reset time.""" + now = 1000 + results = [ + (False, 10, now + 30, {"limited": True, "remaining": 0, "reset_in": 30}), # Resets sooner + (False, 20, now + 60, {"limited": True, "remaining": 0, "reset_in": 60}), + (False, 30, now + 120, {"limited": True, "remaining": 0, "reset_in": 120}), + ] + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 10 # Shortest reset_in (30) + assert remaining == 0 + assert reset_ts == now + 30 + assert meta["reset_in"] == 30 + + def test_violated_with_allowed_dimensions(self): + """When some violated and some allowed, violated takes precedence.""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 90, "reset_in": 60}), # Allowed + (False, 50, now + 30, {"limited": True, "remaining": 0, "reset_in": 30}), # Violated (shortest) + (False, 75, now + 90, {"limited": True, "remaining": 0, "reset_in": 90}), # Violated + ] + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 50 # Violated with shortest reset + assert remaining == 0 + assert reset_ts == now + 30 + assert "dimensions" in meta + assert "violated" in meta["dimensions"] + assert "allowed" in meta["dimensions"] + + def test_multiple_violated_equal_reset_times(self): + """When multiple violated with equal reset times, first one wins (stable).""" + now = 1000 + results = [ + (False, 10, now + 60, {"limited": True, "remaining": 0, "reset_in": 60}), + (False, 20, now + 60, {"limited": True, "remaining": 0, "reset_in": 60}), + ] + allowed, limit, remaining, _reset_ts, meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 10 # First one with shortest reset + assert remaining == 0 + assert meta["reset_in"] == 60 + + # Test Category 4: Multiple Allowed Dimensions - Select Lowest Remaining + + def test_multiple_allowed_lowest_remaining_wins(self): + """When all allowed, select the one with lowest remaining.""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 50, "reset_in": 60}), + (True, 200, now + 60, {"limited": True, "remaining": 10, "reset_in": 60}), # Lowest remaining + (True, 150, now + 60, {"limited": True, "remaining": 75, "reset_in": 60}), + ] + allowed, limit, remaining, reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 200 # Has lowest remaining (10) + assert remaining == 10 + assert reset_ts == now + 60 + + def test_allowed_with_equal_remaining(self): + """When remaining is equal, first one wins (stable sort).""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 25, "reset_in": 60}), + (True, 200, now + 30, {"limited": True, "remaining": 25, "reset_in": 30}), + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert remaining == 25 + assert limit == 100 # First one when remaining is equal + + def test_two_allowed_different_remaining(self): + """Two allowed dimensions with different remaining.""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 80, "reset_in": 60}), + (True, 50, now + 60, {"limited": True, "remaining": 40, "reset_in": 60}), # Lower remaining + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 50 + assert remaining == 40 + + # Test Category 5: Mixed Limited and Unlimited + + def test_limited_more_restrictive_than_unlimited(self): + """Limited dimension should be selected over unlimited.""" + now = 1000 + results = [ + (True, 0, 0, {"limited": False}), # Unlimited + (True, 100, now + 60, {"limited": True, "remaining": 95, "reset_in": 60}), # Limited + ] + allowed, limit, remaining, _reset_ts, meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 100 # Limited dimension selected + assert remaining == 95 + assert meta["limited"] is True + + def test_violated_limited_with_unlimited(self): + """Violated limited dimension should be selected over unlimited.""" + now = 1000 + results = [ + (True, 0, 0, {"limited": False}), # Unlimited + (False, 50, now + 30, {"limited": True, "remaining": 0, "reset_in": 30}), # Violated + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 50 + assert remaining == 0 + + def test_multiple_unlimited_with_one_limited(self): + """Multiple unlimited with one limited should select limited.""" + now = 1000 + results = [ + (True, 0, 0, {"limited": False}), + (True, 0, 0, {"limited": False}), + (True, 75, now + 60, {"limited": True, "remaining": 60, "reset_in": 60}), + (True, 0, 0, {"limited": False}), + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 75 + assert remaining == 60 + + # Test Category 6: Realistic Scenarios + + def test_user_tenant_tool_all_allowed(self): + """Realistic scenario: user, tenant, tool all allowed.""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 80, "reset_in": 60}), # User + (True, 1000, now + 60, {"limited": True, "remaining": 950, "reset_in": 60}), # Tenant + (True, 50, now + 60, {"limited": True, "remaining": 40, "reset_in": 60}), # Tool (most restrictive) + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 50 # Tool has lowest remaining (40) + assert remaining == 40 + + def test_user_violated_tenant_tool_allowed(self): + """Realistic scenario: user violated, others allowed.""" + now = 1000 + results = [ + (False, 100, now + 30, {"limited": True, "remaining": 0, "reset_in": 30}), # User violated + (True, 1000, now + 60, {"limited": True, "remaining": 950, "reset_in": 60}), # Tenant allowed + (True, 50, now + 60, {"limited": True, "remaining": 40, "reset_in": 60}), # Tool allowed + ] + allowed, limit, remaining, reset_ts, _meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 100 # User's violated limit + assert remaining == 0 + assert reset_ts == now + 30 + + def test_multiple_violated_different_reset_times(self): + """Realistic scenario: multiple violated with different reset times.""" + now = 1000 + results = [ + (False, 100, now + 60, {"limited": True, "remaining": 0, "reset_in": 60}), # User + (False, 1000, now + 10, {"limited": True, "remaining": 0, "reset_in": 10}), # Tenant (soonest) + (False, 50, now + 30, {"limited": True, "remaining": 0, "reset_in": 30}), # Tool + ] + allowed, limit, remaining, reset_ts, meta = _select_most_restrictive(results) + assert allowed is False + assert limit == 1000 # Tenant resets soonest + assert remaining == 0 + assert reset_ts == now + 10 + assert meta["reset_in"] == 10 + + def test_tenant_unlimited_user_tool_limited(self): + """Realistic scenario: tenant unlimited, user and tool have limits.""" + now = 1000 + results = [ + (True, 100, now + 60, {"limited": True, "remaining": 80, "reset_in": 60}), # User + (True, 0, 0, {"limited": False}), # Tenant unlimited + (True, 50, now + 60, {"limited": True, "remaining": 30, "reset_in": 60}), # Tool (most restrictive) + ] + allowed, limit, remaining, _reset_ts, _meta = _select_most_restrictive(results) + assert allowed is True + assert limit == 50 # Tool is most restrictive + assert remaining == 30 diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index f20e078684..9c18df9817 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -31,7 +31,20 @@ from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities from mcpgateway.config import settings import mcpgateway.db as db_mod -from mcpgateway.schemas import A2AAgentAggregateMetrics, GatewayRead, PromptMetrics, PromptRead, ResourceMetrics, ResourceRead, ServerMetrics, ServerRead, ToolMetrics, ToolRead +from mcpgateway.schemas import ( + A2AAgentAggregateMetrics, + GatewayRead, + PromptMetrics, + PromptRead, + ResourceMetrics, + ResourceRead, + ServerMetrics, + ServerRead, + ToolMetrics, + ToolRead, +) +from mcpgateway.plugins.framework.constants import PLUGIN_VIOLATION_CODE_MAPPING + # --------------------------------------------------------------------------- # # Constants # @@ -3126,7 +3139,11 @@ class TestPluginExceptionHandlers: """Tests for plugin exception handlers: PluginViolationError and PluginError.""" def test_plugin_violation_exception_handler_with_full_violation(self): - """Test plugin_violation_exception_handler with complete violation details.""" + """Test plugin_violation_exception_handler with complete violation details. + + Updated to verify backward compatibility with new http_status_code and http_headers fields. + This test verifies that the code mapping (PROHIBITED_CONTENT -> 422) still works. + """ # Standard import asyncio @@ -3146,7 +3163,8 @@ def test_plugin_violation_exception_handler_with_full_violation(self): result = asyncio.run(plugin_violation_exception_handler(None, exc)) - assert result.status_code == 200 + # Verify code mapping works (PROHIBITED_CONTENT -> 422 in PLUGIN_VIOLATION_CODE_MAPPING) + assert result.status_code == PLUGIN_VIOLATION_CODE_MAPPING["PROHIBITED_CONTENT"].code # Uses mapping content = json.loads(result.body.decode()) assert "error" in content assert content["error"]["code"] == -32602 @@ -3179,7 +3197,7 @@ def test_plugin_violation_exception_handler_with_custom_mcp_error_code(self): result = asyncio.run(plugin_violation_exception_handler(None, exc)) - assert result.status_code == 200 + assert result.status_code == 429 content = json.loads(result.body.decode()) assert content["error"]["code"] == -32000 assert "Too many requests from this client" in content["error"]["message"] @@ -3223,13 +3241,324 @@ def test_plugin_violation_exception_handler_without_violation_object(self): exc = PluginViolationError(message="Generic plugin violation", violation=None) + def test_plugin_violation_exception_handler_with_http_status_code(self): + """Test that violation HTTP status code is used in response.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Rate limit exceeded", + description="Too many requests", + code="RATE_LIMIT", + http_status_code=429, # NEW FIELD + ) + exc = PluginViolationError(message="Rate limited", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 429 # Should use violation's HTTP status + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + assert "Too many requests" in content["error"]["message"] # Uses description + + def test_plugin_violation_exception_handler_with_http_headers(self): + """Test that violation HTTP headers are included in response.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Rate limit exceeded", + description="Too many requests", + code="RATE_LIMIT", + http_status_code=429, + http_headers={"Retry-After": "60", "X-RateLimit-Limit": "100"}, # NEW FIELD + ) + exc = PluginViolationError(message="Rate limited", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 429 + assert "Retry-After" in result.headers + assert result.headers["Retry-After"] == "60" + assert result.headers["X-RateLimit-Limit"] == "100" + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_exception_handler_with_code_mapping_fallback(self): + """Test that PLUGIN_VIOLATION_CODE_MAPPING is used when no explicit HTTP status.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + # Assumes PLUGIN_VIOLATION_CODE_MAPPING has {"RATE_LIMIT": 429} + violation = PluginViolation( + reason="Rate limit exceeded", + description="Too many requests", + code="RATE_LIMIT", + # No http_status_code field + ) + exc = PluginViolationError(message="Rate limited", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 429 # Should use mapping + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_exception_handler_defaults_to_200(self): + """Test that response defaults to 200 when no HTTP status is provided.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Invalid input", + description="Bad data", + code="UNKNOWN_CODE", # Not in mapping + # No http_status_code + ) + exc = PluginViolationError(message="Violation", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 200 # JSON-RPC default + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_exception_handler_no_headers_when_none(self): + """Test that no headers are added when violation has none.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Error", + description="Something failed", + code="ERROR", + http_status_code=400, + # No http_headers + ) + exc = PluginViolationError(message="Failed", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 400 + # Should not crash when headers is None + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_http_status_takes_precedence_over_mapping(self): + """Verify that explicit http_status_code takes precedence over code mapping.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + # PLUGIN_VIOLATION_CODE_MAPPING has "RATE_LIMIT": 429 + violation = PluginViolation( + reason="Rate limit", + description="Service unavailable", + code="RATE_LIMIT", + http_status_code=503, # Explicit status should win + ) + exc = PluginViolationError(message="Limited", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 503 # Not 429 from mapping + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_with_multiple_rate_limit_headers(self): + """Verify all rate limit headers are properly included.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Rate limit", + description="Too many requests", + code="RATE_LIMIT", + http_status_code=429, + http_headers={ + "X-RateLimit-Limit": "60", + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": "1737394800", + "Retry-After": "35", + }, + ) + exc = PluginViolationError(message="Limited", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 429 + assert result.headers["X-RateLimit-Limit"] == "60" + assert result.headers["X-RateLimit-Remaining"] == "0" + assert result.headers["X-RateLimit-Reset"] == "1737394800" + assert result.headers["Retry-After"] == "35" + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_unknown_code_defaults_to_200(self): + """Verify unknown codes not in mapping default to 200.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Unknown error", + description="Something unexpected happened", + code="UNKNOWN_CODE_NOT_IN_MAPPING", + # No explicit http_status_code + ) + exc = PluginViolationError(message="Error", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + assert result.status_code == 200 # Default for JSON-RPC + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + assert "Something unexpected happened" in content["error"]["message"] # Uses description + + def test_plugin_violation_invalid_http_status_code_below_range(self): + """Test that invalid HTTP status code below 100 defaults to None and uses mapping.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Invalid status", + description="Status code below valid range", + code="RATE_LIMIT", # Has mapping to 429 + http_status_code=99, # Invalid: below 100 + ) + exc = PluginViolationError(message="Invalid status", violation=violation) + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + # Should fall back to code mapping (RATE_LIMIT -> 429) + assert result.status_code == 429 + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_invalid_http_status_code_above_range(self): + """Test that invalid HTTP status code above 511 defaults to None and uses mapping.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Invalid status", + description="Status code above valid range", + code="RATE_LIMIT", # Has mapping to 429 + http_status_code=512 # Invalid: above 511 + ) + exc = PluginViolationError(message="Invalid status", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + # Should fall back to code mapping (RATE_LIMIT -> 429) + assert result.status_code == 429 + content = json.loads(result.body.decode()) + assert content["error"]["code"] == -32602 + + def test_plugin_violation_invalid_http_status_code_no_mapping_fallback(self): + """Test that invalid HTTP status code with no mapping defaults to 200.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + violation = PluginViolation( + reason="Invalid status", + description="Status code invalid, no mapping", + code="UNKNOWN_CODE", # Not in mapping + http_status_code=1000, # Invalid: way above 511 + ) + exc = PluginViolationError(message="Invalid status", violation=violation) + + result = asyncio.run(plugin_violation_exception_handler(None, exc)) + + # Should default to 200 (no mapping available) assert result.status_code == 200 content = json.loads(result.body.decode()) assert content["error"]["code"] == -32602 - assert "A plugin violation occurred" in content["error"]["message"] - assert content["error"]["data"] == {} + + def test_plugin_violation_valid_http_status_code_edge_cases(self): + """Test that valid edge case HTTP status codes (400, 511) are accepted.""" + # Standard + import asyncio + + # First-Party + from mcpgateway.main import plugin_violation_exception_handler + from mcpgateway.plugins.framework.errors import PluginViolationError + from mcpgateway.plugins.framework.models import PluginViolation + + # Test lower boundary (400) + violation_400 = PluginViolation( + reason="Continue", + description="Valid status 400", + code="INFO", + http_status_code=400, # Valid: exactly 400 + ) + exc_400 = PluginViolationError(message="Status 400", violation=violation_400) + result_400 = asyncio.run(plugin_violation_exception_handler(None, exc_400)) + assert result_400.status_code == 400 + + # Test upper boundary (511) + violation_511 = PluginViolation( + reason="Network error", + description="Valid status 511", + code="ERROR", + http_status_code=511, # Valid: exactly 511 + ) + exc_511 = PluginViolationError(message="Status 511", violation=violation_511) + result_511 = asyncio.run(plugin_violation_exception_handler(None, exc_511)) + assert result_511.status_code == 511 def test_plugin_exception_handler_with_full_error(self): """Test plugin_exception_handler with complete error details.""" diff --git a/tests/unit/mcpgateway/test_main_helpers.py b/tests/unit/mcpgateway/test_main_helpers.py index f1fa644de1..18caca2a84 100644 --- a/tests/unit/mcpgateway/test_main_helpers.py +++ b/tests/unit/mcpgateway/test_main_helpers.py @@ -184,6 +184,128 @@ async def test_invalidate_resource_cache_clears_entries(): assert main.resource_cache.get("/resource2") is None +def test_validate_http_headers_valid(): + """Test _validate_http_headers with valid headers.""" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer token123", + "X-Custom-Header": "value with spaces", + } + result = main._validate_http_headers(headers) + assert result == headers + + +def test_validate_http_headers_invalid_name(): + """Test _validate_http_headers rejects invalid header names (line 1435-1436).""" + # Invalid header name with space + headers = {"Invalid Name": "value"} + result = main._validate_http_headers(headers) + assert result is None + + # Invalid header name with special characters not in RFC 9110 token + headers = {"Invalid@Header": "value"} + result = main._validate_http_headers(headers) + assert result is None + + # Mix of valid and invalid headers + headers = { + "Valid-Header": "value1", + "Invalid Name": "value2", + "Another-Valid": "value3", + } + result = main._validate_http_headers(headers) + assert result == {"Valid-Header": "value1", "Another-Valid": "value3"} + + +def test_validate_http_headers_crlf_in_value(): + """Test _validate_http_headers rejects CRLF in header values (line 1439-1440).""" + # Header value with carriage return + headers = {"Content-Type": "application/json\rinjection"} + result = main._validate_http_headers(headers) + assert result is None + + # Header value with newline + headers = {"Authorization": "Bearer token\ninjection"} + result = main._validate_http_headers(headers) + assert result is None + + # Header value with both CRLF + headers = {"X-Custom": "value\r\ninjection"} + result = main._validate_http_headers(headers) + assert result is None + + # Mix of valid and invalid headers + headers = { + "Valid-Header": "clean value", + "Invalid-Header": "value\r\ninjection", + "Another-Valid": "another clean value", + } + result = main._validate_http_headers(headers) + assert result == {"Valid-Header": "clean value", "Another-Valid": "another clean value"} + + +def test_validate_http_headers_ctl_characters(): + """Test _validate_http_headers rejects CTL characters in values (line 1447-1448, 1450).""" + # Header value with null byte (0x00) + headers = {"Content-Type": "application/json\x00"} + result = main._validate_http_headers(headers) + assert result is None + + # Header value with control character (0x01) + headers = {"Authorization": "Bearer\x01token"} + result = main._validate_http_headers(headers) + assert result is None + + # Header value with DEL character (0x7F) + headers = {"X-Custom": "value\x7f"} + result = main._validate_http_headers(headers) + assert result is None + + # Header value with various CTL characters (0x00-0x1F except tab and space) + for code in range(0, 32): + if code in (9, 32): # Skip tab and space (allowed) + continue + headers = {"Test-Header": f"value{chr(code)}end"} + result = main._validate_http_headers(headers) + assert result is None, f"Should reject CTL character 0x{code:02x}" + + # Header value with tab (0x09) - should be allowed + headers = {"Content-Type": "application/json\tcharset=utf-8"} + result = main._validate_http_headers(headers) + assert result == headers + + # Header value with space (0x20) - should be allowed + headers = {"Authorization": "Bearer token with spaces"} + result = main._validate_http_headers(headers) + assert result == headers + + # Mix of valid and invalid headers + headers = { + "Valid-Header": "clean value", + "Invalid-Header": "value\x01injection", + "Another-Valid": "another clean value", + } + result = main._validate_http_headers(headers) + assert result == {"Valid-Header": "clean value", "Another-Valid": "another clean value"} + + +def test_validate_http_headers_empty_dict(): + """Test _validate_http_headers with empty dictionary.""" + result = main._validate_http_headers({}) + assert result is None + + +def test_validate_http_headers_all_invalid(): + """Test _validate_http_headers when all headers are invalid.""" + headers = { + "Invalid Name": "value1", + "Valid-But-Bad-Value": "value\r\ninjection", + "Another-Invalid": "value\x00", + } + result = main._validate_http_headers(headers) + assert result is None + + # --------------------------------------------------------------------------- # tojson_attr filter tests # --------------------------------------------------------------------------- From d158a5e6f8ccdef397290ca16410841845ff8c05 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Tue, 10 Mar 2026 00:41:56 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20review=20fixes=20=E2=80=94=20100%=20?= =?UTF-8?q?coverage,=20typo,=20regex=20hoist,=20docstring=20path?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix typo: validatated_headers → validated_headers - Hoist RFC 9110 token regex to module-level compiled constant - Consolidate redundant CRLF check into unified CTL validation - Fix integration test docstring path (was unit test path) - Add tests for _parse_rate minute/hour/error branches - Add tests for unlimited (no-limit) prompt and tool paths - Achieve 100% differential test coverage on rate_limiter.py Signed-off-by: Mihai Criveti --- mcpgateway/main.py | 34 ++++--- tests/integration/test_rate_limiter.py | 2 +- .../plugins/rate_limiter/test_rate_limiter.py | 92 ++++++++++++++++++- 3 files changed, 108 insertions(+), 20 deletions(-) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 24ba896caf..4fa8a2292b 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -1452,6 +1452,14 @@ async def database_exception_handler(_request: Request, exc: IntegrityError): return ORJSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc)) +# RFC 9110 §5.6.2 'token' pattern for header field names: +# token = 1*tchar +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +_RFC9110_TOKEN_RE = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") + + def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]: """Validate headers according to RFC 9110. @@ -1464,27 +1472,16 @@ def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]: Rules enforced: - Header name must match RFC 9110 'token'. - No whitespace before colon (enforced by dictionary usage). - - Header value must not contain CTL characters (0x00–0x1F, 0x7F). + - Header value must not contain CTL characters (0x00–0x1F, 0x7F), + except SP (0x20) and HTAB (0x09) which are allowed. """ - - # RFC 9110 'token' definition: - # token = 1*tchar - # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" - # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" - # / DIGIT / ALPHA - header_key = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") validated: dict[str, str] = {} for key, value in headers.items(): - # Validate header name (RFC 9110) - if not re.match(header_key, key): + # Validate header name (RFC 9110 token) + if not _RFC9110_TOKEN_RE.match(key): logger.warning(f"Invalid header name: {key}") continue - # Validate header value (no CRLF) - if "\r" in value or "\n" in value: - logger.warning(f"Header value contains CRLF: {key}") - continue # RFC 9110: Reject CTLs (0x00–0x1F, 0x7F). Allow SP (0x20) and HTAB (0x09). - # Further structure (quoted-string, lists, parameters) is left to higher-level parsers. valid = True for ch in value: code = ord(ch) @@ -1492,6 +1489,7 @@ def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]: valid = False break if not valid: + logger.warning(f"Header value contains invalid characters: {key}") continue validated[key] = value return validated if validated else None @@ -1576,9 +1574,9 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola response = ORJSONResponse(status_code=http_status, content={"error": json_rpc_error.model_dump()}) if headers: - validatated_headers = _validate_http_headers(headers) - if validatated_headers: - response.headers.update(validatated_headers) + validated_headers = _validate_http_headers(headers) + if validated_headers: + response.headers.update(validated_headers) return response diff --git a/tests/integration/test_rate_limiter.py b/tests/integration/test_rate_limiter.py index 765a4be0f1..40de40f49e 100644 --- a/tests/integration/test_rate_limiter.py +++ b/tests/integration/test_rate_limiter.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Location: ./tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +"""Location: ./tests/integration/test_rate_limiter.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 Authors: Mihai Criveti diff --git a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py index 9fba1054f0..e3b98b8208 100644 --- a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +++ b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py @@ -17,7 +17,7 @@ PromptPrehookPayload, ToolHookType ) -from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _make_headers, _select_most_restrictive, _store +from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _make_headers, _parse_rate, _select_most_restrictive, _store @pytest.fixture(autouse=True) @@ -509,3 +509,93 @@ def test_tenant_unlimited_user_tool_limited(self): assert allowed is True assert limit == 50 # Tool is most restrictive assert remaining == 30 + + +# ============================================================================ +# _parse_rate Tests +# ============================================================================ + + +class TestParseRate: + """Tests for _parse_rate helper covering all time units.""" + + def test_seconds_short(self): + assert _parse_rate("10/s") == (10, 1) + + def test_seconds_medium(self): + assert _parse_rate("10/sec") == (10, 1) + + def test_seconds_long(self): + assert _parse_rate("10/second") == (10, 1) + + def test_minutes_short(self): + assert _parse_rate("60/m") == (60, 60) + + def test_minutes_medium(self): + assert _parse_rate("60/min") == (60, 60) + + def test_minutes_long(self): + assert _parse_rate("60/minute") == (60, 60) + + def test_hours_short(self): + assert _parse_rate("100/h") == (100, 3600) + + def test_hours_medium(self): + assert _parse_rate("100/hr") == (100, 3600) + + def test_hours_long(self): + assert _parse_rate("100/hour") == (100, 3600) + + def test_unsupported_unit_raises(self): + with pytest.raises(ValueError, match="Unsupported rate unit"): + _parse_rate("10/d") + + def test_whitespace_stripped(self): + assert _parse_rate("5/ M ") == (5, 60) + + +# ============================================================================ +# Unlimited (no-limit) path tests +# ============================================================================ + + +def _mk_unlimited() -> RateLimiterPlugin: + """Create a plugin with no rate limits configured.""" + return RateLimiterPlugin( + PluginConfig( + name="rl", + kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], + config={}, # No limits + ) + ) + + +@pytest.mark.asyncio +async def test_prompt_pre_fetch_unlimited_returns_no_headers(): + """When no limits are configured, prompt_pre_fetch returns metadata without http_headers.""" + plugin = _mk_unlimited() + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = PromptPrehookPayload(prompt_id="p", args={}) + + result = await plugin.prompt_pre_fetch(payload, ctx) + assert result.violation is None + assert result.http_headers is None + assert result.metadata is not None + assert result.metadata.get("limited") is False + + +@pytest.mark.asyncio +async def test_tool_pre_invoke_unlimited_returns_no_headers(): + """When no limits are configured, tool_pre_invoke returns metadata without http_headers.""" + from mcpgateway.plugins.framework import ToolPreInvokePayload + + plugin = _mk_unlimited() + ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1")) + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + result = await plugin.tool_pre_invoke(payload, ctx) + assert result.violation is None + assert result.http_headers is None + assert result.metadata is not None + assert result.metadata.get("limited") is False From fac99b59e342c102a7fedbebe3fb3e1d796eeb12 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Tue, 10 Mar 2026 07:09:04 +0000 Subject: [PATCH 3/3] chore: gitignore per-crate target/ dirs under plugins_rust After the restructure into independent crates (#3147), each crate has its own target/ directory. The existing pattern only covered the workspace-level plugins_rust/target/. Signed-off-by: Mihai Criveti --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index bc06607019..7882531aea 100644 --- a/.gitignore +++ b/.gitignore @@ -258,6 +258,7 @@ uv.lock # Rust (plugins_rust) # ======================================== plugins_rust/target/ +plugins_rust/*/target/ *.rs.bk plugins_rust/Cargo.lock