8585from mcpgateway.db import A2ATask as DbA2ATask
8686from mcpgateway.db import refresh_slugs_on_startup, SessionLocal
8787from mcpgateway.db import Tool as DbTool
88- from mcpgateway.handlers.sampling import SamplingHandler
88+ from mcpgateway.handlers.sampling import SamplingError, SamplingHandler
8989from mcpgateway.middleware.compression import SSEAwareCompressMiddleware
9090from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware
9191from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware, run_pre_request_hooks
152152from mcpgateway.services.a2a_server_service import A2AServerService
153153from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService
154154from mcpgateway.services.cancellation_service import cancellation_service
155- from mcpgateway.services.completion_service import CompletionService
155+ from mcpgateway.services.completion_service import CompletionError, CompletionService
156156from mcpgateway.services.content_security import ContentSizeError, ContentTypeError
157157from mcpgateway.services.email_auth_service import EmailAuthService
158158from mcpgateway.services.export_service import ExportError, ExportService
@@ -3540,6 +3540,11 @@ async def require_valid_server(server_id: str, db: Session = Depends(get_db)) ->
35403540 return server_id
35413541
35423542
3543+ def _jsonrpc_invalid_request(req_id: Optional[Union[int, str]] = None) -> dict:
3544+ """Build a JSON-RPC 2.0 ``Invalid Request`` error envelope."""
3545+ return {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid Request"}, "id": req_id}
3546+
3547+
35433548async def _read_request_json(request: Request) -> Any:
35443549 """Read JSON payload using orjson.
35453550
@@ -3821,23 +3826,14 @@ async def ping(request: Request, user=Depends(get_current_user)) -> JSONResponse
38213826 Raises:
38223827 HTTPException: If the request method is not "ping".
38233828 """
3824- req_id: Optional[str] = None
3825- try:
3826- body: dict = await _read_request_json(request)
3827- if body.get("method") != "ping":
3828- raise HTTPException(status_code=400, detail="Invalid method")
3829- req_id = body.get("id")
3830- logger.debug(f"Authenticated user {SecurityValidator.sanitize_log_message(str(user))} sent ping request.")
3831- # Return an empty result per the MCP ping specification.
3832- response: dict = {"jsonrpc": "2.0", "id": req_id, "result": {}}
3833- return ORJSONResponse(content=response)
3834- except Exception as e:
3835- error_response: dict = {
3836- "jsonrpc": "2.0",
3837- "id": req_id, # Now req_id is always defined
3838- "error": {"code": -32603, "message": "Internal error", "data": str(e)},
3839- }
3840- return ORJSONResponse(status_code=500, content=error_response)
3829+ body = await _read_request_json(request)
3830+ req_id = body.get("id") if isinstance(body, dict) else None
3831+ if not isinstance(body, dict) or body.get("method") != "ping":
3832+ return ORJSONResponse(status_code=400, content=_jsonrpc_invalid_request(req_id))
3833+
3834+ logger.debug(f"Authenticated user {SecurityValidator.sanitize_log_message(str(user))} sent ping request.")
3835+ response: dict = {"jsonrpc": "2.0", "id": req_id, "result": {}}
3836+ return ORJSONResponse(content=response)
38413837
38423838
38433839@protocol_router.post("/notifications")
@@ -3887,6 +3883,9 @@ async def handle_completion(request: Request, db: Session = Depends(get_db), use
38873883
38883884 Returns:
38893885 The result of the completion process.
3886+
3887+ Raises:
3888+ HTTPException: If completion request validation fails.
38903889 """
38913890 body = await _read_request_json(request)
38923891 logger.debug(f"User {SecurityValidator.sanitize_log_message(user['email'])} sent a completion request")
@@ -3895,7 +3894,10 @@ async def handle_completion(request: Request, db: Session = Depends(get_db), use
38953894 user_email = None
38963895 elif token_teams is None:
38973896 token_teams = []
3898- return await completion_service.handle_completion(db, body, user_email=user_email, token_teams=token_teams)
3897+ try:
3898+ return await completion_service.handle_completion(db, body, user_email=user_email, token_teams=token_teams)
3899+ except CompletionError as exc:
3900+ raise HTTPException(status_code=400, detail=str(exc)) from exc
38993901
39003902
39013903@protocol_router.post("/sampling/createMessage")
@@ -3910,10 +3912,16 @@ async def handle_sampling(request: Request, db: Session = Depends(get_db), user=
39103912
39113913 Returns:
39123914 The result of the message creation process.
3915+
3916+ Raises:
3917+ HTTPException: If sampling request validation fails.
39133918 """
39143919 logger.debug(f"User {SecurityValidator.sanitize_log_message(user['email'])} sent a sampling request")
39153920 body = await _read_request_json(request)
3916- return await sampling_handler.create_message(db, body)
3921+ try:
3922+ return await sampling_handler.create_message(db, body)
3923+ except SamplingError as exc:
3924+ raise HTTPException(status_code=400, detail=str(exc)) from exc
39173925
39183926
39193927###############
@@ -10121,7 +10129,7 @@ async def _handle_rpc_authenticated(request: Request, db: Session, user):
1012110129 PluginError: If encounters issue with plugin
1012210130 PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy.
1012310131 """
10124- req_id = None
10132+ req_id: Optional[Union[int, str]] = None
1012510133 try:
1012610134 # Extract user identifier from either RBAC user object or JWT payload
1012710135 if hasattr(user, "email"):
@@ -10143,6 +10151,22 @@ async def _handle_rpc_authenticated(request: Request, db: Session, user):
1014310151 "id": None,
1014410152 },
1014510153 )
10154+ if not isinstance(body, dict):
10155+ return _jsonrpc_invalid_request()
10156+
10157+ req_id = body.get("id")
10158+ if req_id is not None and not isinstance(req_id, (str, int)):
10159+ return _jsonrpc_invalid_request()
10160+
10161+ method = body.get("method")
10162+ params = body.get("params", {})
10163+ if params is None:
10164+ params = {}
10165+ jsonrpc_version = body.get("jsonrpc")
10166+
10167+ if jsonrpc_version != "2.0" or not isinstance(method, str) or not method.strip() or not isinstance(params, dict):
10168+ return _jsonrpc_invalid_request(req_id)
10169+
1014610170 request_headers = request.headers
1014710171 lowered_headers: Optional[Dict[str, str]] = None
1014810172
@@ -10160,13 +10184,14 @@ def _lowered_request_headers() -> Dict[str, str]:
1016010184 _trusted_internal_mcp_dispatch = _get_internal_mcp_auth_context(request) is not None
1016110185 _internal_runtime_server_id = request_headers.get("x-contextforge-server-id") if request_headers.get("x-contextforge-mcp-runtime") == "rust" else None
1016210186
10163- method = body["method"]
10164- req_id = body.get("id")
10187+ if not _trusted_internal_mcp_dispatch:
10188+ try:
10189+ RPCRequest(jsonrpc=jsonrpc_version, method=method, params=params, id=req_id)
10190+ except (ValidationError, ValueError):
10191+ return _jsonrpc_invalid_request(req_id)
10192+
1016510193 if req_id is None:
1016610194 req_id = str(uuid.uuid4())
10167- params = body.get("params", {})
10168- if not isinstance(params, dict):
10169- params = {}
1017010195 if _internal_runtime_server_id:
1017110196 params["server_id"] = _internal_runtime_server_id
1017210197 server_id = params.get("server_id", None)
@@ -10202,9 +10227,6 @@ def _lowered_request_headers() -> Dict[str, str]:
1020210227 elif _token_server_id is not None:
1020310228 server_id = _token_server_id
1020410229
10205- if not _trusted_internal_mcp_dispatch:
10206- RPCRequest(jsonrpc="2.0", method=method, params=params) # Validate the request body against the RPCRequest model
10207-
1020810230 forwarded_response = await _maybe_forward_affinitized_rpc_request(
1020910231 request,
1021010232 method=method,
@@ -10215,6 +10237,14 @@ def _lowered_request_headers() -> Dict[str, str]:
1021510237 if forwarded_response is not None:
1021610238 return forwarded_response
1021710239
10240+ if settings.use_stateful_sessions and mcp_session_id and method != "initialize":
10241+ try:
10242+ await _assert_session_owner_or_admin(request, user, mcp_session_id)
10243+ except HTTPException as exc:
10244+ if exc.status_code == status.HTTP_404_NOT_FOUND:
10245+ raise JSONRPCError(-32002, "Session not found", {"method": method}) from exc
10246+ raise JSONRPCError(-32003, str(exc.detail), {"method": method}) from exc
10247+
1021810248 if method == "initialize":
1021910249 result = await _execute_rpc_initialize(
1022010250 request,
@@ -10560,7 +10590,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1056010590 result = {}
1056110591 elif method == "sampling/createMessage":
1056210592 # MCP spec-compliant sampling endpoint
10563- result = await sampling_handler.create_message(db, params)
10593+ try:
10594+ result = await sampling_handler.create_message(db, params)
10595+ except SamplingError as e:
10596+ raise JSONRPCError(-32602, str(e)) from e
1056410597 elif method.startswith("sampling/"):
1056510598 # Catch-all for other sampling/* methods (currently unsupported)
1056610599 result = {}
@@ -10656,7 +10689,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1065610689 user_email = None
1065710690 elif token_teams is None:
1065810691 token_teams = []
10659- result = await completion_service.handle_completion(db, params, user_email=user_email, token_teams=token_teams)
10692+ try:
10693+ result = await completion_service.handle_completion(db, params, user_email=user_email, token_teams=token_teams)
10694+ except CompletionError as e:
10695+ raise JSONRPCError(-32602, str(e)) from e
1066010696 elif method.startswith("completion/"):
1066110697 # Catch-all for other completion/* methods (currently unsupported)
1066210698 result = {}
@@ -10722,12 +10758,10 @@ def _lowered_request_headers() -> Dict[str, str]:
1072210758 error = e.to_dict()
1072310759 return {"jsonrpc": "2.0", "error": error["error"], "id": req_id}
1072410760 except Exception as e:
10725- if isinstance(e, ValueError):
10726- return ORJSONResponse(content={"message": "Method invalid"}, status_code=422)
1072710761 logger.error(f"RPC error: {str(e)}")
1072810762 return {
1072910763 "jsonrpc": "2.0",
10730- "error": {"code": -32000 , "message": "Internal error", "data": str(e) },
10764+ "error": {"code": -32603 , "message": "Internal error"},
1073110765 "id": req_id,
1073210766 }
1073310767
0 commit comments