Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ uv.lock
# Rust (plugins_rust)
# ========================================
plugins_rust/target/
plugins_rust/*/target/
*.rs.bk
plugins_rust/Cargo.lock

Expand Down
77 changes: 74 additions & 3 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from functools import lru_cache
import hashlib
import html
import re
import sys
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from urllib.parse import urlparse, urlunparse
Expand Down Expand Up @@ -86,6 +87,7 @@
from mcpgateway.middleware.validation_middleware import ValidationMiddleware
from mcpgateway.observability import init_telemetry
from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError
from mcpgateway.plugins.framework.constants import PLUGIN_VIOLATION_CODE_MAPPING, PluginViolationCode, VALID_HTTP_STATUS_CODES
from mcpgateway.routers.server_well_known import router as server_well_known_router
from mcpgateway.routers.well_known import router as well_known_router
from mcpgateway.schemas import (
Expand Down Expand Up @@ -1450,6 +1452,49 @@ async def database_exception_handler(_request: Request, exc: IntegrityError):
return ORJSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc))


# RFC 9110 Β§5.6.2 'token' pattern for header field names:
# token = 1*tchar
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
# / DIGIT / ALPHA
_RFC9110_TOKEN_RE = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")


def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]:
"""Validate headers according to RFC 9110.

Args:
headers: dict of headers

Returns:
Optional[dict[str, str]]: dictionary of valid headers

Rules enforced:
- Header name must match RFC 9110 'token'.
- No whitespace before colon (enforced by dictionary usage).
- Header value must not contain CTL characters (0x00–0x1F, 0x7F),
except SP (0x20) and HTAB (0x09) which are allowed.
"""
validated: dict[str, str] = {}
for key, value in headers.items():
# Validate header name (RFC 9110 token)
if not _RFC9110_TOKEN_RE.match(key):
logger.warning(f"Invalid header name: {key}")
continue
# RFC 9110: Reject CTLs (0x00–0x1F, 0x7F). Allow SP (0x20) and HTAB (0x09).
valid = True
for ch in value:
code = ord(ch)
if (0 <= code <= 31 or code == 127) and code not in (9, 32):
valid = False
break
if not valid:
logger.warning(f"Header value contains invalid characters: {key}")
continue
validated[key] = value
return validated if validated else None


@app.exception_handler(PluginViolationError)
async def plugin_violation_exception_handler(_request: Request, exc: PluginViolationError):
"""Handle plugins violations globally.
Expand All @@ -1464,7 +1509,9 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
violation details.

Returns:
JSONResponse: A 200 response with error details in JSON-RPC format.
JSONResponse: A response with error details in JSON-RPC format.
Uses HTTP status code from violation if present (e.g., 429 for rate limiting),
otherwise defaults to 200 for JSON-RPC compliance.

Examples:
>>> from mcpgateway.plugins.framework import PluginViolationError
Expand All @@ -1482,7 +1529,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
... ))
>>> result = asyncio.run(plugin_violation_exception_handler(None, mock_error))
>>> result.status_code
200
422
>>> content = orjson.loads(result.body.decode())
>>> content["error"]["code"]
-32602
Expand All @@ -1496,6 +1543,7 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
policy_violation["message"] = exc.message
status_code = exc.violation.mcp_error_code if exc.violation and exc.violation.mcp_error_code else -32602
violation_details: dict[str, Any] = {}
http_status = 200
if exc.violation:
if exc.violation.description:
violation_details["description"] = exc.violation.description
Expand All @@ -1505,8 +1553,31 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
violation_details["plugin_error_code"] = exc.violation.code
if exc.violation.plugin_name:
violation_details["plugin_name"] = exc.violation.plugin_name

# Use HTTP status code from violation if present (e.g., 429 for rate limiting)
http_status = exc.violation.http_status_code if exc.violation.http_status_code else None
if http_status and not VALID_HTTP_STATUS_CODES.get(http_status):
logger.warning(f"Invalid HTTP status code {http_status} from violation, defaulting to 200")
http_status = None
if not http_status:
logger.debug("Using Plugin violation code mapping for lack of http_status_code")
mapping: Optional[PluginViolationCode] = PLUGIN_VIOLATION_CODE_MAPPING.get(exc.violation.code) if exc.violation.code else None
if not mapping:
http_status = 200
else:
http_status = mapping.code

json_rpc_error = PydanticJSONRPCError(code=status_code, message="Plugin Violation: " + message, data=violation_details)
return ORJSONResponse(status_code=200, content={"error": json_rpc_error.model_dump()})

# Collect HTTP headers from violation if present
headers = exc.violation.http_headers if exc.violation and exc.violation.http_headers else None

response = ORJSONResponse(status_code=http_status, content={"error": json_rpc_error.model_dump()})
if headers:
validated_headers = _validate_http_headers(headers)
if validated_headers:
response.headers.update(validated_headers)
return response


@app.exception_handler(PluginError)
Expand Down
92 changes: 92 additions & 0 deletions mcpgateway/plugins/framework/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

# Standard

# Standard
from dataclasses import dataclass
from types import MappingProxyType
from typing import Mapping

# Model constants.
# Specialized plugin types.
EXTERNAL_PLUGIN_TYPE = "external"
Expand Down Expand Up @@ -43,3 +48,90 @@
GET_PLUGIN_CONFIG = "get_plugin_config"
HOOK_TYPE = "hook_type"
INVOKE_HOOK = "invoke_hook"


@dataclass(frozen=True)
class PluginViolationCode:
"""
Plugin violation codes as an immutable dataclass object.

Provide Mapping for violation codes to their corresponding HTTP status codes for proper error responses.
"""

code: int
name: str
message: str


PLUGIN_VIOLATION_CODE_MAPPING: Mapping[str, PluginViolationCode] = MappingProxyType(
# MappingProxyType will make sure the resulting object is immutable and hence this will act as a constant.
{
# Rate Limiting
"RATE_LIMIT": PluginViolationCode(429, "RATE_LIMIT", "Used when rate limit is exceeded (rate_limiter plugin)"),
# Resource & URI Validation
"INVALID_URI": PluginViolationCode(400, "INVALID_URI", "Used when URI cannot be parsed or has invalid format (resource_filter, cedar, opa)"),
"PROTOCOL_BLOCKED": PluginViolationCode(403, "PROTOCOL_BLOCKED", "Used when protocol/scheme is not allowed (resource_filter)"),
"DOMAIN_BLOCKED": PluginViolationCode(403, "DOMAIN_BLOCKED", "Used when domain is in blocklist (resource_filter)"),
"CONTENT_TOO_LARGE": PluginViolationCode(413, "CONTENT_TOO_LARGE", "Used when resource content exceeds size limit (resource_filter)"),
# Content Moderation & Safety
"CONTENT_MODERATION": PluginViolationCode(422, "CONTENT_MODERATION", "Used when harmful content is detected (content_moderation plugin)"),
"MODERATION_ERROR": PluginViolationCode(503, "MODERATION_ERROR", "Used when moderation service fails (content_moderation plugin)"),
"PII_DETECTED": PluginViolationCode(422, "PII_DETECTED", "Used when PII is detected in content (pii_filter plugin)"),
"SENSITIVE_CONTENT": PluginViolationCode(422, "SENSITIVE_CONTENT", "Used when sensitive information is detected"),
# Authentication & Authorization
"INVALID_TOKEN": PluginViolationCode(401, "INVALID_TOKEN", "Used for invalid/expired tokens (simple_token_auth example)"), # nosec B105 - Not a password; INVALID_TOKEN is a HTTP Status Code
"API_KEY_REVOKED": PluginViolationCode(401, "API_KEY_REVOKED", "Used when API key has been revoked (custom_auth_example)"),
"AUTH_REQUIRED": PluginViolationCode(401, "AUTH_REQUIRED", "Used when authentication is missing"),
# Generic Violation Codes
"PROHIBITED_CONTENT": PluginViolationCode(422, "PROHIBITED_CONTENT", "Used when content violates policy rules"),
"BLOCKED_CONTENT": PluginViolationCode(403, "BLOCKED_CONTENT", "Used when content is explicitly blocked by policy"),
"BLOCKED": PluginViolationCode(403, "BLOCKED", "Generic blocking violation"),
"EXECUTION_ERROR": PluginViolationCode(500, "EXECUTION_ERROR", "Used when plugin execution fails"),
"PROCESSING_ERROR": PluginViolationCode(500, "PROCESSING_ERROR", "Used when processing encounters an error"),
}
)

VALID_HTTP_STATUS_CODES: dict[int, str] = { # RFC 9110
# 4xx β€” Client Error
400: "Bad Request",
401: "Unauthorized",
402: "Payment Required",
403: "Forbidden",
404: "Not Found",
405: "Method Not Allowed",
406: "Not Acceptable",
407: "Proxy Authentication Required",
408: "Request Timeout",
409: "Conflict",
410: "Gone",
411: "Length Required",
412: "Precondition Failed",
413: "Content Too Large", # (was "Payload Too Large" before RFC 9110)
414: "URI Too Long",
415: "Unsupported Media Type",
416: "Range Not Satisfiable",
417: "Expectation Failed",
418: "(Unused)",
421: "Misdirected Request",
422: "Unprocessable Content", # (was "Unprocessable Entity")
423: "Locked",
424: "Failed Dependency",
425: "Too Early",
426: "Upgrade Required",
428: "Precondition Required",
429: "Too Many Requests",
431: "Request Header Fields Too Large",
451: "Unavailable For Legal Reasons",
# 5xx β€” Server Error
500: "Internal Server Error",
501: "Not Implemented",
502: "Bad Gateway",
503: "Service Unavailable",
504: "Gateway Timeout",
505: "HTTP Version Not Supported",
506: "Variant Also Negotiates",
507: "Insufficient Storage",
508: "Loop Detected",
510: "Not Extended",
511: "Network Authentication Required",
}
6 changes: 6 additions & 0 deletions mcpgateway/plugins/framework/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,8 @@ class PluginViolation(BaseModel):
details: (dict[str, Any]): additional violation details.
_plugin_name (str): the plugin name, private attribute set by the plugin manager.
mcp_error_code(Optional[int]): A valid mcp error code which will be sent back to the client if plugin enabled.
http_status_code (Optional[int]): HTTP status code to return (e.g., 429 for rate limiting).
http_headers (Optional[dict[str, str]]): HTTP headers to include in the response.

Examples:
>>> violation = PluginViolation(
Expand All @@ -1254,6 +1256,8 @@ class PluginViolation(BaseModel):
details: Optional[dict[str, Any]] = Field(default_factory=dict)
_plugin_name: str = PrivateAttr(default="")
mcp_error_code: Optional[int] = None
http_status_code: Optional[int] = None
http_headers: Optional[dict[str, str]] = None

@property
def plugin_name(self) -> str:
Expand Down Expand Up @@ -1327,6 +1331,7 @@ class PluginResult(BaseModel, Generic[T]):
modified_payload (Optional[Any]): The modified payload if the plugin is a transformer.
violation (Optional[PluginViolation]): violation object.
metadata (Optional[dict[str, Any]]): additional metadata.
http_headers (Optional[dict[str, str]]): HTTP headers to include in successful responses.

Examples:
>>> result = PluginResult()
Expand Down Expand Up @@ -1355,6 +1360,7 @@ class PluginResult(BaseModel, Generic[T]):
modified_payload: Optional[T] = None
violation: Optional[PluginViolation] = None
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
http_headers: Optional[dict[str, str]] = None


class GlobalContext(BaseModel):
Expand Down
Loading
Loading