Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions mcpgateway/cache/tool_call_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/cache/tool_call_registry.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti

Tool call to session mapping registry.

Tracks which client session initiated each tool call for proper
elicitation routing in multi-user deployments. This enables the gateway
to route elicitation requests from upstream MCP servers back to the
specific client that initiated the tool call.

Per MCP specification 2025-11-25, elicitation requests must be routed
to the correct client session to maintain security boundaries and prevent
cross-user information leakage.
"""

# Standard
import asyncio
import time
from typing import Dict, Optional, Tuple

# First-Party
from mcpgateway.services.logging_service import LoggingService

# Initialize logging service first
logging_service = LoggingService()
logger = logging_service.get_logger(__name__)


class ToolCallRegistry:
"""Registry mapping tool call IDs to originating client sessions.

This enables proper elicitation routing: when an upstream MCP server
sends an elicitation request during tool execution, we can route it
back to the specific client that initiated the tool call.

The registry maintains mappings with timestamps for automatic cleanup
of stale entries, preventing memory leaks in long-running deployments.

Attributes:
_mappings: Dictionary mapping tool_call_id to (session_id, timestamp)
_cleanup_interval: How often to run cleanup task (seconds)
_cleanup_task: Background task for cleaning up stale mappings
"""

def __init__(self, cleanup_interval: int = 300):
"""Initialize the registry.

Args:
cleanup_interval: How often to clean up stale mappings (seconds)
"""
self._mappings: Dict[str, Tuple[str, float]] = {}
self._cleanup_interval = cleanup_interval
self._cleanup_task: Optional[asyncio.Task[None]] = None
logger.info(f"ToolCallRegistry initialized: cleanup_interval={cleanup_interval}s")

async def start(self):
"""Start background cleanup task."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("ToolCallRegistry cleanup task started")

async def shutdown(self):
"""Shutdown the registry and cancel cleanup task."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass

mapping_count = len(self._mappings)
self._mappings.clear()
logger.info(f"ToolCallRegistry shutdown complete (cleared {mapping_count} mappings)")

def register_tool_call(self, tool_call_id: str, session_id: str):
"""Register a tool call with its originating session.

This should be called before invoking a tool to establish the
mapping for potential elicitation routing.

Args:
tool_call_id: Unique identifier for the tool call
session_id: Client session that initiated the call
"""
self._mappings[tool_call_id] = (session_id, time.time())
logger.debug(f"Registered tool call {tool_call_id} -> session {session_id}")

def get_session_for_tool_call(self, tool_call_id: str) -> Optional[str]:
"""Get the session ID that initiated a tool call.

Args:
tool_call_id: The tool call identifier

Returns:
Session ID if found, None otherwise
"""
mapping = self._mappings.get(tool_call_id)
if mapping:
return mapping[0]
return None

def unregister_tool_call(self, tool_call_id: str):
"""Remove a tool call mapping after completion.

This should be called after tool execution completes (success or failure)
to prevent memory leaks.

Args:
tool_call_id: The tool call identifier to remove
"""
if self._mappings.pop(tool_call_id, None):
logger.debug(f"Unregistered tool call {tool_call_id}")

def get_mapping_count(self) -> int:
"""Get count of active tool call mappings.

Returns:
Number of currently tracked tool calls
"""
return len(self._mappings)

async def _cleanup_loop(self):
"""Background task to periodically clean up stale mappings.

Raises:
asyncio.CancelledError: If the task is cancelled during shutdown.
"""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_stale()
except asyncio.CancelledError:
logger.info("ToolCallRegistry cleanup loop cancelled")
raise
except Exception as e:
logger.error(f"Error in tool call registry cleanup loop: {e}", exc_info=True)

async def _cleanup_stale(self):
"""Remove mappings older than 1 hour.

Tool calls should complete within minutes, so 1 hour is a safe
threshold for detecting abandoned mappings.
"""
now = time.time()
stale_threshold = 3600 # 1 hour
stale_ids = [tool_call_id for tool_call_id, (_, timestamp) in self._mappings.items() if now - timestamp > stale_threshold]

for tool_call_id in stale_ids:
self._mappings.pop(tool_call_id, None)

if stale_ids:
logger.info(f"Cleaned up {len(stale_ids)} stale tool call mappings")


# Global singleton instance
_tool_call_registry: Optional[ToolCallRegistry] = None


def get_tool_call_registry() -> ToolCallRegistry:
"""Get the global ToolCallRegistry singleton instance.

Returns:
The global ToolCallRegistry instance
"""
global _tool_call_registry # pylint: disable=global-statement
if _tool_call_registry is None:
_tool_call_registry = ToolCallRegistry()
return _tool_call_registry


def set_tool_call_registry(registry: ToolCallRegistry):
"""Set the global ToolCallRegistry instance.

This is primarily used for testing to inject mock registries.

Args:
registry: The ToolCallRegistry instance to use globally
"""
global _tool_call_registry # pylint: disable=global-statement
_tool_call_registry = registry
17 changes: 17 additions & 0 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,15 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
await elicitation_service.start()
logger.info("Elicitation service initialized")

# Initialize tool call registry for elicitation routing
if settings.mcpgateway_elicitation_enabled:
# First-Party
from mcpgateway.cache.tool_call_registry import get_tool_call_registry # pylint: disable=import-outside-toplevel

tool_call_registry = get_tool_call_registry()
await tool_call_registry.start()
logger.info("Tool call registry initialized")

# Initialize metrics buffer service for batching metric writes
if settings.metrics_buffer_enabled:
# First-Party
Expand Down Expand Up @@ -1997,6 +2006,14 @@ async def run_log_aggregation_loop() -> None:
elicitation_service = get_elicitation_service()
services_to_shutdown.insert(5, elicitation_service)

# Add tool call registry if elicitation is enabled
if settings.mcpgateway_elicitation_enabled:
# First-Party
from mcpgateway.cache.tool_call_registry import get_tool_call_registry # pylint: disable=import-outside-toplevel

tool_call_registry = get_tool_call_registry()
services_to_shutdown.insert(6, tool_call_registry)

# Add metrics buffer service if enabled (flush remaining metrics before shutdown)
if settings.metrics_buffer_enabled:
# First-Party
Expand Down
89 changes: 89 additions & 0 deletions mcpgateway/services/elicitation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@

# First-Party
from mcpgateway.common.models import ElicitResult
from mcpgateway.services.metrics import (
elicitation_completed_total,
elicitation_duration_seconds,
elicitation_requests_total,
elicitation_timeout_total,
)
from mcpgateway.services.structured_logger import get_structured_logger

logger = logging.getLogger(__name__)
structured_logger = get_structured_logger("elicitation_service")


@dataclass
Expand Down Expand Up @@ -154,14 +162,95 @@ async def create_elicitation(self, upstream_session_id: str, downstream_session_
self._pending[request_id] = elicitation
logger.info(f"Created elicitation request {request_id}: upstream={upstream_session_id}, downstream={downstream_session_id}, timeout={timeout_val}s")

# Increment metrics
elicitation_requests_total.inc()

# Structured logging: elicitation created
structured_logger.log(
level="INFO",
message=f"Elicitation created: {request_id}",
component="elicitation_service",
metadata={
"event": "elicitation.created",
"request_id": request_id,
"upstream_session": upstream_session_id,
"downstream_session": downstream_session_id,
"message": message,
"timeout": timeout_val,
},
)

# Structured logging: elicitation delivered
# This event is emitted immediately after creation, indicating the request
# has been successfully forwarded to the downstream client session.
# The actual delivery happens through the callback mechanism in tool_service.py
structured_logger.log(
level="INFO",
message=f"Elicitation delivered: {request_id}",
component="elicitation_service",
metadata={
"event": "elicitation.delivered",
"request_id": request_id,
"downstream_session": downstream_session_id,
"delivered_at": time.time(),
},
)

try:
# Wait for response with timeout
result = await asyncio.wait_for(future, timeout=timeout_val)
duration = time.time() - elicitation.created_at

# Record metrics
elicitation_duration_seconds.observe(duration)
elicitation_completed_total.labels(action=result.action).inc()

# Structured logging: elicitation completed
structured_logger.log(
level="INFO",
message=f"Elicitation completed: {request_id}",
component="elicitation_service",
metadata={
"event": "elicitation.completed",
"request_id": request_id,
"action": result.action,
"duration_ms": duration * 1000,
},
)

logger.info(f"Elicitation {request_id} completed: action={result.action}")
return result
except asyncio.TimeoutError:
# Record timeout metric
elicitation_timeout_total.inc()

# Structured logging: elicitation timeout
structured_logger.log(
level="WARNING",
message=f"Elicitation timeout: {request_id}",
component="elicitation_service",
metadata={
"event": "elicitation.timeout",
"request_id": request_id,
"timeout_seconds": timeout_val,
},
)

logger.warning(f"Elicitation {request_id} timed out after {timeout_val}s")
raise
except Exception as e:
# Structured logging: elicitation error
structured_logger.log(
level="ERROR",
message=f"Elicitation error: {request_id}",
component="elicitation_service",
metadata={
"event": "elicitation.error",
"request_id": request_id,
"error": str(e),
},
)
raise
finally:
# Cleanup
self._pending.pop(request_id, None)
Expand Down
27 changes: 26 additions & 1 deletion mcpgateway/services/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# Third-Party
from fastapi import Depends, Request, Response, status
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, generate_latest, REGISTRY
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, generate_latest, Histogram, REGISTRY
from prometheus_fastapi_instrumentator import Instrumentator

# First-Party
Expand Down Expand Up @@ -118,6 +118,31 @@ def _get_registry_collector(metric_name: str):
["outcome"],
)

# Elicitation Metrics
elicitation_requests_total = Counter(
"elicitation_requests_total",
"Total number of elicitation requests created",
)

elicitation_completed_total = Counter(
"elicitation_completed_total",
"Total number of completed elicitations by action",
["action"], # accept, decline, cancel
)

elicitation_timeout_total = Counter(
"elicitation_timeout_total",
"Total number of elicitation timeouts",
)

elicitation_duration_seconds = _get_registry_collector("elicitation_duration_seconds")
if elicitation_duration_seconds is None:
elicitation_duration_seconds = Histogram(
"elicitation_duration_seconds",
"Duration of elicitation requests in seconds",
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
)

# OAuth / JWKS access-token verification on oauth_enabled virtual servers.
# Outcome labels:
# success β€” IdP-issued token verified and user context populated
Expand Down
Loading
Loading