diff --git a/mcpgateway/admin_ui/admin.js b/mcpgateway/admin_ui/admin.js index 7fdb825605..738c247df9 100644 --- a/mcpgateway/admin_ui/admin.js +++ b/mcpgateway/admin_ui/admin.js @@ -149,11 +149,14 @@ Admin.handleSubmitWithConfirmation = handleSubmitWithConfirmation; Admin.handleDeleteSubmit = handleDeleteSubmit; // Gateways -import { editGateway, refreshGatewayTools, refreshToolsForSelectedGateways, testGateway, viewGateway } from "./gateways.js"; +import { editGateway, openCredentialModal, refreshGatewayTools, refreshToolsForSelectedGateways, revokeCredential, submitCredential, testGateway, viewGateway } from "./gateways.js"; Admin.editGateway = editGateway; +Admin.openCredentialModal = openCredentialModal; Admin.refreshGatewayTools = refreshGatewayTools; Admin.refreshToolsForSelectedGateways = refreshToolsForSelectedGateways; +Admin.revokeCredential = revokeCredential; +Admin.submitCredential = submitCredential; Admin.testGateway = testGateway; Admin.viewGateway = viewGateway; diff --git a/mcpgateway/admin_ui/gateways.js b/mcpgateway/admin_ui/gateways.js index c8028dcf6e..6d6adf37a5 100644 --- a/mcpgateway/admin_ui/gateways.js +++ b/mcpgateway/admin_ui/gateways.js @@ -1936,3 +1936,106 @@ export const refreshToolsForSelectedGateways = async function(buttonEl) { reloadAssociatedItems(); } } + +// --------------------------------------------------------------------------- +// Personal Credential Management +// --------------------------------------------------------------------------- + +/** + * Open the credential modal for a gateway, checking current credential status. + */ +export const openCredentialModal = async function (gatewayId, gatewayName) { + document.getElementById("credential-gateway-id").value = gatewayId; + document.getElementById("credential-gateway-name").textContent = gatewayName || gatewayId; + document.getElementById("credential-value").value = ""; + document.getElementById("credential-label").value = ""; + document.getElementById("credential-type").value = "api_key"; + const statusEl = document.getElementById("credential-status"); + statusEl.classList.add("hidden"); + + // Check if user already has a credential for this gateway + try { + const res = await fetchWithTimeout(`${window.ROOT_PATH}/credentials/${gatewayId}`, { + headers: { Accept: "application/json" }, + }); + if (res.ok) { + const data = await res.json(); + if (data.has_credential) { + statusEl.innerHTML = `✓ You have a stored ${data.credential_type} credential${data.label ? ` (${data.label})` : ""}. Submitting will replace it.`; + statusEl.classList.remove("hidden"); + if (data.credential_type) { + document.getElementById("credential-type").value = data.credential_type; + } + if (data.label) { + document.getElementById("credential-label").value = data.label; + } + } + } + } catch (_) { + // Silently ignore — modal still opens + } + + openModal("credential-modal"); +}; + +/** + * Submit the credential form to store a personal credential. + */ +export const submitCredential = async function () { + const gatewayId = document.getElementById("credential-gateway-id").value; + const credentialType = document.getElementById("credential-type").value; + const credentialValue = document.getElementById("credential-value").value; + const label = document.getElementById("credential-label").value || null; + + if (!credentialValue) { + showErrorMessage("Credential value is required"); + return; + } + + try { + const res = await fetchWithTimeout(`${window.ROOT_PATH}/credentials/${gatewayId}`, { + method: "POST", + headers: { "Content-Type": "application/json", Accept: "application/json" }, + body: JSON.stringify({ + credential_type: credentialType, + credential_value: credentialValue, + label: label, + }), + }); + const data = await res.json(); + if (res.ok && data.success) { + showSuccessMessage("Personal credential stored successfully"); + closeModal("credential-modal"); + } else { + showErrorMessage(data.detail || data.message || "Failed to store credential"); + } + } catch (err) { + showErrorMessage(`Failed to store credential: ${err.message}`); + } +}; + +/** + * Revoke the stored credential for the current gateway. + */ +export const revokeCredential = async function () { + const gatewayId = document.getElementById("credential-gateway-id").value; + if (!confirm("Are you sure you want to revoke your personal credential for this gateway?")) { + return; + } + + try { + const res = await fetchWithTimeout(`${window.ROOT_PATH}/credentials/${gatewayId}`, { + method: "DELETE", + headers: { Accept: "application/json" }, + }); + const data = await res.json(); + if (res.ok && data.success) { + showSuccessMessage("Personal credential revoked"); + closeModal("credential-modal"); + } else { + showErrorMessage(data.message || "No credential found to revoke"); + } + } catch (err) { + showErrorMessage(`Failed to revoke credential: ${err.message}`); + } +}; diff --git a/mcpgateway/alembic/versions/a1b2c3d4e5f6_add_user_gateway_credentials_table.py b/mcpgateway/alembic/versions/a1b2c3d4e5f6_add_user_gateway_credentials_table.py new file mode 100644 index 0000000000..85847b5447 --- /dev/null +++ b/mcpgateway/alembic/versions/a1b2c3d4e5f6_add_user_gateway_credentials_table.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""Add user_gateway_credentials table for per-user personal credentials + +Revision ID: a1b2c3d4e5f6 +Revises: z1a2b3c4d5e6 +Create Date: 2026-04-02 10:00:00.000000 + +""" + +# Third-Party +from alembic import op +import sqlalchemy as sa +from sqlalchemy import inspect + +# revision identifiers, used by Alembic. +revision = "a1b2c3d4e5f6" +down_revision = "z1a2b3c4d5e6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = inspect(conn) + existing_tables = inspector.get_table_names() + + if "user_gateway_credentials" not in existing_tables: + op.create_table( + "user_gateway_credentials", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("gateway_id", sa.String(36), sa.ForeignKey("gateways.id", ondelete="CASCADE"), nullable=False), + sa.Column("app_user_email", sa.String(255), sa.ForeignKey("email_users.email", ondelete="CASCADE"), nullable=False), + sa.Column("credential_type", sa.String(50), nullable=False), + sa.Column("credential_value", sa.Text(), nullable=False), + sa.Column("label", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now()), + sa.UniqueConstraint("gateway_id", "app_user_email", name="uq_credential_gateway_user"), + ) + + existing_indexes = {idx["name"] for idx in inspector.get_indexes("user_gateway_credentials")} + if "idx_user_credentials_gateway" not in existing_indexes: + op.create_index("idx_user_credentials_gateway", "user_gateway_credentials", ["gateway_id"]) + if "idx_user_credentials_email" not in existing_indexes: + op.create_index("idx_user_credentials_email", "user_gateway_credentials", ["app_user_email"]) + + +def downgrade() -> None: + op.drop_index("idx_user_credentials_email", table_name="user_gateway_credentials") + op.drop_index("idx_user_credentials_gateway", table_name="user_gateway_credentials") + op.drop_table("user_gateway_credentials") diff --git a/mcpgateway/db.py b/mcpgateway/db.py index f3dac09f5e..7024ec101c 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -4671,6 +4671,9 @@ def team(self) -> Optional[str]: # Relationship with OAuth tokens oauth_tokens: Mapped[List["OAuthToken"]] = relationship("OAuthToken", back_populates="gateway", cascade="all, delete-orphan") + # Relationship with per-user personal credentials + user_credentials: Mapped[List["UserGatewayCredential"]] = relationship("UserGatewayCredential", back_populates="gateway", cascade="all, delete-orphan") + # Relationship with registered OAuth clients (DCR) registered_oauth_clients: Mapped[List["RegisteredOAuthClient"]] = relationship("RegisteredOAuthClient", back_populates="gateway", cascade="all, delete-orphan") @@ -5208,6 +5211,33 @@ class OAuthToken(Base): __table_args__ = (UniqueConstraint("gateway_id", "app_user_email", name="uq_oauth_gateway_user"),) +class UserGatewayCredential(Base): + """ORM model for per-user personal credentials (API keys, PATs, basic auth) for gateways. + + Unlike OAuthToken which stores tokens obtained via OAuth flows, this model stores + credentials that users manually provide for gateways where OAuth is not supported + (e.g., API keys, personal access tokens, basic auth credentials). + """ + + __tablename__ = "user_gateway_credentials" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + gateway_id: Mapped[str] = mapped_column(String(36), ForeignKey("gateways.id", ondelete="CASCADE"), nullable=False) + app_user_email: Mapped[str] = mapped_column(String(255), ForeignKey("email_users.email", ondelete="CASCADE"), nullable=False) + credential_type: Mapped[str] = mapped_column(String(50), nullable=False) # "api_key", "bearer_token", "basic_auth" + credential_value: Mapped[str] = mapped_column(EncryptedText(), nullable=False) + label: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) + + # Relationships + gateway: Mapped["Gateway"] = relationship("Gateway", back_populates="user_credentials") + app_user: Mapped["EmailUser"] = relationship("EmailUser", foreign_keys=[app_user_email]) + + # Unique constraint: one credential per user per gateway + __table_args__ = (UniqueConstraint("gateway_id", "app_user_email", name="uq_credential_gateway_user"),) + + class OAuthState(Base): """ORM model for OAuth authorization states with TTL for CSRF protection.""" diff --git a/mcpgateway/main.py b/mcpgateway/main.py index a07b60482b..f7cec96ee6 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -11837,6 +11837,16 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr except ImportError: logger.debug("OAuth router not available") +# Include personal credential router +try: + # First-Party + from mcpgateway.routers.credential_router import credential_router + + app.include_router(credential_router) + logger.info("Credential router included") +except ImportError: + logger.debug("Credential router not available") + # Include reverse proxy router if enabled if settings.mcpgateway_reverse_proxy_enabled: try: diff --git a/mcpgateway/routers/credential_router.py b/mcpgateway/routers/credential_router.py new file mode 100644 index 0000000000..f177594e34 --- /dev/null +++ b/mcpgateway/routers/credential_router.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/credential_router.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Personal Credential Router for ContextForge. + +This module provides REST endpoints for users to manage their personal +credentials (API keys, PATs, basic auth) for gateways where OAuth is not +supported or not sufficient for all endpoints. + +Endpoints: +- POST /credentials/{gateway_id} — store a personal credential +- GET /credentials/{gateway_id} — get credential status for current user +- DELETE /credentials/{gateway_id} — revoke stored credential +- GET /credentials — list all stored credentials for current user +""" + +# Standard +import logging +from typing import Any, Dict, List, Optional + +# Third-Party +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.common.validators import SecurityValidator +from mcpgateway.db import Gateway, get_db +from mcpgateway.middleware.rbac import get_current_user_with_permissions +from mcpgateway.schemas import EmailUserResponse +from mcpgateway.services.credential_storage_service import VALID_CREDENTIAL_TYPES, CredentialStorageService + +logger = logging.getLogger(__name__) + +credential_router = APIRouter(prefix="/credentials", tags=["credentials"]) + + +# --------------------------------------------------------------------------- +# Helpers (reused patterns from oauth_router) +# --------------------------------------------------------------------------- + + +def _extract_user_email(current_user: EmailUserResponse | dict) -> str | None: + """Extract requester email from typed or dict user contexts.""" + if hasattr(current_user, "email"): + email = getattr(current_user, "email", None) + if isinstance(email, str) and email.strip(): + return email.strip().lower() + if isinstance(current_user, dict): + email = current_user.get("email") or current_user.get("user", {}).get("email") + if isinstance(email, str) and email.strip(): + return email.strip().lower() + return None + + +def _extract_is_admin(current_user: EmailUserResponse | dict) -> bool: + """Extract admin flag from typed or dict user contexts.""" + if hasattr(current_user, "is_admin"): + return bool(getattr(current_user, "is_admin", False)) + if isinstance(current_user, dict): + return bool(current_user.get("is_admin", False) or current_user.get("user", {}).get("is_admin", False)) + return False + + +async def _enforce_gateway_access(gateway_id: str, gateway: Gateway, current_user: EmailUserResponse, db: Session) -> None: + """Enforce gateway visibility and ownership checks.""" + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + requester_is_admin = _extract_is_admin(current_user) + if requester_is_admin: + return + + visibility = str(getattr(gateway, "visibility", "team") or "team").lower() + gateway_owner = getattr(gateway, "owner_email", None) + gateway_team_id = getattr(gateway, "team_id", None) + + if visibility == "public": + return + + if visibility == "team": + if not gateway_team_id: + raise HTTPException(status_code=403, detail="You don't have access to this gateway") + from mcpgateway.services.email_auth_service import EmailAuthService + auth_service = EmailAuthService(db) + user = await auth_service.get_user_by_email(requester_email) + if not user or not user.is_team_member(gateway_team_id): + raise HTTPException(status_code=403, detail="You don't have access to this gateway") + return + + if visibility in {"private", "user"}: + if gateway_owner and gateway_owner.strip().lower() == requester_email: + return + raise HTTPException(status_code=403, detail="You don't have access to this gateway") + + if gateway_owner and gateway_owner.strip().lower() == requester_email: + return + if gateway_team_id: + from mcpgateway.services.email_auth_service import EmailAuthService + auth_service = EmailAuthService(db) + user = await auth_service.get_user_by_email(requester_email) + if user and user.is_team_member(gateway_team_id): + return + + raise HTTPException(status_code=403, detail="You don't have access to this gateway") + + +# --------------------------------------------------------------------------- +# Request / Response Models +# --------------------------------------------------------------------------- + + +class CredentialStoreRequest(BaseModel): + """Request body for storing a personal credential.""" + + credential_type: str = Field( + ..., + description=f"Type of credential. Must be one of: {', '.join(sorted(VALID_CREDENTIAL_TYPES))}", + ) + credential_value: str = Field(..., min_length=1, description="The credential value (API key, token, or username:password)") + label: Optional[str] = Field(None, max_length=255, description="Optional user-friendly label for the credential") + + +class CredentialInfoResponse(BaseModel): + """Response for credential info queries.""" + + gateway_id: str + gateway_name: Optional[str] = None + app_user_email: str + credential_type: str + label: Optional[str] = None + created_at: Optional[str] = None + updated_at: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@credential_router.post("/{gateway_id}") +async def store_credential( + gateway_id: str, + body: CredentialStoreRequest, + current_user: EmailUserResponse = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> Dict[str, Any]: + """Store a personal credential for a gateway. + + Allows authenticated users to store a personal API key, bearer token, or basic auth + credential for a specific gateway. This credential is used instead of the gateway's + shared credentials when the user invokes tools on this gateway. + + The credential is encrypted at rest using the platform encryption service. + """ + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none() + if not gateway: + raise HTTPException(status_code=404, detail="Gateway not found") + + await _enforce_gateway_access(gateway_id, gateway, current_user, db) + + if body.credential_type not in VALID_CREDENTIAL_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid credential_type '{body.credential_type}'. Must be one of: {', '.join(sorted(VALID_CREDENTIAL_TYPES))}", + ) + + try: + credential_service = CredentialStorageService(db) + record = await credential_service.store_credential( + gateway_id=gateway_id, + app_user_email=requester_email, + credential_type=body.credential_type, + credential_value=body.credential_value, + label=body.label, + ) + + logger.info( + f"Credential stored via API for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, " + f"user {SecurityValidator.sanitize_log_message(requester_email)}" + ) + + return { + "success": True, + "gateway_id": gateway_id, + "app_user_email": requester_email, + "credential_type": body.credential_type, + "label": body.label, + "message": "Personal credential stored successfully", + } + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error( + f"Failed to store credential for gateway {SecurityValidator.sanitize_log_message(gateway_id)}: {e}" + ) + raise HTTPException(status_code=500, detail=f"Failed to store credential: {str(e)}") + + +@credential_router.get("/{gateway_id}") +async def get_credential_status( + gateway_id: str, + current_user: EmailUserResponse = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> Dict[str, Any]: + """Get credential status for the authenticated user on a gateway. + + Returns metadata about the stored credential without exposing the secret value. + """ + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none() + if not gateway: + raise HTTPException(status_code=404, detail="Gateway not found") + + await _enforce_gateway_access(gateway_id, gateway, current_user, db) + + try: + credential_service = CredentialStorageService(db) + info = await credential_service.get_credential_info(gateway_id, requester_email) + + if not info: + return {"has_credential": False, "gateway_id": gateway_id, "message": "No personal credential stored for this gateway"} + + return { + "has_credential": True, + "gateway_id": gateway_id, + **info, + } + + except Exception as e: + logger.error( + f"Failed to get credential status for gateway {SecurityValidator.sanitize_log_message(gateway_id)}: {e}" + ) + raise HTTPException(status_code=500, detail=f"Failed to get credential status: {str(e)}") + + +@credential_router.delete("/{gateway_id}") +async def revoke_credential( + gateway_id: str, + current_user: EmailUserResponse = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> Dict[str, Any]: + """Revoke stored personal credential for the authenticated user on a gateway.""" + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + gateway = db.execute(select(Gateway).where(Gateway.id == gateway_id)).scalar_one_or_none() + if not gateway: + raise HTTPException(status_code=404, detail="Gateway not found") + + await _enforce_gateway_access(gateway_id, gateway, current_user, db) + + try: + credential_service = CredentialStorageService(db) + revoked = await credential_service.revoke_credential(gateway_id, requester_email) + + if revoked: + logger.info( + f"Credential revoked via API for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, " + f"user {SecurityValidator.sanitize_log_message(requester_email)}" + ) + return {"success": True, "gateway_id": gateway_id, "message": "Personal credential revoked successfully"} + + return {"success": False, "gateway_id": gateway_id, "message": "No personal credential found for this gateway"} + + except Exception as e: + logger.error( + f"Failed to revoke credential for gateway {SecurityValidator.sanitize_log_message(gateway_id)}: {e}" + ) + raise HTTPException(status_code=500, detail=f"Failed to revoke credential: {str(e)}") + + +@credential_router.get("") +async def list_credentials( + current_user: EmailUserResponse = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> Dict[str, Any]: + """List all gateways where the authenticated user has stored personal credentials. + + Returns metadata about each credential without exposing secret values. + """ + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + try: + credential_service = CredentialStorageService(db) + credentials = await credential_service.list_user_credentials(requester_email) + + # Enrich with gateway names + for cred in credentials: + gateway = db.execute(select(Gateway).where(Gateway.id == cred["gateway_id"])).scalar_one_or_none() + cred["gateway_name"] = gateway.name if gateway else cred["gateway_id"] + + return { + "credentials": credentials, + "count": len(credentials), + } + + except Exception as e: + logger.error(f"Failed to list credentials: {e}") + raise HTTPException(status_code=500, detail=f"Failed to list credentials: {str(e)}") diff --git a/mcpgateway/services/credential_storage_service.py b/mcpgateway/services/credential_storage_service.py new file mode 100644 index 0000000000..da269be545 --- /dev/null +++ b/mcpgateway/services/credential_storage_service.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/credential_storage_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Personal Credential Storage Service for ContextForge. + +This module handles the storage, retrieval, and management of per-user personal +credentials (API keys, bearer tokens, basic auth) for gateways where OAuth is not +supported or not sufficient. +""" + +# Standard +import base64 +import logging +from typing import Any, Dict, List, Optional + +# Third-Party +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.common.validators import SecurityValidator +from mcpgateway.config import get_settings +from mcpgateway.db import UserGatewayCredential +from mcpgateway.services.encryption_service import get_encryption_service + +logger = logging.getLogger(__name__) + +# Valid credential types +VALID_CREDENTIAL_TYPES = {"api_key", "bearer_token", "basic_auth"} + + +class CredentialStorageService: + """Manages per-user personal credential storage and retrieval for gateways.""" + + def __init__(self, db: Session): + self.db = db + try: + settings = get_settings() + self.encryption = get_encryption_service(settings.auth_encryption_secret) + except (ImportError, AttributeError): + logger.warning("Encryption not available for credential storage, using plain text") + self.encryption = None + + async def store_credential( + self, + gateway_id: str, + app_user_email: str, + credential_type: str, + credential_value: str, + label: Optional[str] = None, + ) -> UserGatewayCredential: + """Store or update a personal credential for a gateway-user combination. + + Args: + gateway_id: ID of the gateway + app_user_email: ContextForge user email + credential_type: Type of credential ("api_key", "bearer_token", "basic_auth") + credential_value: The secret credential value + label: Optional user-friendly label + + Returns: + UserGatewayCredential record + + Raises: + ValueError: If credential_type is invalid + Exception: If storage fails + """ + if credential_type not in VALID_CREDENTIAL_TYPES: + raise ValueError(f"Invalid credential_type '{credential_type}'. Must be one of: {VALID_CREDENTIAL_TYPES}") + + try: + encrypted_value = credential_value + if self.encryption: + encrypted_value = await self.encryption.encrypt_secret_async(credential_value) + + record = self.db.execute( + select(UserGatewayCredential).where( + UserGatewayCredential.gateway_id == gateway_id, + UserGatewayCredential.app_user_email == app_user_email, + ) + ).scalar_one_or_none() + + if record: + record.credential_type = credential_type + record.credential_value = encrypted_value + record.label = label + logger.info( + f"Updated credential for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, " + f"user {SecurityValidator.sanitize_log_message(app_user_email)}" + ) + else: + record = UserGatewayCredential( + gateway_id=gateway_id, + app_user_email=app_user_email, + credential_type=credential_type, + credential_value=encrypted_value, + label=label, + ) + self.db.add(record) + logger.info( + f"Stored new credential for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, " + f"user {SecurityValidator.sanitize_log_message(app_user_email)}" + ) + + self.db.commit() + return record + + except ValueError: + raise + except Exception as e: + self.db.rollback() + logger.error(f"Failed to store credential: {str(e)}") + raise + + async def get_credential(self, gateway_id: str, app_user_email: str) -> Optional[str]: + """Get a decrypted credential value for a specific user and gateway. + + Args: + gateway_id: ID of the gateway + app_user_email: ContextForge user email + + Returns: + Decrypted credential value or None if not found + """ + try: + record = self.db.execute( + select(UserGatewayCredential).where( + UserGatewayCredential.gateway_id == gateway_id, + UserGatewayCredential.app_user_email == app_user_email, + ) + ).scalar_one_or_none() + + if not record: + return None + + if self.encryption: + return await self.encryption.decrypt_secret_async(record.credential_value) + return record.credential_value + + except Exception as e: + logger.error(f"Failed to retrieve credential: {str(e)}") + return None + + async def get_credential_record(self, gateway_id: str, app_user_email: str) -> Optional[UserGatewayCredential]: + """Get the full credential record (without decrypting the value). + + Args: + gateway_id: ID of the gateway + app_user_email: ContextForge user email + + Returns: + UserGatewayCredential record or None + """ + try: + return self.db.execute( + select(UserGatewayCredential).where( + UserGatewayCredential.gateway_id == gateway_id, + UserGatewayCredential.app_user_email == app_user_email, + ) + ).scalar_one_or_none() + except Exception as e: + logger.error(f"Failed to get credential record: {str(e)}") + return None + + async def get_credential_info(self, gateway_id: str, app_user_email: str) -> Optional[Dict[str, Any]]: + """Get credential metadata without the secret value. + + Args: + gateway_id: ID of the gateway + app_user_email: ContextForge user email + + Returns: + Dict with credential info or None + """ + record = await self.get_credential_record(gateway_id, app_user_email) + if not record: + return None + + return { + "gateway_id": record.gateway_id, + "app_user_email": record.app_user_email, + "credential_type": record.credential_type, + "label": record.label, + "created_at": record.created_at.isoformat() if record.created_at else None, + "updated_at": record.updated_at.isoformat() if record.updated_at else None, + } + + async def revoke_credential(self, gateway_id: str, app_user_email: str) -> bool: + """Delete stored credential for a gateway-user combination. + + Args: + gateway_id: ID of the gateway + app_user_email: ContextForge user email + + Returns: + True if a credential was deleted, False if none existed + """ + try: + result = self.db.execute( + delete(UserGatewayCredential).where( + UserGatewayCredential.gateway_id == gateway_id, + UserGatewayCredential.app_user_email == app_user_email, + ) + ) + self.db.commit() + deleted = result.rowcount > 0 + if deleted: + logger.info( + f"Revoked credential for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, " + f"user {SecurityValidator.sanitize_log_message(app_user_email)}" + ) + return deleted + except Exception as e: + self.db.rollback() + logger.error(f"Failed to revoke credential: {str(e)}") + return False + + async def list_user_credentials(self, app_user_email: str) -> List[Dict[str, Any]]: + """List all gateway credentials for a user (metadata only, no secrets). + + Args: + app_user_email: ContextForge user email + + Returns: + List of credential info dicts + """ + try: + records = ( + self.db.execute( + select(UserGatewayCredential).where( + UserGatewayCredential.app_user_email == app_user_email, + ) + ) + .scalars() + .all() + ) + + return [ + { + "gateway_id": r.gateway_id, + "credential_type": r.credential_type, + "label": r.label, + "created_at": r.created_at.isoformat() if r.created_at else None, + "updated_at": r.updated_at.isoformat() if r.updated_at else None, + } + for r in records + ] + except Exception as e: + logger.error(f"Failed to list credentials: {str(e)}") + return [] + + @staticmethod + def build_auth_headers(credential_type: str, credential_value: str, gateway_auth_type: Optional[str] = None) -> Dict[str, str]: + """Build HTTP Authorization headers from a credential. + + Args: + credential_type: Type of credential ("api_key", "bearer_token", "basic_auth") + credential_value: Decrypted credential value + gateway_auth_type: Gateway's auth_type for context-aware header construction + + Returns: + Dict of HTTP headers to use for authentication + """ + if credential_type == "bearer_token": + return {"Authorization": f"Bearer {credential_value}"} + elif credential_type == "api_key": + # API key sent as Basic auth (common pattern: api_key as username, 'X' as password) + encoded = base64.b64encode(f"{credential_value}:X".encode()).decode() + return {"Authorization": f"Basic {encoded}"} + elif credential_type == "basic_auth": + # credential_value is "username:password" + encoded = base64.b64encode(credential_value.encode()).decode() + return {"Authorization": f"Basic {encoded}"} + else: + logger.warning(f"Unknown credential type: {credential_type}") + return {} diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 88e59d2a79..8e22f07c83 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -25,10 +25,9 @@ # Third-Party from jinja2 import Environment, meta, select_autoescape, Template -from mcp import ClientSession, types +from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.types import GetPromptRequest, GetPromptRequestParams import orjson from pydantic import ValidationError from sqlalchemy import and_, delete, desc, not_, or_, select @@ -37,7 +36,6 @@ # First-Party from mcpgateway.common.models import Message, PromptResult, Role, TextContent -from mcpgateway.common.validators import validate_meta_data as _validate_meta_data from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import EmailTeamMember as DbEmailTeamMember @@ -53,15 +51,14 @@ from mcpgateway.services.content_security import ContentSizeError, get_content_security_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService -from mcpgateway.services.upstream_session_registry import downstream_session_id_from_request_context as _downstream_session_id_from_request -from mcpgateway.services.upstream_session_registry import get_upstream_session_registry, RegistryNotInitializedError, TransportType from mcpgateway.utils.create_slug import slugify -from mcpgateway.utils.gateway_access import build_gateway_auth_headers +from mcpgateway.utils.gateway_access import build_gateway_auth_headers, resolve_gateway_auth_headers from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import unified_paginate from mcpgateway.utils.services_auth import decode_auth @@ -131,55 +128,6 @@ def _get_registry_cache(): metrics_buffer = get_metrics_buffer_service() -def _build_get_prompt_request(name: str, arguments: Optional[Dict[str, str]], meta_data: Dict[str, Any]) -> "types.ClientRequest": - """Build a GetPrompt ClientRequest that carries _meta (CWE-20, CWE-284). - - Using ``by_alias=True`` ensures the Pydantic alias ``_meta`` is the only - key written into the dict so the subsequent ``model_validate`` call - resolves it correctly regardless of ``populate_by_name`` settings. - - ``send_request`` is used instead of ``session.get_prompt()`` because the - MCP SDK helper does not expose a ``_meta`` parameter; this wrapper must be - updated if the SDK later adds that capability. - - Args: - name: The prompt name. - arguments: Optional prompt arguments. - meta_data: Validated metadata dict to inject as ``_meta``. - - Returns: - A :class:`types.ClientRequest` ready to be passed to ``session.send_request``. - """ - _gp_dict = GetPromptRequestParams(name=name, arguments=arguments).model_dump(by_alias=True) - _gp_dict["_meta"] = meta_data - return types.ClientRequest(GetPromptRequest(params=GetPromptRequestParams.model_validate(_gp_dict))) - - -async def _get_prompt_with_meta(session: "ClientSession", name: str, arguments: Optional[Dict[str, str]], meta_data: Optional[Dict[str, Any]]) -> Any: - """Dispatch a get_prompt call, injecting ``_meta`` when meta_data is provided. - - Eliminates the repeated ``if meta_data: send_request … else: get_prompt`` - pattern across every transport/pool branch in this module. - - Args: - session: An active MCP :class:`ClientSession`. - name: The prompt name. - arguments: Optional prompt-rendering arguments. - meta_data: Optional validated metadata dict. When ``None`` the standard - SDK helper is used; when non-empty the low-level ``send_request`` - path is taken to carry ``_meta``. - - Returns: - The raw MCP result object (caller extracts ``.messages``). - """ - if meta_data: - return await session.send_request( - _build_get_prompt_request(name, arguments, meta_data), - types.GetPromptResult, - ) - return await session.get_prompt(name, arguments=arguments) - - class PromptError(Exception): """Base class for prompt-related errors.""" @@ -353,13 +301,13 @@ def _should_fetch_gateway_prompt(prompt: DbPrompt) -> bool: """ return bool(getattr(prompt, "gateway_id", None)) and not bool(getattr(prompt, "template", "")) - async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Optional[Dict[str, str]], meta_data: Optional[Dict[str, Any]] = None) -> PromptResult: + async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Optional[Dict[str, str]], user_identity: Optional[str]) -> PromptResult: """Fetch a rendered prompt from the upstream MCP gateway. Args: prompt: Gateway-backed prompt record from the catalog. arguments: Optional prompt-rendering arguments. - meta_data: Optional metadata dict forwarded as ``_meta`` in the upstream MCP request. + user_identity: Effective requester email for session-pool isolation. Returns: Prompt result normalized into ContextForge models. @@ -372,7 +320,10 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option raise PromptError(f"Prompt '{prompt.name}' is gateway-backed but missing gateway metadata") gateway_url = str(gateway.url) - headers = build_gateway_auth_headers(gateway) + # Resolve per-user credentials (falls back to gateway defaults) + from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel + with SessionLocal() as db: + headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_identity, db=db) auth_query_params_decrypted: Optional[Dict[str, str]] = None if getattr(gateway, "auth_type", None) == "query_param" and getattr(gateway, "auth_query_params", None): @@ -387,32 +338,27 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) remote_name = getattr(prompt, "original_name", None) or prompt.name + pool_user_identity = (user_identity or "anonymous").strip() or "anonymous" gateway_id = str(getattr(gateway, "id", "")) transport = str(getattr(gateway, "transport", "streamable_http") or "streamable_http").lower() - registry_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP + pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP prompt_arguments = arguments or None - # CWE-400: Validate meta_data limits before forwarding to upstream - _validate_meta_data(meta_data) try: - # #4205: Use the upstream session registry when a downstream Mcp-Session-Id - # is in scope; this binds the upstream session 1:1 to the downstream - # session and preserves connection reuse across its tool/prompt calls. - downstream_session_id = _downstream_session_id_from_request() - if downstream_session_id and gateway_id: + if settings.mcp_session_pool_enabled: try: - registry = get_upstream_session_registry() - except RegistryNotInitializedError: - registry = None - if registry is not None: - async with registry.acquire( - downstream_session_id=downstream_session_id, - gateway_id=gateway_id, + pool = get_mcp_session_pool() + except RuntimeError: + pool = None + if pool is not None: + async with pool.session( url=gateway_url, headers=headers, - transport_type=registry_transport_type, - ) as upstream: - remote_result = await _get_prompt_with_meta(upstream.session, remote_name, prompt_arguments, meta_data) + transport_type=pool_transport_type, + user_identity=pool_user_identity, + gateway_id=gateway_id, + ) as pooled: + remote_result = await pooled.session.get_prompt(remote_name, arguments=prompt_arguments) return PromptResult( messages=[ Message.model_validate(message.model_dump(by_alias=True, exclude_none=True) if hasattr(message, "model_dump") else message) @@ -425,12 +371,12 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option async with sse_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as streams: async with ClientSession(*streams) as session: await session.initialize() - remote_result = await _get_prompt_with_meta(session, remote_name, prompt_arguments, meta_data) + remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments) else: async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - remote_result = await _get_prompt_with_meta(session, remote_name, prompt_arguments, meta_data) + remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments) return PromptResult( messages=[ @@ -1843,7 +1789,7 @@ async def get_prompt( None = unrestricted admin, [] = public-only, [...] = team-scoped. plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. plugin_global_context: Optional global context from middleware for consistency across hooks. - _meta_data: Optional metadata forwarded as _meta to the upstream MCP gateway during prompt retrieval. + _meta_data: Optional metadata for prompt retrieval (not used currently). Returns: Prompt result with rendered messages @@ -2004,7 +1950,7 @@ async def get_prompt( if self._should_fetch_gateway_prompt(prompt): # Release the read transaction before any remote network I/O. db.commit() - result = await self._fetch_gateway_prompt_result(prompt, arguments, meta_data=_meta_data) + result = await self._fetch_gateway_prompt_result(prompt, arguments, user) elif not arguments: result = PromptResult( messages=[ diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 5cb6e9bcdf..be9524040f 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -35,10 +35,9 @@ # Third-Party import httpx -from mcp import ClientSession, types +from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.types import ReadResourceRequest, ReadResourceRequestParams import parse from pydantic import ValidationError from sqlalchemy import and_, delete, desc, not_, or_, select @@ -48,7 +47,6 @@ # First-Party from mcpgateway.common.models import ResourceContent, ResourceContents, ResourceTemplate, TextContent from mcpgateway.common.validators import SecurityValidator -from mcpgateway.common.validators import validate_meta_data as _validate_meta_data from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import EmailTeamMember as DbEmailTeamMember @@ -67,14 +65,13 @@ from mcpgateway.services.content_security import ContentSizeError, ContentTypeError, get_content_security_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.services.structured_logger import get_structured_logger -from mcpgateway.services.upstream_session_registry import downstream_session_id_from_request_context as _downstream_session_id_from_request -from mcpgateway.services.upstream_session_registry import get_upstream_session_registry, RegistryNotInitializedError, TransportType -from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access +from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, resolve_gateway_auth_headers from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import unified_paginate from mcpgateway.utils.services_auth import decode_auth @@ -114,53 +111,6 @@ def _get_registry_cache(): metrics_buffer = get_metrics_buffer_service() -def _build_read_resource_request(uri: Any, meta_data: Dict[str, Any]) -> "types.ClientRequest": - """Build a ReadResource ClientRequest that carries _meta (CWE-20, CWE-284). - - Using ``by_alias=True`` ensures the Pydantic alias ``_meta`` is the only - key written into the dict so the subsequent ``model_validate`` call - resolves it correctly regardless of ``populate_by_name`` settings. - - ``send_request`` is used instead of ``session.read_resource()`` because the - MCP SDK helper does not expose a ``_meta`` parameter; this wrapper must be - updated if the SDK later adds that capability. - - Args: - uri: The resource URI. - meta_data: Validated metadata dict to inject as ``_meta``. - - Returns: - A :class:`types.ClientRequest` ready to be passed to ``session.send_request``. - """ - _rp_dict = ReadResourceRequestParams(uri=uri).model_dump(by_alias=True) - _rp_dict["_meta"] = meta_data - return types.ClientRequest(ReadResourceRequest(params=ReadResourceRequestParams.model_validate(_rp_dict))) - - -async def _read_resource_with_meta(session: "ClientSession", uri: Any, meta_data: Optional[Dict[str, Any]]) -> Any: - """Dispatch a read_resource call, injecting ``_meta`` when meta_data is provided. - - Eliminates the repeated ``if meta_data: send_request … else: read_resource`` - pattern across every transport/pool branch in this module. - - Args: - session: An active MCP :class:`ClientSession`. - uri: The resource URI to read. - meta_data: Optional validated metadata dict. When ``None`` the standard - SDK helper is used; when non-empty the low-level ``send_request`` - path is taken to carry ``_meta``. - - Returns: - The raw MCP result object (caller extracts ``.contents``). - """ - if meta_data: - return await session.send_request( - _build_read_resource_request(uri, meta_data), - types.ReadResourceResult, - ) - return await session.read_resource(uri=uri) - - class ResourceError(Exception): """Base class for resource-related errors.""" @@ -1571,7 +1521,7 @@ async def invoke_resource( # pylint: disable=unused-argument resource_uri: str, resource_template_uri: Optional[str] = None, user_identity: Optional[Union[str, Dict[str, Any]]] = None, - meta_data: Optional[Dict[str, Any]] = None, # Forwarded as _meta in upstream MCP requests + meta_data: Optional[Dict[str, Any]] = None, # Reserved for future MCP SDK support resource_obj: Optional[Any] = None, gateway_obj: Optional[Any] = None, server_id: Optional[str] = None, @@ -1685,10 +1635,6 @@ async def invoke_resource( # pylint: disable=unused-argument 'using template: /template' """ - # CWE-400: Validate meta_data limits before any further processing; invoke_resource is - # a separate entry point that must enforce the same guards as read_resource. - _validate_meta_data(meta_data) - uri = None if resource_uri and resource_template_uri: uri = resource_template_uri @@ -1706,6 +1652,14 @@ async def invoke_resource( # pylint: disable=unused-argument # This is especially important when resource isn't found - we don't want to hold the transaction db.commit() + # Normalize user_identity to string for session pool isolation. + if isinstance(user_identity, dict): + pool_user_identity = user_identity.get("email") or "anonymous" + elif isinstance(user_identity, str): + pool_user_identity = user_identity + else: + pool_user_identity = "anonymous" + oauth_user_email: Optional[str] = None if isinstance(user_identity, dict): user_email_value = user_identity.get("email") @@ -1915,8 +1869,8 @@ async def connect_to_sse_session(server_url: str, uri: str, authentication: Opti ``None`` instead of raising. Note: - When meta_data is provided, the request is built using send_request - with _meta injected into ReadResourceRequestParams. + MCP SDK 1.25.0 read_resource() does not support meta parameter. + When the SDK adds support, meta_data can be added back here. Args: server_url (str): @@ -1943,38 +1897,39 @@ async def connect_to_sse_session(server_url: str, uri: str, authentication: Opti if authentication is None: authentication = {} try: - # #4205: Registry path is taken when the caller has a downstream - # Mcp-Session-Id; upstream state is then bound 1:1 to that - # downstream session and never shared across clients. - downstream_session_id = _downstream_session_id_from_request() - use_registry = bool(downstream_session_id) and bool(gateway_id) - registry = None - if use_registry: + # Use session pool if enabled for 10-20x latency improvement + use_pool = False + pool = None + if settings.mcp_session_pool_enabled: try: - registry = get_upstream_session_registry() - except RegistryNotInitializedError: - use_registry = False - - if use_registry and registry is not None: - async with registry.acquire( - downstream_session_id=downstream_session_id, - gateway_id=gateway_id, + pool = get_mcp_session_pool() + use_pool = True + except RuntimeError: + # Pool not initialized (e.g., in tests), fall back to per-call sessions + pass + + if use_pool and pool is not None: + async with pool.session( url=server_url, headers=authentication, transport_type=TransportType.SSE, httpx_client_factory=_get_httpx_client_factory, - ) as upstream: - resource_response = await _read_resource_with_meta(upstream.session, uri, meta_data) + user_identity=pool_user_identity, + gateway_id=gateway_id, + ) as pooled: + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await pooled.session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") else: - # Fallback: per-call session when no downstream session id is in scope. + # Fallback to per-call sessions when pool disabled or not initialized async with sse_client(url=server_url, headers=authentication, timeout=settings.health_check_timeout, httpx_client_factory=_get_httpx_client_factory) as ( read_stream, write_stream, ): async with ClientSession(read_stream, write_stream) as session: _ = await session.initialize() - resource_response = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") except Exception as e: # Sanitize error message to prevent URL secrets from leaking in logs @@ -1996,8 +1951,8 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica of propagating the exception. Note: - When meta_data is provided, the request is built using send_request - with _meta injected into ReadResourceRequestParams. + MCP SDK 1.25.0 read_resource() does not support meta parameter. + When the SDK adds support, meta_data can be added back here. Args: server_url (str): @@ -2023,29 +1978,31 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica if authentication is None: authentication = {} try: - # #4205: see SSE path above; same 1:1 binding rationale. - downstream_session_id = _downstream_session_id_from_request() - use_registry = bool(downstream_session_id) and bool(gateway_id) - registry = None - if use_registry: + # Use session pool if enabled for 10-20x latency improvement + use_pool = False + pool = None + if settings.mcp_session_pool_enabled: try: - registry = get_upstream_session_registry() - except RegistryNotInitializedError: - use_registry = False - - if use_registry and registry is not None: - async with registry.acquire( - downstream_session_id=downstream_session_id, - gateway_id=gateway_id, + pool = get_mcp_session_pool() + use_pool = True + except RuntimeError: + # Pool not initialized (e.g., in tests), fall back to per-call sessions + pass + + if use_pool and pool is not None: + async with pool.session( url=server_url, headers=authentication, transport_type=TransportType.STREAMABLE_HTTP, httpx_client_factory=_get_httpx_client_factory, - ) as upstream: - resource_response = await _read_resource_with_meta(upstream.session, uri, meta_data) + user_identity=pool_user_identity, + gateway_id=gateway_id, + ) as pooled: + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await pooled.session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") else: - # Fallback: per-call session when no downstream session id is in scope. + # Fallback to per-call sessions when pool disabled or not initialized async with streamablehttp_client(url=server_url, headers=authentication, timeout=settings.health_check_timeout, httpx_client_factory=_get_httpx_client_factory) as ( read_stream, write_stream, @@ -2053,7 +2010,8 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica ): async with ClientSession(read_stream, write_stream) as session: _ = await session.initialize() - resource_response = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") except Exception as e: # Sanitize error message to prevent URL secrets from leaking in logs @@ -2067,8 +2025,10 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica resource_text = "" if (gateway_transport).lower() == "sse": + # Note: meta_data not passed - MCP SDK 1.25.0 read_resource() doesn't support it resource_text = await connect_to_sse_session(server_url=gateway_url, authentication=headers, uri=uri) else: + # Note: meta_data not passed - MCP SDK 1.25.0 read_resource() doesn't support it resource_text = await connect_to_streamablehttp_server(server_url=gateway_url, authentication=headers, uri=uri) if span and resource_text is not None and is_output_capture_enabled("invoke.resource"): set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"content": resource_text})) @@ -2180,8 +2140,6 @@ async def read_resource( resource_db = None server_scoped = False resource_db_gateway = None # Only set when eager-loaded via Q2's joinedload - # CWE-400: Validate meta_data limits before any further processing - _validate_meta_data(meta_data) content = None uri = resource_uri or "unknown" if resource_id: @@ -2348,15 +2306,16 @@ async def read_resource( gateway = resource_db.gateway - # Prepare headers with gateway auth - headers = build_gateway_auth_headers(gateway) + # Prepare headers with per-user credentials (falls back to gateway defaults) + headers = await resolve_gateway_auth_headers(gateway, app_user_email=user, db=db) # Use MCP SDK to connect and read resource async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - result = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK read_resource() only accepts uri; _meta is not supported + result = await session.read_resource(uri=uri) # Convert MCP result to MCP-compliant content models # result.contents is a list of TextResourceContents or BlobResourceContents diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 5b2dba3314..03c785fb80 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -91,7 +91,7 @@ from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name -from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers +from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, resolve_gateway_auth_headers from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached @@ -3418,8 +3418,8 @@ async def invoke_tool_direct( if not await check_gateway_access(db, gateway, user_email, token_teams): raise ToolNotFoundError(f"Tool not found: {name}") - # Prepare headers with gateway auth - headers = build_gateway_auth_headers(gateway) + # Prepare headers with per-user credentials (falls back to gateway defaults) + headers = await resolve_gateway_auth_headers(gateway, app_user_email=user_email, db=db) # Forward passthrough headers if configured if gateway.passthrough_headers and request_headers: @@ -3761,7 +3761,28 @@ async def prepare_rust_mcp_tool_execution( if not gateway_url: return {"eligible": False, "fallbackReason": "missing-gateway-url"} - if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: + # Per-user personal credentials always take priority over + # gateway-level OAuth tokens. + user_credential_headers = None + if has_gateway and app_user_email and gateway_id_str: + try: + from mcpgateway.services.credential_storage_service import CredentialStorageService # pylint: disable=import-outside-toplevel + + with fresh_db_session() as cred_db: + cred_service = CredentialStorageService(cred_db) + cred_record = await cred_service.get_credential_record(gateway_id_str, app_user_email) + if cred_record: + cred_value = await cred_service.get_credential(gateway_id_str, app_user_email) + if cred_value: + user_credential_headers = CredentialStorageService.build_auth_headers( + cred_record.credential_type, cred_value, gateway_auth_type + ) + except Exception as e: + logger.debug(f"Failed to check personal credentials for gateway {gateway_name}: {e}") + + if user_credential_headers: + headers = user_credential_headers + elif has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: grant_type = gateway_oauth_config.get("grant_type", "client_credentials") if grant_type == "authorization_code": try: @@ -3789,6 +3810,7 @@ async def prepare_rust_mcp_tool_execution( logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}") raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") else: + # No per-user credentials and no OAuth — fall back to shared gateway auth headers = decode_auth(gateway_auth_value) if gateway_auth_value else {} if request_headers: @@ -4841,7 +4863,28 @@ async def invoke_tool( # Handle OAuth authentication for the gateway (using local variables) # NOTE: Use has_gateway instead of gateway to avoid accessing detached ORM object - if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: + # Per-user personal credentials always take priority over + # gateway-level OAuth tokens. + user_credential_headers = None + if has_gateway and app_user_email and gateway_id_str: + try: + from mcpgateway.services.credential_storage_service import CredentialStorageService # pylint: disable=import-outside-toplevel + + with fresh_db_session() as cred_db: + cred_service = CredentialStorageService(cred_db) + cred_record = await cred_service.get_credential_record(gateway_id_str, app_user_email) + if cred_record: + cred_value = await cred_service.get_credential(gateway_id_str, app_user_email) + if cred_value: + user_credential_headers = CredentialStorageService.build_auth_headers( + cred_record.credential_type, cred_value, gateway_auth_type + ) + except Exception as e: + logger.debug(f"Failed to check personal credentials for gateway {gateway_name}: {e}") + + if user_credential_headers: + headers = user_credential_headers + elif has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: grant_type = gateway_oauth_config.get("grant_type", "client_credentials") if grant_type == "authorization_code": @@ -4877,6 +4920,7 @@ async def invoke_tool( logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}") raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") else: + # No per-user credentials and no OAuth — fall back to shared gateway auth headers = decode_auth(gateway_auth_value) if gateway_auth_value else {} # Use cached passthrough headers (no DB query needed) diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 98a05fd47f..204340785f 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -9712,6 +9712,55 @@

+ + +
🔐 Authorize {% endif %} + + + {% if can_modify %}