Skip to content

Commit 188d93b

Browse files
committed
feat: add rust validation middleware sidecar
Signed-off-by: lucarlig <luca.carlig@ibm.com>
1 parent 298b32d commit 188d93b

File tree

8 files changed

+7622
-7269
lines changed

8 files changed

+7622
-7269
lines changed

.secrets.baseline

Lines changed: 7268 additions & 7268 deletions
Large diffs are not rendered by default.

mcpgateway/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ class Settings(BaseSettings):
349349

350350
# Security Validation & Sanitization
351351
experimental_validate_io: bool = Field(default=False, description="Enable experimental input validation and output sanitization")
352+
experimental_rust_validation_middleware_enabled: bool = Field(
353+
default=False,
354+
description="Enable experimental Rust sidecar for recursive validation middleware JSON checks",
355+
)
352356
validation_middleware_enabled: bool = Field(default=False, description="Enable validation middleware for all requests")
353357
validation_strict: bool = Field(default=True, description="Strict validation mode - reject on violations")
354358
sanitize_output: bool = Field(default=True, description="Sanitize output to remove control characters")

mcpgateway/middleware/validation_middleware.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
# Standard
20+
import importlib
2021
import logging
2122
from pathlib import Path
2223
import re
@@ -32,6 +33,8 @@
3233

3334
logger = logging.getLogger(__name__)
3435

36+
_RUST_VALIDATION_MODULE = None
37+
3538

3639
def is_path_traversal(uri: str) -> bool:
3740
"""Check if URI contains path traversal patterns.
@@ -165,6 +168,17 @@ def _validate_json_data(self, data: Any):
165168
Raises:
166169
HTTPException: If validation fails in strict mode
167170
"""
171+
if getattr(settings, "experimental_rust_validation_middleware_enabled", False) is True:
172+
result = self._load_rust_validation_module().validate_json_data(data, settings.max_param_length, list(settings.dangerous_patterns))
173+
if result is not None:
174+
key, error_type = result
175+
if error_type == "max_length":
176+
raise HTTPException(status_code=422, detail=f"Parameter {key} exceeds maximum length")
177+
if error_type == "dangerous_pattern":
178+
raise HTTPException(status_code=422, detail=f"Parameter {key} contains dangerous characters")
179+
raise HTTPException(status_code=422, detail=f"Parameter {key} failed validation")
180+
return
181+
168182
if isinstance(data, dict):
169183
for key, value in data.items():
170184
if isinstance(value, str):
@@ -175,6 +189,14 @@ def _validate_json_data(self, data: Any):
175189
for item in data:
176190
self._validate_json_data(item)
177191

192+
def _load_rust_validation_module(self):
193+
"""Load the experimental Rust validation sidecar on demand."""
194+
global _RUST_VALIDATION_MODULE
195+
196+
if _RUST_VALIDATION_MODULE is None:
197+
_RUST_VALIDATION_MODULE = importlib.import_module("validation_middleware_sidecar")
198+
return _RUST_VALIDATION_MODULE
199+
178200
def validate_resource_path(self, path: str) -> str:
179201
"""Validate and normalize resource paths to prevent traversal attacks.
180202
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# -*- coding: utf-8 -*-
2+
"""Benchmark the validation middleware Rust sidecar against the Python path."""
3+
4+
# Standard
5+
from __future__ import annotations
6+
7+
import importlib
8+
import re
9+
import statistics
10+
import subprocess
11+
import time
12+
from pathlib import Path
13+
from typing import Any, Callable
14+
15+
# Third-Party
16+
from fastapi import HTTPException
17+
18+
# First-Party
19+
from mcpgateway.config import settings
20+
from mcpgateway.middleware.validation_middleware import ValidationMiddleware
21+
22+
REPO_ROOT = Path(__file__).resolve().parents[2]
23+
SIDECAR_MANIFEST = REPO_ROOT / "tools_rust" / "validation_middleware_sidecar" / "Cargo.toml"
24+
25+
26+
def _ensure_sidecar_installed() -> Any:
27+
subprocess.run(["uv", "run", "maturin", "develop", "--release", "--manifest-path", str(SIDECAR_MANIFEST)], check=True, cwd=REPO_ROOT)
28+
return importlib.import_module("validation_middleware_sidecar")
29+
30+
31+
def _build_python_validator(max_param_length: int, dangerous_patterns: list[str]) -> Callable[[Any], None]:
32+
settings.max_param_length = max_param_length
33+
settings.dangerous_patterns = dangerous_patterns
34+
settings.experimental_rust_validation_middleware_enabled = False
35+
settings.environment = "production"
36+
middleware = ValidationMiddleware(app=None)
37+
middleware.dangerous_patterns = [re.compile(pattern) for pattern in dangerous_patterns]
38+
39+
def _run(data: Any) -> None:
40+
middleware._validate_json_data(data)
41+
42+
return _run
43+
44+
45+
def _build_rust_validator(max_param_length: int, dangerous_patterns: list[str]) -> Callable[[Any], None]:
46+
sidecar = _ensure_sidecar_installed()
47+
settings.max_param_length = max_param_length
48+
settings.dangerous_patterns = dangerous_patterns
49+
settings.environment = "production"
50+
51+
def _run(data: Any) -> None:
52+
result = sidecar.validate_json_data(data, max_param_length, dangerous_patterns)
53+
if result is None:
54+
return
55+
key, error_type = result
56+
if error_type == "max_length":
57+
raise HTTPException(status_code=422, detail=f"Parameter {key} exceeds maximum length")
58+
raise HTTPException(status_code=422, detail=f"Parameter {key} contains dangerous characters")
59+
60+
return _run
61+
62+
63+
def _measure(label: str, fn: Callable[[Any], None], payload: Any, iterations: int) -> tuple[float, float]:
64+
samples = []
65+
for _ in range(iterations):
66+
started = time.perf_counter_ns()
67+
try:
68+
fn(payload)
69+
except HTTPException:
70+
pass
71+
samples.append(time.perf_counter_ns() - started)
72+
73+
median_ms = statistics.median(samples) / 1_000_000
74+
p95_ms = statistics.quantiles(samples, n=100)[94] / 1_000_000
75+
print(f"{label}: median={median_ms:.3f}ms p95={p95_ms:.3f}ms")
76+
return median_ms, p95_ms
77+
78+
79+
def _assert_parity(python_fn: Callable[[Any], None], rust_fn: Callable[[Any], None], payloads: list[Any]) -> None:
80+
for payload in payloads:
81+
python_error = None
82+
rust_error = None
83+
84+
try:
85+
python_fn(payload)
86+
except HTTPException as exc:
87+
python_error = (exc.status_code, exc.detail)
88+
89+
try:
90+
rust_fn(payload)
91+
except HTTPException as exc:
92+
rust_error = (exc.status_code, exc.detail)
93+
94+
if python_error != rust_error:
95+
raise AssertionError(f"Parity mismatch for payload {payload!r}: python={python_error!r} rust={rust_error!r}")
96+
97+
98+
def main() -> None:
99+
max_param_length = 1024
100+
dangerous_patterns = [r"[;&|`$(){}\[\]<>]", r"\.\.[\\/]", r"[\x00-\x1f\x7f-\x9f]"]
101+
102+
python_fn = _build_python_validator(max_param_length, dangerous_patterns)
103+
rust_fn = _build_rust_validator(max_param_length, dangerous_patterns)
104+
105+
parity_payloads = [
106+
{"name": "safe", "nested": {"description": "still safe"}},
107+
{"prompt": "<script>alert(1)</script>"},
108+
{"outer": {"inner": "a" * 2048}},
109+
]
110+
_assert_parity(python_fn, rust_fn, parity_payloads)
111+
112+
scenarios = [
113+
(
114+
"nested_safe",
115+
{
116+
"tool": {
117+
"name": "safe-tool",
118+
"description": "ok" * 32,
119+
"metadata": [{"field": "value" * 8} for _ in range(256)],
120+
}
121+
},
122+
400,
123+
),
124+
(
125+
"deep_nested",
126+
{"batch": [{"payload": {"name": f"item-{index}", "content": ("alpha-beta-gamma-" * 16)}} for index in range(512)]},
127+
250,
128+
),
129+
(
130+
"dangerous_string",
131+
{"batch": [{"payload": {"name": f"item-{index}", "content": "safe-content"}} for index in range(511)] + [{"payload": {"name": "bad", "content": "<script>alert(1)</script>"}}]},
132+
250,
133+
),
134+
]
135+
136+
for name, payload, iterations in scenarios:
137+
print(f"\n{name} ({iterations} iterations)")
138+
python_median, _ = _measure("python", python_fn, payload, iterations)
139+
rust_median, _ = _measure("rust", rust_fn, payload, iterations)
140+
print(f"speedup={python_median / rust_median:.2f}x")
141+
142+
143+
if __name__ == "__main__":
144+
main()

tests/unit/mcpgateway/middleware/test_validation_middleware.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
# Standard
11-
import re
1211
from unittest.mock import AsyncMock, MagicMock, patch
1312

1413
# Third-Party
@@ -350,6 +349,45 @@ def test_validate_json_data_list(self):
350349
# Should not raise for valid data
351350
middleware._validate_json_data([{"name": "item1"}, {"name": "item2"}])
352351

352+
def test_validate_json_data_uses_rust_sidecar_when_enabled(self):
353+
"""Test JSON validation uses the Rust sidecar when explicitly enabled."""
354+
with patch("mcpgateway.middleware.validation_middleware.settings") as mock_settings:
355+
mock_settings.experimental_validate_io = True
356+
mock_settings.experimental_rust_validation_middleware_enabled = True
357+
mock_settings.validation_strict = True
358+
mock_settings.sanitize_output = False
359+
mock_settings.allowed_roots = []
360+
mock_settings.dangerous_patterns = [r"<script"]
361+
mock_settings.max_param_length = 1000
362+
mock_settings.environment = "production"
363+
364+
middleware = ValidationMiddleware(app=None)
365+
rust_module = MagicMock()
366+
rust_module.validate_json_data.return_value = None
367+
368+
with patch.object(middleware, "_load_rust_validation_module", return_value=rust_module):
369+
middleware._validate_json_data({"name": "safe"})
370+
371+
rust_module.validate_json_data.assert_called_once_with({"name": "safe"}, 1000, [r"<script"])
372+
373+
def test_validate_json_data_missing_sidecar_is_hard_failure_when_enabled(self):
374+
"""Test Rust mode fails hard when the sidecar cannot be loaded."""
375+
with patch("mcpgateway.middleware.validation_middleware.settings") as mock_settings:
376+
mock_settings.experimental_validate_io = True
377+
mock_settings.experimental_rust_validation_middleware_enabled = True
378+
mock_settings.validation_strict = True
379+
mock_settings.sanitize_output = False
380+
mock_settings.allowed_roots = []
381+
mock_settings.dangerous_patterns = [r"<script"]
382+
mock_settings.max_param_length = 1000
383+
mock_settings.environment = "production"
384+
385+
middleware = ValidationMiddleware(app=None)
386+
387+
with patch.object(middleware, "_load_rust_validation_module", side_effect=ModuleNotFoundError("missing sidecar")):
388+
with pytest.raises(ModuleNotFoundError, match="missing sidecar"):
389+
middleware._validate_json_data({"name": "<script>"})
390+
353391
def test_validate_resource_path_traversal(self):
354392
"""Test resource path validation for traversal."""
355393
with patch("mcpgateway.middleware.validation_middleware.settings") as mock_settings:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "validation_middleware_sidecar"
3+
version = "0.1.0"
4+
edition = "2021"
5+
license = "Apache-2.0"
6+
7+
[lib]
8+
name = "validation_middleware_sidecar"
9+
crate-type = ["cdylib"]
10+
11+
[dependencies]
12+
once_cell = "1.21.3"
13+
pyo3 = { version = "0.27.1", features = ["extension-module"] }
14+
regex = "1.12.2"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[build-system]
2+
requires = ["maturin>=1.8,<2.0"]
3+
build-backend = "maturin"
4+
5+
[project]
6+
name = "validation-middleware-sidecar"
7+
version = "0.1.0"
8+
requires-python = ">=3.11"
9+
10+
[tool.maturin]
11+
module-name = "validation_middleware_sidecar"
12+
bindings = "pyo3"

0 commit comments

Comments
 (0)