Skip to content

Commit 264676c

Browse files
crivetimihaijonpsprisisyphus-dev-ai
authored
fix(transport): protocol and transport hardening for auth and lifecycle consistency (#3344)
* fix: protocol and transport hardening for auth and lifecycle consistency Improve RPC/protocol validation behavior and error mapping, enforce admin-only diagnostics access, and harden stateful streamable HTTP session teardown semantics. Align nginx upstream behavior for replica-aware balancing and remove overlapping proxy security headers. Update and expand unit coverage across main/version/oauth/streamable transport paths. Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * chore: docstring hardening for lint consistency Update endpoint/helper docstrings for explicit args/returns/raises coverage and keep doctest examples aligned with current admin diagnostics access semantics. Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * test: add differential coverage for protocol/transport hardening paths Cover RPC-path SamplingError and CompletionError mapping to JSON-RPC -32602, stateful session owner assertion (404/403 deny paths), ping endpoint non-dict body handling, and streamable HTTP session close error branches (registry None, remove_session failure). Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * fix: harden version endpoint auth and align test expectations - version.py: _has_version_admin_access now accepts string inputs from require_admin_auth (which returns email after verifying admin status) - test_version.py: reference require_admin_auth instead of nonexistent require_auth - test_main_apis.py: ping invalid-method test expects 400 (not 500) - test_main_extended.py: fix params list->dict, remove stale RPCRequest patch Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai> Signed-off-by: Jonathan Springer <jps@s390x.com> * refactor: harden RPC error responses — extract helper, skip redundant validation, remove data leaks - Extract _jsonrpc_invalid_request() helper to deduplicate 5 identical error envelope constructions across ping() and _handle_rpc_authenticated() - Skip RPCRequest Pydantic validation for trusted internal Rust dispatch (restores intentional perf optimization; manual type checks still run) - Remove session_id from JSONRPCError data to avoid reflecting client input - Remove full params dict from SamplingError/CompletionError JSON-RPC errors to prevent request payload leakage in error responses Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai> Signed-off-by: Jonathan Springer <jps@s390x.com> --------- Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> Signed-off-by: Jonathan Springer <jps@s390x.com> Co-authored-by: Jonathan Springer <jps@s390x.com> Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
1 parent 1beafee commit 264676c

14 files changed

Lines changed: 534 additions & 99 deletions

.secrets.baseline

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"files": "(?x)( package-lock\\.json$ |Cargo\\.lock$ |uv\\.lock$ |go\\.sum$ |mcpgateway/sri_hashes\\.json$ )|^.secrets.baseline$",
44
"lines": null
55
},
6-
"generated_at": "2026-04-23T10:40:16Z",
6+
"generated_at": "2026-04-23T15:27:18Z",
77
"plugins_used": [
88
{
99
"name": "AWSKeyDetector"
@@ -5018,7 +5018,7 @@
50185018
"hashed_secret": "d3ecb0d890368d7659ee54010045b835dacb8efe",
50195019
"is_secret": false,
50205020
"is_verified": false,
5021-
"line_number": 611,
5021+
"line_number": 615,
50225022
"type": "Secret Keyword",
50235023
"verified_result": null
50245024
}
@@ -5484,15 +5484,15 @@
54845484
"hashed_secret": "9addbf544119efa4a64223b649750a510f0d463f",
54855485
"is_secret": false,
54865486
"is_verified": false,
5487-
"line_number": 850,
5487+
"line_number": 851,
54885488
"type": "Secret Keyword",
54895489
"verified_result": null
54905490
},
54915491
{
54925492
"hashed_secret": "9addbf544119efa4a64223b649750a510f0d463f",
54935493
"is_secret": false,
54945494
"is_verified": false,
5495-
"line_number": 912,
5495+
"line_number": 913,
54965496
"type": "Basic Auth Credentials",
54975497
"verified_result": null
54985498
}
@@ -5982,7 +5982,7 @@
59825982
"hashed_secret": "4dfad2d5130a5bcaf2716c495de92da13f4389e0",
59835983
"is_secret": false,
59845984
"is_verified": false,
5985-
"line_number": 764,
5985+
"line_number": 763,
59865986
"type": "Secret Keyword",
59875987
"verified_result": null
59885988
}
@@ -9510,7 +9510,7 @@
95109510
"hashed_secret": "516b9783fca517eecbd1d064da2d165310b19759",
95119511
"is_secret": false,
95129512
"is_verified": false,
9513-
"line_number": 183,
9513+
"line_number": 210,
95149514
"type": "Basic Auth Credentials",
95159515
"verified_result": null
95169516
}
@@ -9548,15 +9548,15 @@
95489548
"hashed_secret": "b4c9248600a42f8c38c01b632f392dbcb4c7b19a",
95499549
"is_secret": false,
95509550
"is_verified": false,
9551-
"line_number": 12820,
9551+
"line_number": 12928,
95529552
"type": "Hex High Entropy String",
95539553
"verified_result": null
95549554
},
95559555
{
95569556
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
95579557
"is_secret": false,
95589558
"is_verified": false,
9559-
"line_number": 13995,
9559+
"line_number": 14103,
95609560
"type": "Hex High Entropy String",
95619561
"verified_result": null
95629562
}

infra/nginx/nginx-performance.conf

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ http {
156156
# Load balancing: least_conn distributes to backend with fewest active connections
157157
# Fixes imbalance caused by keepalive connections sticking to one backend
158158
least_conn;
159+
zone gateway_backend 64k;
159160

160-
server gateway:4444 max_fails=0; # Disable failure tracking (always retry)
161+
server gateway:4444 resolve max_fails=0; # Re-resolve Docker DNS for replica-aware balancing
161162

162163
# Keepalive pool sizing for 10,000 capacity:
163164
# - Each nginx worker maintains its own pool
@@ -199,11 +200,7 @@ http {
199200
listen [::]:80 backlog=65535 reuseport;
200201
server_name localhost;
201202

202-
# Security headers
203-
add_header X-Frame-Options "SAMEORIGIN" always;
204-
add_header X-Content-Type-Options "nosniff" always;
205-
add_header X-XSS-Protection "1; mode=block" always;
206-
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
203+
# Security headers are authored by the gateway app middleware.
207204

208205
# Cache status header (for debugging)
209206
add_header X-Cache-Status $upstream_cache_status always;

infra/nginx/nginx-tls.conf

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ http {
156156
# Load balancing: least_conn distributes to backend with fewest active connections
157157
# Fixes imbalance caused by keepalive connections sticking to one backend
158158
least_conn;
159+
zone gateway_backend 64k;
159160

160-
server gateway:4444 max_fails=0; # Disable failure tracking (always retry)
161+
server gateway:4444 resolve max_fails=0; # Re-resolve Docker DNS for replica-aware balancing
161162

162163
# Keepalive pool sizing for 4000 capacity:
163164
# - Each nginx worker maintains its own pool
@@ -213,11 +214,7 @@ http {
213214

214215
server_name localhost;
215216

216-
# Security headers
217-
add_header X-Frame-Options "SAMEORIGIN" always;
218-
add_header X-Content-Type-Options "nosniff" always;
219-
add_header X-XSS-Protection "1; mode=block" always;
220-
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
217+
# Security headers are authored by the gateway app middleware.
221218

222219
# Cache status header (for debugging)
223220
add_header X-Cache-Status $upstream_cache_status always;
@@ -296,11 +293,8 @@ http {
296293

297294
server_name localhost;
298295

299-
# Security headers (with HSTS for HTTPS)
300-
add_header X-Frame-Options "SAMEORIGIN" always;
301-
add_header X-Content-Type-Options "nosniff" always;
302-
add_header X-XSS-Protection "1; mode=block" always;
303-
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
296+
# Security headers are authored by the gateway app middleware.
297+
# HSTS can still be configured here if edge-only enforcement is desired.
304298
# HSTS - Only enable if you want to force HTTPS (15768000 = 6 months)
305299
# add_header Strict-Transport-Security "max-age=15768000; includeSubDomains" always;
306300

infra/nginx/nginx.conf

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,7 @@ http {
235235

236236
server_name localhost;
237237

238-
# Security headers
239-
add_header X-Frame-Options "SAMEORIGIN" always;
240-
add_header X-Content-Type-Options "nosniff" always;
241-
add_header X-XSS-Protection "1; mode=block" always;
242-
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
238+
# Security headers are authored by the gateway app middleware.
243239

244240
# Cache status header (for debugging)
245241
add_header X-Cache-Status $upstream_cache_status always;

mcpgateway/main.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
from mcpgateway.db import A2ATask as DbA2ATask
8686
from mcpgateway.db import refresh_slugs_on_startup, SessionLocal
8787
from mcpgateway.db import Tool as DbTool
88-
from mcpgateway.handlers.sampling import SamplingHandler
88+
from mcpgateway.handlers.sampling import SamplingError, SamplingHandler
8989
from mcpgateway.middleware.compression import SSEAwareCompressMiddleware
9090
from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware
9191
from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware, run_pre_request_hooks
@@ -152,7 +152,7 @@
152152
from mcpgateway.services.a2a_server_service import A2AServerService
153153
from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService
154154
from mcpgateway.services.cancellation_service import cancellation_service
155-
from mcpgateway.services.completion_service import CompletionService
155+
from mcpgateway.services.completion_service import CompletionError, CompletionService
156156
from mcpgateway.services.content_security import ContentSizeError, ContentTypeError
157157
from mcpgateway.services.email_auth_service import EmailAuthService
158158
from mcpgateway.services.export_service import ExportError, ExportService
@@ -3540,6 +3540,11 @@ async def require_valid_server(server_id: str, db: Session = Depends(get_db)) ->
35403540
return server_id
35413541

35423542

3543+
def _jsonrpc_invalid_request(req_id: Optional[Union[int, str]] = None) -> dict:
3544+
"""Build a JSON-RPC 2.0 ``Invalid Request`` error envelope."""
3545+
return {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid Request"}, "id": req_id}
3546+
3547+
35433548
async def _read_request_json(request: Request) -> Any:
35443549
"""Read JSON payload using orjson.
35453550

@@ -3821,23 +3826,14 @@ async def ping(request: Request, user=Depends(get_current_user)) -> JSONResponse
38213826
Raises:
38223827
HTTPException: If the request method is not "ping".
38233828
"""
3824-
req_id: Optional[str] = None
3825-
try:
3826-
body: dict = await _read_request_json(request)
3827-
if body.get("method") != "ping":
3828-
raise HTTPException(status_code=400, detail="Invalid method")
3829-
req_id = body.get("id")
3830-
logger.debug(f"Authenticated user {SecurityValidator.sanitize_log_message(str(user))} sent ping request.")
3831-
# Return an empty result per the MCP ping specification.
3832-
response: dict = {"jsonrpc": "2.0", "id": req_id, "result": {}}
3833-
return ORJSONResponse(content=response)
3834-
except Exception as e:
3835-
error_response: dict = {
3836-
"jsonrpc": "2.0",
3837-
"id": req_id, # Now req_id is always defined
3838-
"error": {"code": -32603, "message": "Internal error", "data": str(e)},
3839-
}
3840-
return ORJSONResponse(status_code=500, content=error_response)
3829+
body = await _read_request_json(request)
3830+
req_id = body.get("id") if isinstance(body, dict) else None
3831+
if not isinstance(body, dict) or body.get("method") != "ping":
3832+
return ORJSONResponse(status_code=400, content=_jsonrpc_invalid_request(req_id))
3833+
3834+
logger.debug(f"Authenticated user {SecurityValidator.sanitize_log_message(str(user))} sent ping request.")
3835+
response: dict = {"jsonrpc": "2.0", "id": req_id, "result": {}}
3836+
return ORJSONResponse(content=response)
38413837

38423838

38433839
@protocol_router.post("/notifications")
@@ -3887,6 +3883,9 @@ async def handle_completion(request: Request, db: Session = Depends(get_db), use
38873883

38883884
Returns:
38893885
The result of the completion process.
3886+
3887+
Raises:
3888+
HTTPException: If completion request validation fails.
38903889
"""
38913890
body = await _read_request_json(request)
38923891
logger.debug(f"User {SecurityValidator.sanitize_log_message(user['email'])} sent a completion request")
@@ -3895,7 +3894,10 @@ async def handle_completion(request: Request, db: Session = Depends(get_db), use
38953894
user_email = None
38963895
elif token_teams is None:
38973896
token_teams = []
3898-
return await completion_service.handle_completion(db, body, user_email=user_email, token_teams=token_teams)
3897+
try:
3898+
return await completion_service.handle_completion(db, body, user_email=user_email, token_teams=token_teams)
3899+
except CompletionError as exc:
3900+
raise HTTPException(status_code=400, detail=str(exc)) from exc
38993901

39003902

39013903
@protocol_router.post("/sampling/createMessage")
@@ -3910,10 +3912,16 @@ async def handle_sampling(request: Request, db: Session = Depends(get_db), user=
39103912

39113913
Returns:
39123914
The result of the message creation process.
3915+
3916+
Raises:
3917+
HTTPException: If sampling request validation fails.
39133918
"""
39143919
logger.debug(f"User {SecurityValidator.sanitize_log_message(user['email'])} sent a sampling request")
39153920
body = await _read_request_json(request)
3916-
return await sampling_handler.create_message(db, body)
3921+
try:
3922+
return await sampling_handler.create_message(db, body)
3923+
except SamplingError as exc:
3924+
raise HTTPException(status_code=400, detail=str(exc)) from exc
39173925

39183926

39193927
###############
@@ -10121,7 +10129,7 @@ async def _handle_rpc_authenticated(request: Request, db: Session, user):
1012110129
PluginError: If encounters issue with plugin
1012210130
PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy.
1012310131
"""
10124-
req_id = None
10132+
req_id: Optional[Union[int, str]] = None
1012510133
try:
1012610134
# Extract user identifier from either RBAC user object or JWT payload
1012710135
if hasattr(user, "email"):
@@ -10143,6 +10151,22 @@ async def _handle_rpc_authenticated(request: Request, db: Session, user):
1014310151
"id": None,
1014410152
},
1014510153
)
10154+
if not isinstance(body, dict):
10155+
return _jsonrpc_invalid_request()
10156+
10157+
req_id = body.get("id")
10158+
if req_id is not None and not isinstance(req_id, (str, int)):
10159+
return _jsonrpc_invalid_request()
10160+
10161+
method = body.get("method")
10162+
params = body.get("params", {})
10163+
if params is None:
10164+
params = {}
10165+
jsonrpc_version = body.get("jsonrpc")
10166+
10167+
if jsonrpc_version != "2.0" or not isinstance(method, str) or not method.strip() or not isinstance(params, dict):
10168+
return _jsonrpc_invalid_request(req_id)
10169+
1014610170
request_headers = request.headers
1014710171
lowered_headers: Optional[Dict[str, str]] = None
1014810172

@@ -10160,13 +10184,14 @@ def _lowered_request_headers() -> Dict[str, str]:
1016010184
_trusted_internal_mcp_dispatch = _get_internal_mcp_auth_context(request) is not None
1016110185
_internal_runtime_server_id = request_headers.get("x-contextforge-server-id") if request_headers.get("x-contextforge-mcp-runtime") == "rust" else None
1016210186

10163-
method = body["method"]
10164-
req_id = body.get("id")
10187+
if not _trusted_internal_mcp_dispatch:
10188+
try:
10189+
RPCRequest(jsonrpc=jsonrpc_version, method=method, params=params, id=req_id)
10190+
except (ValidationError, ValueError):
10191+
return _jsonrpc_invalid_request(req_id)
10192+
1016510193
if req_id is None:
1016610194
req_id = str(uuid.uuid4())
10167-
params = body.get("params", {})
10168-
if not isinstance(params, dict):
10169-
params = {}
1017010195
if _internal_runtime_server_id:
1017110196
params["server_id"] = _internal_runtime_server_id
1017210197
server_id = params.get("server_id", None)
@@ -10202,9 +10227,6 @@ def _lowered_request_headers() -> Dict[str, str]:
1020210227
elif _token_server_id is not None:
1020310228
server_id = _token_server_id
1020410229

10205-
if not _trusted_internal_mcp_dispatch:
10206-
RPCRequest(jsonrpc="2.0", method=method, params=params) # Validate the request body against the RPCRequest model
10207-
1020810230
forwarded_response = await _maybe_forward_affinitized_rpc_request(
1020910231
request,
1021010232
method=method,
@@ -10215,6 +10237,14 @@ def _lowered_request_headers() -> Dict[str, str]:
1021510237
if forwarded_response is not None:
1021610238
return forwarded_response
1021710239

10240+
if settings.use_stateful_sessions and mcp_session_id and method != "initialize":
10241+
try:
10242+
await _assert_session_owner_or_admin(request, user, mcp_session_id)
10243+
except HTTPException as exc:
10244+
if exc.status_code == status.HTTP_404_NOT_FOUND:
10245+
raise JSONRPCError(-32002, "Session not found", {"method": method}) from exc
10246+
raise JSONRPCError(-32003, str(exc.detail), {"method": method}) from exc
10247+
1021810248
if method == "initialize":
1021910249
result = await _execute_rpc_initialize(
1022010250
request,
@@ -10560,7 +10590,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1056010590
result = {}
1056110591
elif method == "sampling/createMessage":
1056210592
# MCP spec-compliant sampling endpoint
10563-
result = await sampling_handler.create_message(db, params)
10593+
try:
10594+
result = await sampling_handler.create_message(db, params)
10595+
except SamplingError as e:
10596+
raise JSONRPCError(-32602, str(e)) from e
1056410597
elif method.startswith("sampling/"):
1056510598
# Catch-all for other sampling/* methods (currently unsupported)
1056610599
result = {}
@@ -10656,7 +10689,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1065610689
user_email = None
1065710690
elif token_teams is None:
1065810691
token_teams = []
10659-
result = await completion_service.handle_completion(db, params, user_email=user_email, token_teams=token_teams)
10692+
try:
10693+
result = await completion_service.handle_completion(db, params, user_email=user_email, token_teams=token_teams)
10694+
except CompletionError as e:
10695+
raise JSONRPCError(-32602, str(e)) from e
1066010696
elif method.startswith("completion/"):
1066110697
# Catch-all for other completion/* methods (currently unsupported)
1066210698
result = {}
@@ -10722,12 +10758,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1072210758
error = e.to_dict()
1072310759
return {"jsonrpc": "2.0", "error": error["error"], "id": req_id}
1072410760
except Exception as e:
10725-
if isinstance(e, ValueError):
10726-
return ORJSONResponse(content={"message": "Method invalid"}, status_code=422)
1072710761
logger.error(f"RPC error: {str(e)}")
1072810762
return {
1072910763
"jsonrpc": "2.0",
10730-
"error": {"code": -32000, "message": "Internal error", "data": str(e)},
10764+
"error": {"code": -32603, "message": "Internal error"},
1073110765
"id": req_id,
1073210766
}
1073310767

mcpgateway/routers/oauth_router.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ async def initiate_oauth_flow(
412412
@oauth_router.get("/callback")
413413
async def oauth_callback(
414414
code: Annotated[str | None, Query(description="Authorization code from OAuth provider")] = None,
415-
state: Annotated[str, Query(description="State parameter for CSRF protection")] = ...,
415+
state: Annotated[str | None, Query(description="State parameter for CSRF protection")] = None,
416416
error: Annotated[str | None, Query(description="OAuth provider error code")] = None,
417417
error_description: Annotated[str | None, Query(description="OAuth provider error description")] = None,
418418
# Remove the gateway_id parameter requirement
@@ -510,6 +510,10 @@ def _invalid_state_response() -> HTMLResponse:
510510
status_code=400,
511511
)
512512

513+
if not state:
514+
logger.warning("OAuth callback missing state parameter")
515+
return _invalid_state_response()
516+
513517
oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
514518
gateway_id = await oauth_manager.resolve_gateway_id_from_state(state, allow_legacy_fallback=False)
515519
if not gateway_id:

0 commit comments

Comments
 (0)