Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions mcpgateway/routers/well_known.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# First-Party
from mcpgateway.config import settings
from mcpgateway.db import get_db
from mcpgateway.db import Server as DbServer
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.services.server_service import ServerError, ServerNotFoundError, ServerService
from mcpgateway.utils.log_sanitizer import sanitize_for_log
Expand All @@ -38,6 +39,9 @@
# UUID validation pattern for RFC 9728 endpoint
UUID_PATTERN = re.compile(r"^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$", re.IGNORECASE)

# Server name pattern: alphanumeric, hyphens, underscores (prevents path traversal/injection)
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]{0,253}[a-zA-Z0-9]$|^[a-zA-Z0-9]$")

# Well-known URI registry with validation
WELL_KNOWN_REGISTRY = {
"robots.txt": {"content_type": "text/plain", "description": "Robot exclusion standard", "rfc": "RFC 9309"},
Expand Down Expand Up @@ -169,21 +173,27 @@ async def get_oauth_protected_resource_rfc9728(
logger.debug(f"Invalid RFC 9728 path format: {sanitize_for_log(path)}")
raise HTTPException(status_code=404, detail="Invalid resource path format. Expected: /.well-known/oauth-protected-resource/servers/{server_id}/mcp")

server_id = path_parts[1]

# Validate server_id is a valid UUID (prevents path traversal and injection)
if not UUID_PATTERN.match(server_id):
# Sanitize untrusted server_id before logging to prevent log injection
logger.warning(f"Invalid server_id format (not a UUID): {sanitize_for_log(server_id)}")
raise HTTPException(status_code=404, detail="Invalid server_id format. Must be a valid UUID.")
server_id_or_name = path_parts[1]

# Reject paths with extra segments after /mcp (e.g., servers/uuid/mcp/extra)
if len(path_parts) > 3:
# Sanitize untrusted path before logging to prevent log injection
logger.warning(f"RFC 9728 path has unexpected segments: {sanitize_for_log(path)}")
raise HTTPException(status_code=404, detail="Invalid resource path format. Expected: /.well-known/oauth-protected-resource/servers/{server_id}/mcp")

# Build resource URL with /mcp suffix per MCP specification
# Resolve server_id: accept UUID or server name
if UUID_PATTERN.match(server_id_or_name):
server_id = server_id_or_name
elif SERVER_NAME_PATTERN.match(server_id_or_name):
# Resolve server name to UUID via DB lookup
server = db.query(DbServer).filter(DbServer.name == server_id_or_name, DbServer.enabled.is_(True)).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
server_id = server.id
else:
logger.warning(f"Invalid server identifier format: {sanitize_for_log(server_id_or_name)}")
raise HTTPException(status_code=404, detail="Invalid server identifier format. Must be a UUID or a valid server name.")

# Build resource URL with canonical UUID (ensures stable resource identifiers per RFC 9728)
base_url = get_base_url_with_protocol(request)
resource_url = f"{base_url}/servers/{server_id}/mcp"

Expand Down
28 changes: 27 additions & 1 deletion mcpgateway/services/server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,28 @@ def __init__(self) -> None:
self._audit_trail = get_audit_trail_service()
self._performance_tracker = get_performance_tracker()

def resolve_server_id(self, db: Session, server_id_or_name: str) -> Optional[str]:
"""Resolve a server identifier (UUID or name) to a canonical UUID.

Accepts either a hex UUID or a server name. Returns the server's primary
key if found, or None if no matching server exists.

Args:
db: Database session.
server_id_or_name: UUID hex string or server name.

Returns:
The server UUID string, or None if not found.
"""
# Try direct PK lookup first (fast path)
server = db.get(DbServer, server_id_or_name)
if server:
return server.id

# Fallback: lookup by name
server = db.query(DbServer).filter(DbServer.name == server_id_or_name).first()
return server.id if server else None

async def initialize(self) -> None:
"""Initialize the server service."""
logger.info("Initializing server service")
Expand Down Expand Up @@ -1901,7 +1923,11 @@ def get_oauth_protected_resource_metadata(self, db: Session, server_id: str, res
>>> callable(service.get_oauth_protected_resource_metadata)
True
"""
server = db.get(DbServer, server_id)
# Resolve server by UUID or name
resolved_id = self.resolve_server_id(db, server_id)
if not resolved_id:
raise ServerNotFoundError(f"Server not found: {server_id}")
server = db.get(DbServer, resolved_id)

# Return not found for non-existent, disabled, or non-public servers
# (avoids leaking information about private/team servers)
Expand Down
12 changes: 8 additions & 4 deletions mcpgateway/transports/streamablehttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,12 +951,15 @@ async def _check_server_oauth_enforcement(server_id: str, user_context: Optional

try:
async with get_db() as db:
# Support both UUID and name resolution
server = db.execute(select(DbServer).where(DbServer.id == server_id)).scalar_one_or_none()
if not server:
server = db.execute(select(DbServer).where(DbServer.name == server_id)).scalar_one_or_none()
if server and server.oauth_enabled:
logger.warning("OAuth required for server %s but caller is unauthenticated", server_id)
raise OAuthRequiredError(
"This server requires OAuth authentication. Please provide a valid access token.",
server_id=server_id,
server_id=server.id,
)
_oauth_checked_var.set(True)
except SQLAlchemyError as exc:
Expand Down Expand Up @@ -2835,18 +2838,19 @@ async def _validate_server_id(match: "re.Match[str] | None", path: str, scope: S
server_id = match.group("server_id")
# SECURITY: Validate that the server_id exists in the database
# to prevent unauthorized access via invalid server IDs.
# Uses the shared BaseService.entity_exists() for a lightweight
# EXISTS check — no row data is loaded.
# Supports both UUID and server name resolution.
try:
# First-Party
from mcpgateway.services.server_service import server_service as _server_svc # pylint: disable=import-outside-toplevel,no-name-in-module

async with get_db() as db:
if not await _server_svc.entity_exists(db, server_id):
resolved_id = _server_svc.resolve_server_id(db, server_id)
if not resolved_id:
logger.warning("Invalid server ID in MCP request path: %s", server_id)
response = ORJSONResponse({"detail": "Server not found"}, status_code=404)
await response(scope, receive, send)
return _REJECT
server_id = resolved_id
except Exception as e:
logger.error("Failed to validate server ID %s: %s", server_id, e)
response = ORJSONResponse({"detail": "Service unavailable — unable to verify server"}, status_code=503)
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/mcpgateway/routers/test_well_known_rfc9728.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,23 @@ def override_get_db():
app.dependency_overrides.pop(get_db, None)

def test_rfc9728_endpoint_invalid_uuid(self, app):
"""Test RFC 9728 endpoint rejects non-UUID server IDs."""
"""Test RFC 9728 endpoint rejects invalid identifiers and resolves names."""
mock_db = MagicMock()
# Name lookups return no results (server not found)
mock_db.query.return_value.filter.return_value.first.return_value = None

def override_get_db():
yield mock_db

app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)

# Not a valid UUID
# Valid name format but server does not exist => 404 Server not found
response = client.get("/.well-known/oauth-protected-resource/servers/not-a-uuid/mcp")
assert response.status_code == 404
assert "Invalid server_id format" in response.json()["detail"]
assert "Server not found" in response.json()["detail"]

# Path traversal attempt
# Path traversal attempt (invalid name format)
response = client.get("/.well-known/oauth-protected-resource/servers/../admin/mcp")
assert response.status_code == 404

Expand Down