Skip to content

Commit 1e47984

Browse files
author
Olivier Gintrand
committed
fix(security): enforce per-user token isolation in gateway proxy
CRITICAL: build_gateway_auth_headers() only used gateway-level default credentials, ignoring per-user tokens entirely. When User A stored a per-user API key for a gateway, User B (without their own key) would use User A's gateway-level credentials — a cross-user token leak. Add resolve_gateway_auth_headers() that checks per-user credentials (UserGatewayCredential, then OAuthToken) before falling back to gateway defaults. Update all 6 call sites: - streamablehttp_transport: _proxy_list_tools, _proxy_list_resources, _proxy_read_resource - tool_service: invoke_tool_direct - resource_service: direct_proxy resource read - prompt_service: _fetch_gateway_prompt_result
1 parent b878ae2 commit 1e47984

File tree

5 files changed

+101
-21
lines changed

5 files changed

+101
-21
lines changed

mcpgateway/services/prompt_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from mcpgateway.services.structured_logger import get_structured_logger
5959
from mcpgateway.services.team_management_service import TeamManagementService
6060
from mcpgateway.utils.create_slug import slugify
61-
from mcpgateway.utils.gateway_access import build_gateway_auth_headers
61+
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, resolve_gateway_auth_headers
6262
from mcpgateway.utils.metrics_common import build_top_performers
6363
from mcpgateway.utils.pagination import unified_paginate
6464
from mcpgateway.utils.services_auth import decode_auth
@@ -320,7 +320,10 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option
320320
raise PromptError(f"Prompt '{prompt.name}' is gateway-backed but missing gateway metadata")
321321

322322
gateway_url = str(gateway.url)
323-
headers = build_gateway_auth_headers(gateway)
323+
# Resolve per-user credentials (falls back to gateway defaults)
324+
from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel
325+
with SessionLocal() as db:
326+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_identity, db=db)
324327
auth_query_params_decrypted: Optional[Dict[str, str]] = None
325328

326329
if getattr(gateway, "auth_type", None) == "query_param" and getattr(gateway, "auth_query_params", None):

mcpgateway/services/resource_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from mcpgateway.services.oauth_manager import OAuthManager
7272
from mcpgateway.services.observability_service import current_trace_id, ObservabilityService
7373
from mcpgateway.services.structured_logger import get_structured_logger
74-
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access
74+
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, resolve_gateway_auth_headers
7575
from mcpgateway.utils.metrics_common import build_top_performers
7676
from mcpgateway.utils.pagination import unified_paginate
7777
from mcpgateway.utils.services_auth import decode_auth
@@ -2306,8 +2306,8 @@ async def read_resource(
23062306

23072307
gateway = resource_db.gateway
23082308

2309-
# Prepare headers with gateway auth
2310-
headers = build_gateway_auth_headers(gateway)
2309+
# Prepare headers with per-user credentials (falls back to gateway defaults)
2310+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user, db=db)
23112311

23122312
# Use MCP SDK to connect and read resource
23132313
async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):

mcpgateway/services/tool_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
from mcpgateway.utils.correlation_id import get_correlation_id
8989
from mcpgateway.utils.create_slug import slugify
9090
from mcpgateway.utils.display_name import generate_display_name
91-
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers
91+
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, resolve_gateway_auth_headers
9292
from mcpgateway.utils.metrics_common import build_top_performers
9393
from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate
9494
from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
@@ -2985,8 +2985,8 @@ async def invoke_tool_direct(
29852985
if not await check_gateway_access(db, gateway, user_email, token_teams):
29862986
raise ToolNotFoundError(f"Tool not found: {name}")
29872987

2988-
# Prepare headers with gateway auth
2989-
headers = build_gateway_auth_headers(gateway)
2988+
# Prepare headers with per-user credentials (falls back to gateway defaults)
2989+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_email, db=db)
29902990

29912991
# Forward passthrough headers if configured
29922992
if gateway.passthrough_headers and request_headers:

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
from mcpgateway.services.resource_service import ResourceService
7979
from mcpgateway.services.tool_service import ToolService
8080
from mcpgateway.transports.redis_event_store import RedisEventStore
81-
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, GATEWAY_ID_HEADER
81+
from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, GATEWAY_ID_HEADER, resolve_gateway_auth_headers
8282
from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify
8383
from mcpgateway.utils.log_sanitizer import sanitize_for_log
8484
from mcpgateway.utils.orjson_response import ORJSONResponse
@@ -1157,21 +1157,23 @@ async def _validate_streamable_session_access(
11571157
return False, HTTP_403_FORBIDDEN, "Session owner metadata unavailable"
11581158

11591159

1160-
async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Tool]: # pylint: disable=unused-argument
1160+
async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Tool]:
11611161
"""Proxy tools/list request directly to remote MCP gateway using MCP SDK.
11621162
11631163
Args:
11641164
gateway: Gateway ORM instance
11651165
request_headers: Request headers from client
1166-
user_context: User context (not used - _meta comes from MCP SDK)
1166+
user_context: User context dict with email for per-user credential lookup
11671167
meta: Request metadata (_meta) from the original request
11681168
11691169
Returns:
11701170
List of Tool objects from remote server
11711171
"""
11721172
try:
1173-
# Prepare headers with gateway auth
1174-
headers = build_gateway_auth_headers(gateway)
1173+
# Prepare headers with per-user credentials (falls back to gateway defaults)
1174+
user_email = user_context.get("email") if user_context else None
1175+
with SessionLocal() as db:
1176+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_email, db=db)
11751177

11761178
# Forward passthrough headers using shared utility (includes X-Upstream-Authorization rename)
11771179
if request_headers:
@@ -1209,21 +1211,23 @@ async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user
12091211
return []
12101212

12111213

1212-
async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Resource]: # pylint: disable=unused-argument
1214+
async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Resource]:
12131215
"""Proxy resources/list request directly to remote MCP gateway using MCP SDK.
12141216
12151217
Args:
12161218
gateway: Gateway ORM instance
12171219
request_headers: Request headers from client
1218-
user_context: User context (not used - _meta comes from MCP SDK)
1220+
user_context: User context dict with email for per-user credential lookup
12191221
meta: Request metadata (_meta) from the original request
12201222
12211223
Returns:
12221224
List of Resource objects from remote server
12231225
"""
12241226
try:
1225-
# Prepare headers with gateway auth
1226-
headers = build_gateway_auth_headers(gateway)
1227+
# Prepare headers with per-user credentials (falls back to gateway defaults)
1228+
user_email = user_context.get("email") if user_context else None
1229+
with SessionLocal() as db:
1230+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_email, db=db)
12271231

12281232
# Forward passthrough headers using shared utility (includes X-Upstream-Authorization rename)
12291233
if request_headers:
@@ -1267,21 +1271,23 @@ async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict,
12671271
return []
12681272

12691273

1270-
async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_context: dict, meta: Optional[Any] = None) -> List[Any]: # pylint: disable=unused-argument
1274+
async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_context: dict, meta: Optional[Any] = None) -> List[Any]:
12711275
"""Proxy resources/read request directly to remote MCP gateway using MCP SDK.
12721276
12731277
Args:
12741278
gateway: Gateway ORM instance
12751279
resource_uri: URI of the resource to read
1276-
user_context: User context (not used - auth comes from gateway config)
1280+
user_context: User context dict with email for per-user credential lookup
12771281
meta: Request metadata (_meta) from the original request
12781282
12791283
Returns:
12801284
List of content objects (TextResourceContents or BlobResourceContents) from remote server
12811285
"""
12821286
try:
1283-
# Prepare headers with gateway auth
1284-
headers = build_gateway_auth_headers(gateway)
1287+
# Prepare headers with per-user credentials (falls back to gateway defaults)
1288+
user_email = user_context.get("email") if user_context else None
1289+
with SessionLocal() as db:
1290+
headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_email, db=db)
12851291

12861292
# Get request headers
12871293
request_headers = request_headers_var.get()

mcpgateway/utils/gateway_access.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
# Standard
13+
import logging
1314
from typing import Dict, List, Optional
1415

1516
# Third-Party
@@ -19,6 +20,8 @@
1920
from mcpgateway.db import Gateway as DbGateway
2021
from mcpgateway.utils.services_auth import decode_auth
2122

23+
logger = logging.getLogger(__name__)
24+
2225
# Header name used by clients to target a specific gateway for direct_proxy mode.
2326
# Defined once here to avoid string literal repetition across the codebase.
2427
GATEWAY_ID_HEADER = "X-Context-Forge-Gateway-Id"
@@ -159,3 +162,71 @@ def build_gateway_auth_headers(gateway: DbGateway) -> Dict[str, str]:
159162
headers["Authorization"] = auth_header
160163

161164
return headers
165+
166+
167+
async def resolve_gateway_auth_headers(
168+
gateway: DbGateway,
169+
app_user_email: Optional[str] = None,
170+
db: Optional[Session] = None,
171+
) -> Dict[str, str]:
172+
"""Resolve auth headers for gateway requests, preferring per-user credentials.
173+
174+
Security-critical: ensures per-user credential isolation. When a user has
175+
stored personal credentials (API key, bearer token, basic auth) or OAuth
176+
tokens for a gateway, those MUST be used instead of the gateway defaults.
177+
178+
Lookup order:
179+
1. Per-user personal credential (UserGatewayCredential table)
180+
2. Per-user OAuth token (OAuthToken table)
181+
3. Gateway-level default credentials (gateway.auth_value)
182+
183+
Args:
184+
gateway: Gateway ORM object with auth configuration.
185+
app_user_email: Email of the requesting user (None skips per-user lookup).
186+
db: Database session for per-user credential lookup.
187+
188+
Returns:
189+
Dictionary of HTTP headers with Authorization header.
190+
"""
191+
if app_user_email and db:
192+
# 1. Check per-user personal credentials (API keys, bearer tokens, basic auth)
193+
try:
194+
from mcpgateway.services.credential_storage_service import CredentialStorageService # pylint: disable=import-outside-toplevel
195+
196+
cred_service = CredentialStorageService(db)
197+
record = await cred_service.get_credential_record(gateway.id, app_user_email)
198+
if record:
199+
decrypted = await cred_service.get_credential(gateway.id, app_user_email)
200+
if decrypted:
201+
if record.credential_type in ("bearer_token", "api_key"):
202+
logger.debug(
203+
"Using per-user %s credential for gateway %s, user %s",
204+
record.credential_type, gateway.id, app_user_email,
205+
)
206+
return {"Authorization": f"Bearer {decrypted}"}
207+
elif record.credential_type == "basic_auth":
208+
logger.debug(
209+
"Using per-user basic_auth credential for gateway %s, user %s",
210+
gateway.id, app_user_email,
211+
)
212+
return {"Authorization": f"Basic {decrypted}"}
213+
except Exception:
214+
logger.exception("Error looking up per-user credential for gateway %s", gateway.id)
215+
216+
# 2. Check per-user OAuth tokens
217+
try:
218+
from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
219+
220+
token_service = TokenStorageService(db)
221+
oauth_token = await token_service.get_user_token(gateway.id, app_user_email)
222+
if oauth_token:
223+
logger.debug(
224+
"Using per-user OAuth token for gateway %s, user %s",
225+
gateway.id, app_user_email,
226+
)
227+
return {"Authorization": f"Bearer {oauth_token}"}
228+
except Exception:
229+
logger.exception("Error looking up per-user OAuth token for gateway %s", gateway.id)
230+
231+
# 3. Fall back to gateway-level default credentials
232+
return build_gateway_auth_headers(gateway)

0 commit comments

Comments
 (0)