Skip to content

Commit d559ad9

Browse files
ja8zyjitscrivetimihai
authored andcommitted
fix(plugins): propagate HTTP status codes and headers from plugin violations
Add support for plugins to specify HTTP status codes (e.g., 429 for rate limiting) and custom headers (e.g., Retry-After) in PluginViolation responses. - Add http_status_code and http_headers fields to PluginViolation model - Implement PLUGIN_VIOLATION_CODE_MAPPING for common violation types - Update plugin_violation_exception_handler to use explicit status codes with fallback to code mapping, defaulting to 200 for JSON-RPC compliance - Add RFC 9110 header validation to prevent header injection - Enhance rate limiter plugin with multi-dimensional rate limiting, proper HTTP 429 responses, and X-RateLimit-* / Retry-After headers - Add comprehensive test coverage for status code precedence, header propagation, and header validation Fixes #2668 Signed-off-by: Jitesh Nair <jiteshnair@ibm.com> Signed-off-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent 4420a7e commit d559ad9

File tree

9 files changed

+1934
-48
lines changed

9 files changed

+1934
-48
lines changed

mcpgateway/main.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from functools import lru_cache
3434
import hashlib
3535
import html
36+
import re
3637
import sys
3738
from typing import Any, AsyncIterator, Dict, List, Optional, Union
3839
from urllib.parse import urlparse, urlunparse
@@ -86,6 +87,7 @@
8687
from mcpgateway.middleware.validation_middleware import ValidationMiddleware
8788
from mcpgateway.observability import init_telemetry
8889
from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError
90+
from mcpgateway.plugins.framework.constants import PLUGIN_VIOLATION_CODE_MAPPING, PluginViolationCode, VALID_HTTP_STATUS_CODES
8991
from mcpgateway.routers.server_well_known import router as server_well_known_router
9092
from mcpgateway.routers.well_known import router as well_known_router
9193
from mcpgateway.schemas import (
@@ -1450,6 +1452,51 @@ async def database_exception_handler(_request: Request, exc: IntegrityError):
14501452
return ORJSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc))
14511453

14521454

1455+
def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]:
1456+
"""Validate headers according to RFC 9110.
1457+
1458+
Args:
1459+
headers: dict of headers
1460+
1461+
Returns:
1462+
Optional[dict[str, str]]: dictionary of valid headers
1463+
1464+
Rules enforced:
1465+
- Header name must match RFC 9110 'token'.
1466+
- No whitespace before colon (enforced by dictionary usage).
1467+
- Header value must not contain CTL characters (0x00–0x1F, 0x7F).
1468+
"""
1469+
1470+
# RFC 9110 'token' definition:
1471+
# token = 1*tchar
1472+
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
1473+
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
1474+
# / DIGIT / ALPHA
1475+
header_key = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")
1476+
validated: dict[str, str] = {}
1477+
for key, value in headers.items():
1478+
# Validate header name (RFC 9110)
1479+
if not re.match(header_key, key):
1480+
logger.warning(f"Invalid header name: {key}")
1481+
continue
1482+
# Validate header value (no CRLF)
1483+
if "\r" in value or "\n" in value:
1484+
logger.warning(f"Header value contains CRLF: {key}")
1485+
continue
1486+
# RFC 9110: Reject CTLs (0x00–0x1F, 0x7F). Allow SP (0x20) and HTAB (0x09).
1487+
# Further structure (quoted-string, lists, parameters) is left to higher-level parsers.
1488+
valid = True
1489+
for ch in value:
1490+
code = ord(ch)
1491+
if (0 <= code <= 31 or code == 127) and code not in (9, 32):
1492+
valid = False
1493+
break
1494+
if not valid:
1495+
continue
1496+
validated[key] = value
1497+
return validated if validated else None
1498+
1499+
14531500
@app.exception_handler(PluginViolationError)
14541501
async def plugin_violation_exception_handler(_request: Request, exc: PluginViolationError):
14551502
"""Handle plugins violations globally.
@@ -1464,7 +1511,9 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
14641511
violation details.
14651512
14661513
Returns:
1467-
JSONResponse: A 200 response with error details in JSON-RPC format.
1514+
JSONResponse: A response with error details in JSON-RPC format.
1515+
Uses HTTP status code from violation if present (e.g., 429 for rate limiting),
1516+
otherwise defaults to 200 for JSON-RPC compliance.
14681517
14691518
Examples:
14701519
>>> from mcpgateway.plugins.framework import PluginViolationError
@@ -1482,7 +1531,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
14821531
... ))
14831532
>>> result = asyncio.run(plugin_violation_exception_handler(None, mock_error))
14841533
>>> result.status_code
1485-
200
1534+
422
14861535
>>> content = orjson.loads(result.body.decode())
14871536
>>> content["error"]["code"]
14881537
-32602
@@ -1496,6 +1545,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
14961545
policy_violation["message"] = exc.message
14971546
status_code = exc.violation.mcp_error_code if exc.violation and exc.violation.mcp_error_code else -32602
14981547
violation_details: dict[str, Any] = {}
1548+
http_status = 200
14991549
if exc.violation:
15001550
if exc.violation.description:
15011551
violation_details["description"] = exc.violation.description
@@ -1505,8 +1555,31 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
15051555
violation_details["plugin_error_code"] = exc.violation.code
15061556
if exc.violation.plugin_name:
15071557
violation_details["plugin_name"] = exc.violation.plugin_name
1558+
1559+
# Use HTTP status code from violation if present (e.g., 429 for rate limiting)
1560+
http_status = exc.violation.http_status_code if exc.violation.http_status_code else None
1561+
if http_status and not VALID_HTTP_STATUS_CODES.get(http_status):
1562+
logger.warning(f"Invalid HTTP status code {http_status} from violation, defaulting to 200")
1563+
http_status = None
1564+
if not http_status:
1565+
logger.debug("Using Plugin violation code mapping for lack of http_status_code")
1566+
mapping: Optional[PluginViolationCode] = PLUGIN_VIOLATION_CODE_MAPPING.get(exc.violation.code) if exc.violation.code else None
1567+
if not mapping:
1568+
http_status = 200
1569+
else:
1570+
http_status = mapping.code
1571+
15081572
json_rpc_error = PydanticJSONRPCError(code=status_code, message="Plugin Violation: " + message, data=violation_details)
1509-
return ORJSONResponse(status_code=200, content={"error": json_rpc_error.model_dump()})
1573+
1574+
# Collect HTTP headers from violation if present
1575+
headers = exc.violation.http_headers if exc.violation and exc.violation.http_headers else None
1576+
1577+
response = ORJSONResponse(status_code=http_status, content={"error": json_rpc_error.model_dump()})
1578+
if headers:
1579+
validatated_headers = _validate_http_headers(headers)
1580+
if validatated_headers:
1581+
response.headers.update(validatated_headers)
1582+
return response
15101583

15111584

15121585
@app.exception_handler(PluginError)

mcpgateway/plugins/framework/constants.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
# Standard
1212

13+
# Standard
14+
from dataclasses import dataclass
15+
from types import MappingProxyType
16+
from typing import Mapping
17+
1318
# Model constants.
1419
# Specialized plugin types.
1520
EXTERNAL_PLUGIN_TYPE = "external"
@@ -43,3 +48,90 @@
4348
GET_PLUGIN_CONFIG = "get_plugin_config"
4449
HOOK_TYPE = "hook_type"
4550
INVOKE_HOOK = "invoke_hook"
51+
52+
53+
@dataclass(frozen=True)
54+
class PluginViolationCode:
55+
"""
56+
Plugin violation codes as an immutable dataclass object.
57+
58+
Provide Mapping for violation codes to their corresponding HTTP status codes for proper error responses.
59+
"""
60+
61+
code: int
62+
name: str
63+
message: str
64+
65+
66+
PLUGIN_VIOLATION_CODE_MAPPING: Mapping[str, PluginViolationCode] = MappingProxyType(
67+
# MappingProxyType will make sure the resulting object is immutable and hence this will act as a constant.
68+
{
69+
# Rate Limiting
70+
"RATE_LIMIT": PluginViolationCode(429, "RATE_LIMIT", "Used when rate limit is exceeded (rate_limiter plugin)"),
71+
# Resource & URI Validation
72+
"INVALID_URI": PluginViolationCode(400, "INVALID_URI", "Used when URI cannot be parsed or has invalid format (resource_filter, cedar, opa)"),
73+
"PROTOCOL_BLOCKED": PluginViolationCode(403, "PROTOCOL_BLOCKED", "Used when protocol/scheme is not allowed (resource_filter)"),
74+
"DOMAIN_BLOCKED": PluginViolationCode(403, "DOMAIN_BLOCKED", "Used when domain is in blocklist (resource_filter)"),
75+
"CONTENT_TOO_LARGE": PluginViolationCode(413, "CONTENT_TOO_LARGE", "Used when resource content exceeds size limit (resource_filter)"),
76+
# Content Moderation & Safety
77+
"CONTENT_MODERATION": PluginViolationCode(422, "CONTENT_MODERATION", "Used when harmful content is detected (content_moderation plugin)"),
78+
"MODERATION_ERROR": PluginViolationCode(503, "MODERATION_ERROR", "Used when moderation service fails (content_moderation plugin)"),
79+
"PII_DETECTED": PluginViolationCode(422, "PII_DETECTED", "Used when PII is detected in content (pii_filter plugin)"),
80+
"SENSITIVE_CONTENT": PluginViolationCode(422, "SENSITIVE_CONTENT", "Used when sensitive information is detected"),
81+
# Authentication & Authorization
82+
"INVALID_TOKEN": PluginViolationCode(401, "INVALID_TOKEN", "Used for invalid/expired tokens (simple_token_auth example)"), # nosec B105 - Not a password; INVALID_TOKEN is a HTTP Status Code
83+
"API_KEY_REVOKED": PluginViolationCode(401, "API_KEY_REVOKED", "Used when API key has been revoked (custom_auth_example)"),
84+
"AUTH_REQUIRED": PluginViolationCode(401, "AUTH_REQUIRED", "Used when authentication is missing"),
85+
# Generic Violation Codes
86+
"PROHIBITED_CONTENT": PluginViolationCode(422, "PROHIBITED_CONTENT", "Used when content violates policy rules"),
87+
"BLOCKED_CONTENT": PluginViolationCode(403, "BLOCKED_CONTENT", "Used when content is explicitly blocked by policy"),
88+
"BLOCKED": PluginViolationCode(403, "BLOCKED", "Generic blocking violation"),
89+
"EXECUTION_ERROR": PluginViolationCode(500, "EXECUTION_ERROR", "Used when plugin execution fails"),
90+
"PROCESSING_ERROR": PluginViolationCode(500, "PROCESSING_ERROR", "Used when processing encounters an error"),
91+
}
92+
)
93+
94+
VALID_HTTP_STATUS_CODES: dict[int, str] = { # RFC 9110
95+
# 4xx — Client Error
96+
400: "Bad Request",
97+
401: "Unauthorized",
98+
402: "Payment Required",
99+
403: "Forbidden",
100+
404: "Not Found",
101+
405: "Method Not Allowed",
102+
406: "Not Acceptable",
103+
407: "Proxy Authentication Required",
104+
408: "Request Timeout",
105+
409: "Conflict",
106+
410: "Gone",
107+
411: "Length Required",
108+
412: "Precondition Failed",
109+
413: "Content Too Large", # (was "Payload Too Large" before RFC 9110)
110+
414: "URI Too Long",
111+
415: "Unsupported Media Type",
112+
416: "Range Not Satisfiable",
113+
417: "Expectation Failed",
114+
418: "(Unused)",
115+
421: "Misdirected Request",
116+
422: "Unprocessable Content", # (was "Unprocessable Entity")
117+
423: "Locked",
118+
424: "Failed Dependency",
119+
425: "Too Early",
120+
426: "Upgrade Required",
121+
428: "Precondition Required",
122+
429: "Too Many Requests",
123+
431: "Request Header Fields Too Large",
124+
451: "Unavailable For Legal Reasons",
125+
# 5xx — Server Error
126+
500: "Internal Server Error",
127+
501: "Not Implemented",
128+
502: "Bad Gateway",
129+
503: "Service Unavailable",
130+
504: "Gateway Timeout",
131+
505: "HTTP Version Not Supported",
132+
506: "Variant Also Negotiates",
133+
507: "Insufficient Storage",
134+
508: "Loop Detected",
135+
510: "Not Extended",
136+
511: "Network Authentication Required",
137+
}

mcpgateway/plugins/framework/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,8 @@ class PluginViolation(BaseModel):
12311231
details: (dict[str, Any]): additional violation details.
12321232
_plugin_name (str): the plugin name, private attribute set by the plugin manager.
12331233
mcp_error_code(Optional[int]): A valid mcp error code which will be sent back to the client if plugin enabled.
1234+
http_status_code (Optional[int]): HTTP status code to return (e.g., 429 for rate limiting).
1235+
http_headers (Optional[dict[str, str]]): HTTP headers to include in the response.
12341236
12351237
Examples:
12361238
>>> violation = PluginViolation(
@@ -1254,6 +1256,8 @@ class PluginViolation(BaseModel):
12541256
details: Optional[dict[str, Any]] = Field(default_factory=dict)
12551257
_plugin_name: str = PrivateAttr(default="")
12561258
mcp_error_code: Optional[int] = None
1259+
http_status_code: Optional[int] = None
1260+
http_headers: Optional[dict[str, str]] = None
12571261

12581262
@property
12591263
def plugin_name(self) -> str:
@@ -1327,6 +1331,7 @@ class PluginResult(BaseModel, Generic[T]):
13271331
modified_payload (Optional[Any]): The modified payload if the plugin is a transformer.
13281332
violation (Optional[PluginViolation]): violation object.
13291333
metadata (Optional[dict[str, Any]]): additional metadata.
1334+
http_headers (Optional[dict[str, str]]): HTTP headers to include in successful responses.
13301335
13311336
Examples:
13321337
>>> result = PluginResult()
@@ -1355,6 +1360,7 @@ class PluginResult(BaseModel, Generic[T]):
13551360
modified_payload: Optional[T] = None
13561361
violation: Optional[PluginViolation] = None
13571362
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
1363+
http_headers: Optional[dict[str, str]] = None
13581364

13591365

13601366
class GlobalContext(BaseModel):

0 commit comments

Comments
 (0)