diff --git a/.env.example b/.env.example index 2fcf23a5..c35280d3 100644 --- a/.env.example +++ b/.env.example @@ -218,3 +218,15 @@ INGESTION__CHUNK_SIZE=5000 # Files ignored by the indexer. # INGESTION__GITHUB_IGNORED_FILES={'uv.lock', 'poetry.lock', 'package-lock.json', 'Pipfile.lock', 'yarn.lock'} + +# ============================================================================== +# MCP (MODEL CONTEXT PROTOCOL) SERVERS +# ============================================================================== +# MCP (Model Context Protocol) servers (JSON string) +# Example (stdio): +# MCP__SERVERS='{"filesystem":{"command":"npx","args":["-y","@modelcontextprotocol/server-filesystem","/path"]}}' + +# Example (remote SSE): +# MCP__SERVERS='{"api":{"url":"http://localhost:3000/sse","transport":"sse"}}' + +# Optional per server: headers, timeout, tool_filter, enabled, env diff --git a/README.md b/README.md index b7ce6adf..90e44f36 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ SDK for building **verifiable AI Agents** on Flare using Confidential Space Trus - **Verifiable execution**: Run logic inside Intel TDX TEEs via [GCP Confidential Space](https://cloud.google.com/confidential-computing/confidential-space/docs/confidential-space-overview). - **Multi-agent consensus**: Majority/Tournament/[Consensus Learning](https://arxiv.org/abs/2402.16157) via [Google Agent2Agent](https://github.com/a2aproject/A2A) protocol. -- **Agent framework**: Built on [Google ADK](https://google.github.io/adk-docs/) with tool-calling, orchestration and evaluation. +- **Agent framework**: Built on [Google ADK](https://google.github.io/adk-docs/) with tool-calling, orchestration and evaluation. Supports [MCP](https://modelcontextprotocol.io/) for custom tool integration. - **Flare integration**: [FTSO](https://dev.flare.network/ftso/overview), [FDC](https://dev.flare.network/fdc/overview), [FAssets](https://dev.flare.network/fassets/overview) + ecosystem dApps ([Sceptre](https://sceptre.fi), [SparkDEX](https://sparkdex.ai), ...). - **Social connectors**: X, Telegram, Farcaster. @@ -29,6 +29,7 @@ flowchart TD subgraph AgentFramework["Agent Framework"] B["Google ADK"] B --o LLM["Gemini
GPT
Grok
+200 models"] + B --o MCP["MCP Servers
(Custom Tools)"] end %% VectorRAG Engine subgraph @@ -130,7 +131,7 @@ docker run --rm -it \ fai-script-pdf ``` -Available `EXTRAS`: `pdf`, `rag`, `a2a`, `ftso`, `da`, `fassets`, `social`, `tee`, `wallet`, `ingestion` +Available `EXTRAS`: `pdf`, `rag`, `a2a`, `ftso`, `da`, `fassets`, `social`, `tee`, `wallet`, `ingestion`, `mcp` See [Docker Scripts Guide](docs/docker_scripts_guide.md) for detailed usage instructions. diff --git a/docs/mcp_readme.md b/docs/mcp_readme.md new file mode 100644 index 00000000..efad5c43 --- /dev/null +++ b/docs/mcp_readme.md @@ -0,0 +1,258 @@ +# MCP (Model Context Protocol) Server Integration + +This documentation provides a comprehensive guide to the MCP server integration for Flare AI Kit, enabling AI agents to connect to external tools and services via the Model Context Protocol. + +## Quick Start + +### 1. Installation + +The MCP integration requires additional dependencies: + +```bash +# Install with MCP support +uv sync --extra mcp + +# Or add the mcp extra to your dependencies +uv add "flare-ai-kit[mcp]" +``` + +### 2. Configuration + +Configure MCP servers via the `MCP__SERVERS` environment variable: + +```bash +# .env file +MCP__SERVERS='{"filesystem": {"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path"]}}' +``` + +## Server Configuration + +### Stdio Servers (Local Processes) + +Stdio servers run as local processes and communicate via stdin/stdout + +**Environment Variable Format:** + +```bash +MCP__SERVERS='{"filesystem": {"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path"]}}' +``` + +### SSE Servers (Remote - Server-Sent Events) + +SSE servers connect to remote endpoints using Server-Sent Events + +**Environment Variable Format:** + +```bash +MCP__SERVERS='{"api": {"url": "http://localhost:3000/sse", "transport": "sse", "headers": {"Authorization": "Bearer token"}}}' +``` + +### HTTP Servers (Remote - Streamable HTTP) + +HTTP servers connect using the streamable HTTP transport + +**Environment Variable Format:** + +```bash +MCP__SERVERS='{"remote": {"url": "https://api.example.com/mcp", "transport": "http"}}' +``` + +### Mixed Configuration Example + +Configure multiple servers of different types: + +```bash +MCP__SERVERS='{ + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/workspace"] + }, + "git": { + "command": "python", + "args": ["-m", "mcp_git_server"], + "env": {"GIT_AUTHOR_NAME": "AI Agent"} + }, + "external-api": { + "url": "https://api.example.com/mcp", + "transport": "http", + "headers": {"X-API-Key": "secret"} + }, + "streaming": { + "url": "http://localhost:3000/sse", + "transport": "sse" + } +}' +``` + +## Configuration Options + +### MCPServerConfig Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `command` | `str` | `None` | Command to run (stdio servers) | +| `args` | `list[str]` | `[]` | Arguments for the command | +| `env` | `dict[str, str]` | `{}` | Environment variables for the process | +| `url` | `str` | `None` | URL for remote servers (SSE/HTTP) | +| `transport` | `"stdio"\|"sse"\|"http"` | `"stdio"` | Transport type | +| `headers` | `dict[str, str]` | `{}` | HTTP headers for authentication | +| `timeout` | `float` | `30.0` | Connection timeout in seconds | +| `tool_filter` | `list[str]` | `None` | Tools to expose (`None` = all) | +| `enabled` | `bool` | `True` | Whether the server is enabled | + +### Tool Filtering + +Limit which tools are exposed from a server: + +```python +config = MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem", "/path"], + tool_filter=["read_file", "list_directory"], # Only expose these tools +) +``` + +```bash +MCP__SERVERS='{"fs": {"command": "npx", "args": ["-y", "server"], "tool_filter": ["read_file", "write_file"]}}' +``` + +### Disabling Servers + +Temporarily disable a server without removing its configuration: + +```python +config = MCPServerConfig( + command="npx", + args=["-y", "some-server"], + enabled=False, # Server won't be initialized +) +``` + +## Using with ADK Agents + +The MCP toolsets integrate seamlessly with Google ADK agents: + +```python +from flare_ai_kit import FlareAIKit +from google.adk import Agent +from google.genai import types + +async def create_agent_with_mcp(): + kit = FlareAIKit() + + # Get MCP toolsets + mcp_toolsets = kit.mcp_tools + + # Combine with other tools + all_tools = [ + *mcp_toolsets, # MCP tools + # ... your other ADK tools + ] + + # Create ADK agent with MCP tools + agent = Agent( + model="gemini-2.5-flash", + name="mcp-enabled-agent", + instruction="You have access to filesystem and other MCP tools.", + tools=all_tools, + ) + + return agent +``` + +## MCPManager API + +### Getting Toolsets + +```python +from flare_ai_kit.mcp.manager import MCPManager +from flare_ai_kit.mcp.settings import MCPSettings + +# Create manager +settings = MCPSettings() # Loads from environment +manager = MCPManager(settings) + +# Get all toolsets (sync) +toolsets = manager.get_toolsets_sync() + +# Get all toolsets (async) +toolsets = await manager.get_toolsets() + +# Get specific toolset by name +fs_toolset = manager.get_toolset("filesystem") +``` + +### Checking Configuration + +```python +# Check if any servers are configured +if manager.has_servers: + print("MCP servers configured") + +# Get list of server names +print(f"Servers: {manager.server_names}") +``` + +### Error Handling + +```python +# Get toolsets (errors are captured, not raised) +toolsets = manager.get_toolsets_sync() + +# Check for initialization errors +errors = manager.get_errors() +for server_name, error in errors.items(): + print(f"Server {server_name} failed: {error}") +``` + +### Cleanup + +```python +# Close all MCP connections +await manager.close() +``` + +## Error Handling + +### ImportError for Missing Dependencies + +```python +try: + toolsets = manager.get_toolsets_sync() +except ImportError as e: + print("MCP dependencies not installed. Run: pip install flare-ai-kit[mcp]") +``` + +### Server Configuration Errors + +```python +from flare_ai_kit.mcp.settings import MCPServerConfig + +# Invalid: both command and url specified +try: + config = MCPServerConfig(command="echo", url="http://localhost:3000") +except ValueError as e: + print(f"Configuration error: {e}") + +# Invalid: neither command nor url specified +try: + config = MCPServerConfig() +except ValueError as e: + print(f"Configuration error: {e}") +``` + +### Connection Errors + +```python +manager = MCPManager(settings) +toolsets = manager.get_toolsets_sync() + +# Check which servers failed to initialize +errors = manager.get_errors() +if errors: + for name, error in errors.items(): + print(f"Server '{name}' failed: {error}") + + # Successfully initialized servers still work + print(f"Working servers: {len(toolsets)}") +``` diff --git a/pyproject.toml b/pyproject.toml index e37f7418..5a25a105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,11 @@ wallet = [ "eth-account>=0.13.7", "pyjwt>=2.10.1", ] +mcp = [ + # MCP SDK for Model Context Protocol integration + # Note: google-adk provides McpToolset, but mcp package needed for StdioServerParameters + "mcp>=1.0.0", +] [build-system] requires = ["hatchling"] diff --git a/src/flare_ai_kit/config.py b/src/flare_ai_kit/config.py index b6c47841..c219df88 100644 --- a/src/flare_ai_kit/config.py +++ b/src/flare_ai_kit/config.py @@ -9,6 +9,7 @@ from flare_ai_kit.agent.settings import AgentSettings from flare_ai_kit.ecosystem.settings import EcosystemSettings from flare_ai_kit.ingestion.settings import IngestionSettings +from flare_ai_kit.mcp.settings import MCPSettings from flare_ai_kit.rag.graph.settings import GraphDbSettings from flare_ai_kit.rag.vector.settings import VectorDbSettings from flare_ai_kit.social.settings import SocialSettings @@ -38,3 +39,4 @@ class AppSettings(BaseSettings): tee: TeeSettings = Field(default_factory=TeeSettings) # pyright: ignore[reportArgumentType,reportUnknownVariableType] ingestion: IngestionSettings = Field(default_factory=IngestionSettings) # pyright: ignore[reportArgumentType,reportUnknownVariableType] a2a: A2ASettings = Field(default_factory=A2ASettings) # pyright: ignore[reportArgumentType,reportUnknownVariableType] + mcp: MCPSettings = Field(default_factory=MCPSettings) # pyright: ignore[reportArgumentType,reportUnknownVariableType] diff --git a/src/flare_ai_kit/main.py b/src/flare_ai_kit/main.py index 4f0d35a2..ffdc7952 100644 --- a/src/flare_ai_kit/main.py +++ b/src/flare_ai_kit/main.py @@ -10,10 +10,13 @@ from .config import AppSettings if TYPE_CHECKING: + from google.adk.tools.mcp_tool import McpToolset + from .a2a import A2AClient from .ecosystem.api import BlockExplorer, FAssets, Flare, FtsoV2 from .ingestion.api import GithubIngestor from .ingestion.pdf_processor import PDFProcessor + from .mcp.manager import MCPManager from .rag.vector.api import VectorRAGPipeline from .social.api import TelegramClient, XClient @@ -50,6 +53,7 @@ def __init__(self, config: AppSettings | None) -> None: self._x_client: XClient | None = None self._pdf_processor: PDFProcessor | None = None self._a2a_client: A2AClient | None = None + self._mcp_manager: MCPManager | None = None # Ecosystem Interaction Methods @property @@ -160,6 +164,32 @@ def a2a_client(self) -> A2AClient: self._a2a_client = A2AClient(settings=self.settings.a2a) return self._a2a_client + # MCP Methods + @property + def mcp_manager(self) -> MCPManager: + """Access the MCP manager (configured via `MCP__SERVERS`).""" + from .mcp.manager import MCPManager # noqa: PLC0415 + + if self._mcp_manager is None: + self._mcp_manager = MCPManager(self.settings.mcp) + return self._mcp_manager + + @property + def mcp_tools(self) -> list[McpToolset]: + """Get MCP toolsets for use with ADK agents (empty if none configured).""" + return self.mcp_manager.get_toolsets_sync() + + @property + def has_mcp_tools(self) -> bool: + """Check if MCP tools are configured and available.""" + return self.mcp_manager.has_servers + + async def close_mcp(self) -> None: + """Close all MCP connections. Call during cleanup.""" + if self._mcp_manager is not None: + await self._mcp_manager.close() + self._mcp_manager = None + async def core() -> None: """Core function to run the Flare AI Kit SDK.""" diff --git a/src/flare_ai_kit/mcp/__init__.py b/src/flare_ai_kit/mcp/__init__.py new file mode 100644 index 00000000..97b61bf8 --- /dev/null +++ b/src/flare_ai_kit/mcp/__init__.py @@ -0,0 +1,6 @@ +"""MCP (Model Context Protocol) integration for Flare AI Kit.""" + +from .manager import MCPManager +from .settings import MCPServerConfig, MCPSettings + +__all__ = ["MCPManager", "MCPServerConfig", "MCPSettings"] diff --git a/src/flare_ai_kit/mcp/manager.py b/src/flare_ai_kit/mcp/manager.py new file mode 100644 index 00000000..64ca29b9 --- /dev/null +++ b/src/flare_ai_kit/mcp/manager.py @@ -0,0 +1,187 @@ +"""MCP manager for creating and managing MCP toolsets.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import structlog + +from .settings import MCPServerConfig, MCPSettings # noqa: TC001 + +if TYPE_CHECKING: + from google.adk.tools.mcp_tool import McpToolset + +logger = structlog.get_logger(__name__) + + +class MCPManager: + """Creates MCP toolsets from configuration and tracks initialization errors.""" + + def __init__(self, settings: MCPSettings) -> None: + """Initialize the manager with MCP settings.""" + self._settings = settings + self._toolsets: dict[str, McpToolset] = {} + self._initialized = False + self._initialization_errors: dict[str, Exception] = {} + + @property + def has_servers(self) -> bool: + """Check if any MCP servers are configured.""" + return self._settings.has_servers + + @property + def server_names(self) -> list[str]: + """Get list of configured server names.""" + return list(self._settings.get_enabled_servers().keys()) + + def _create_toolset( + self, + name: str, + config: MCPServerConfig, + ) -> McpToolset: + """ + Create a McpToolset from server configuration. + + Args: + name: Server name (for logging). + config: Server configuration. + + Returns: + McpToolset instance. + + Raises: + ImportError: If MCP dependencies are not installed. + + """ + # Check for required MCP dependencies + try: + from mcp import StdioServerParameters # noqa: PLC0415 + except ImportError as e: + msg = ( + "MCP dependencies not installed. " + "Install with: pip install flare-ai-kit[mcp]" + ) + raise ImportError(msg) from e + + # Import ADK MCP tools - these are available in google-adk>=1.19.0 + from google.adk.tools.mcp_tool.mcp_session_manager import ( # noqa: PLC0415 + SseConnectionParams, + StdioConnectionParams, + StreamableHTTPConnectionParams, + ) + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset # noqa: PLC0415 + + connection_params: ( + StdioConnectionParams | SseConnectionParams | StreamableHTTPConnectionParams + ) + + if config.is_stdio: + # Stdio-based server (local process) + if config.command is None: + msg = f"Server {name}: stdio config requires 'command'" + raise ValueError(msg) + connection_params = StdioConnectionParams( + server_params=StdioServerParameters( + command=config.command, + args=config.args, + env=config.env or {}, + ), + timeout=config.timeout, + ) + logger.debug( + "mcp_creating_stdio_toolset", + server=name, + command=config.command, + args=config.args, + ) + elif config.transport == "sse": + if config.url is None: + msg = f"Server {name}: SSE config requires 'url'" + raise ValueError(msg) + connection_params = SseConnectionParams( + url=str(config.url), + headers=config.headers if config.headers else None, + timeout=config.timeout, + ) + logger.debug( + "mcp_creating_sse_toolset", + server=name, + url=str(config.url), + ) + else: + if config.url is None: + msg = f"Server {name}: HTTP config requires 'url'" + raise ValueError(msg) + connection_params = StreamableHTTPConnectionParams( + url=str(config.url), + headers=config.headers if config.headers else None, + timeout=config.timeout, + ) + logger.debug( + "mcp_creating_http_toolset", + server=name, + url=str(config.url), + ) + + return McpToolset( + connection_params=connection_params, + tool_filter=config.tool_filter, + ) + + def get_toolsets_sync(self) -> list[McpToolset]: + """Return configured toolsets; connections are initialized lazily by ADK.""" + if self._initialized: + return list(self._toolsets.values()) + + enabled_servers = self._settings.get_enabled_servers() + + for name, config in enabled_servers.items(): + try: + toolset = self._create_toolset(name, config) + self._toolsets[name] = toolset + logger.info("mcp_toolset_created", server=name) + except ImportError: + raise + except Exception as e: # noqa: BLE001 + self._initialization_errors[name] = e + logger.warning( + "mcp_toolset_creation_failed", + server=name, + error=str(e), + ) + + self._initialized = True + return list(self._toolsets.values()) + + async def get_toolsets(self) -> list[McpToolset]: + """Async wrapper around `get_toolsets_sync()`.""" + return self.get_toolsets_sync() + + def get_toolset(self, name: str) -> McpToolset | None: + """Get a toolset by server name, initializing toolsets on first access.""" + if not self._initialized: + self.get_toolsets_sync() + return self._toolsets.get(name) + + def get_errors(self) -> dict[str, Exception]: + """Get any errors that occurred during initialization.""" + return dict(self._initialization_errors) + + async def close(self) -> None: + """Close all MCP connections.""" + for name, toolset in self._toolsets.items(): + try: + if hasattr(toolset, "close"): + close_result = toolset.close() + if hasattr(close_result, "__await__"): + await close_result + logger.debug("mcp_toolset_closed", server=name) + except Exception as e: # noqa: BLE001 + logger.warning( + "mcp_toolset_close_failed", + server=name, + error=str(e), + ) + + self._toolsets.clear() + self._initialized = False diff --git a/src/flare_ai_kit/mcp/settings.py b/src/flare_ai_kit/mcp/settings.py new file mode 100644 index 00000000..6e3de00c --- /dev/null +++ b/src/flare_ai_kit/mcp/settings.py @@ -0,0 +1,139 @@ +"""MCP server configuration settings.""" + +from __future__ import annotations + +import json +from typing import Annotated, Literal + +import structlog +from pydantic import BaseModel, BeforeValidator, Field, HttpUrl, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +logger = structlog.get_logger(__name__) + + +def _parse_json_servers( + value: dict[str, object] | str | None, +) -> dict[str, object]: + """Parse servers from JSON string or pass through dict.""" + if value is None: + return {} + if isinstance(value, str): + try: + parsed: dict[str, object] = json.loads(value) + except json.JSONDecodeError as e: + msg = f"Invalid JSON in MCP__SERVERS: {e}" + raise ValueError(msg) from e + else: + return parsed + return value + + +class MCPServerConfig(BaseModel): + """Configuration for a single MCP server (stdio or remote).""" + + # Stdio config fields (for local process-based servers) + command: str | None = Field( + default=None, + description="Command to run the MCP server process", + ) + args: list[str] = Field( + default_factory=list, + description="Arguments for the command", + ) + env: dict[str, str] = Field( + default_factory=dict, + description="Environment variables for the server process", + ) + + # Remote config fields (for external servers) + url: HttpUrl | None = Field( + default=None, + description="URL of the remote MCP server", + ) + transport: Literal["stdio", "sse", "http"] = Field( + default="stdio", + description="Transport type: 'stdio' for local, 'sse' or 'http' for remote", + ) + headers: dict[str, str] = Field( + default_factory=dict, + description="HTTP headers for authentication with remote servers", + ) + + # Common fields + timeout: float = Field( + default=30.0, + description="Connection timeout in seconds", + ) + tool_filter: list[str] | None = Field( + default=None, + description="List of tool names to expose (None = all tools)", + ) + enabled: bool = Field( + default=True, + description="Whether this server is enabled", + ) + + @model_validator(mode="after") + def validate_config_type(self) -> MCPServerConfig: + """Ensure either command (stdio) or url (remote) is provided, not both.""" + has_command = self.command is not None + has_url = self.url is not None + + if has_command and has_url: + msg = "Cannot specify both 'command' (stdio) and 'url' (remote)" + raise ValueError(msg) + + if not has_command and not has_url: + msg = "Must specify either 'command' (for stdio) or 'url' (for remote)" + raise ValueError(msg) + + # Infer transport type from config + if has_command: + object.__setattr__(self, "transport", "stdio") + elif self.transport == "stdio": + # Remote server with default transport, switch to http + object.__setattr__(self, "transport", "http") + + return self + + @property + def is_stdio(self) -> bool: + """Check if this is a stdio-based (local process) server.""" + return self.command is not None + + @property + def is_remote(self) -> bool: + """Check if this is a remote server.""" + return self.url is not None + + +# Type alias for parsed servers with BeforeValidator +ParsedServers = Annotated[ + dict[str, MCPServerConfig], + BeforeValidator(_parse_json_servers), +] + + +class MCPSettings(BaseSettings): + """Settings for MCP (Model Context Protocol) server integration.""" + + model_config = SettingsConfigDict( + env_prefix="MCP__", + env_file=".env", + extra="ignore", + ) + + servers: ParsedServers = Field( + default_factory=dict, + description="Map of server names to their configurations (JSON string)", + ) + + def get_enabled_servers(self) -> dict[str, MCPServerConfig]: + """Return only enabled server configurations.""" + return {name: config for name, config in self.servers.items() if config.enabled} + + @property + def has_servers(self) -> bool: + """Check if any MCP servers are configured.""" + return len(self.get_enabled_servers()) > 0 diff --git a/tests/unit/mcp/__init__.py b/tests/unit/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/mcp/test_manager.py b/tests/unit/mcp/test_manager.py new file mode 100644 index 00000000..c2756634 --- /dev/null +++ b/tests/unit/mcp/test_manager.py @@ -0,0 +1,383 @@ +"""Tests for MCP Manager.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from flare_ai_kit.mcp.manager import MCPManager +from flare_ai_kit.mcp.settings import MCPServerConfig, MCPSettings + + +class TestMCPManager: + """Tests for MCPManager class.""" + + def test_init_with_empty_settings(self): + """Test initialization with empty settings.""" + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + + assert not manager.has_servers + assert manager.server_names == [] + assert manager.get_errors() == {} + + def test_init_with_servers(self): + """Test initialization with configured servers.""" + settings = MCPSettings( + servers={ + "server1": MCPServerConfig(command="echo", args=[]), + "server2": MCPServerConfig(command="cat", args=[]), + } + ) + manager = MCPManager(settings) + + assert manager.has_servers + assert set(manager.server_names) == {"server1", "server2"} + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_get_toolsets_sync_creates_toolsets(self, mock_toolset_class): + """Test that toolsets are created for each server.""" + mock_toolset_class.return_value = MagicMock() + + settings = MCPSettings( + servers={ + "server1": MCPServerConfig(command="npx", args=["-y", "server1"]), + "server2": MCPServerConfig(command="npx", args=["-y", "server2"]), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert mock_toolset_class.call_count == 2 + assert len(toolsets) == 2 + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_get_toolsets_sync_caches_result(self, mock_toolset_class): + """Test that toolsets are only created once.""" + mock_toolset_class.return_value = MagicMock() + + settings = MCPSettings( + servers={"server1": MCPServerConfig(command="echo", args=[])} + ) + manager = MCPManager(settings) + + toolsets1 = manager.get_toolsets_sync() + toolsets2 = manager.get_toolsets_sync() + + assert mock_toolset_class.call_count == 1 + assert toolsets1 == toolsets2 + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_get_toolsets_handles_errors(self, mock_toolset_class): + """Test graceful error handling during toolset creation.""" + mock_toolset_class.side_effect = [ + Exception("Connection failed"), + MagicMock(), + ] + + settings = MCPSettings( + servers={ + "failing": MCPServerConfig(command="fail", args=[]), + "working": MCPServerConfig(command="work", args=[]), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert len(toolsets) == 1 + errors = manager.get_errors() + assert "failing" in errors + assert "working" not in errors + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_get_toolset_by_name(self, mock_toolset_class): + """Test getting specific toolset by name.""" + mock_instance = MagicMock() + mock_toolset_class.return_value = mock_instance + + settings = MCPSettings( + servers={"my-server": MCPServerConfig(command="echo", args=[])} + ) + manager = MCPManager(settings) + + toolset = manager.get_toolset("my-server") + + assert toolset is mock_instance + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_get_toolset_not_found(self, mock_toolset_class): + """Test getting non-existent toolset returns None.""" + mock_toolset_class.return_value = MagicMock() + + settings = MCPSettings( + servers={"my-server": MCPServerConfig(command="echo", args=[])} + ) + manager = MCPManager(settings) + + toolset = manager.get_toolset("nonexistent") + + assert toolset is None + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_disabled_servers_not_created(self, mock_toolset_class): + """Test that disabled servers don't create toolsets.""" + mock_toolset_class.return_value = MagicMock() + + settings = MCPSettings( + servers={ + "enabled": MCPServerConfig(command="echo", args=[], enabled=True), + "disabled": MCPServerConfig(command="echo", args=[], enabled=False), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert mock_toolset_class.call_count == 1 + assert len(toolsets) == 1 + + @pytest.mark.asyncio + async def test_get_toolsets_async(self): + """Test async toolsets method.""" + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + + toolsets = await manager.get_toolsets() + + assert toolsets == [] + + @pytest.mark.asyncio + async def test_close(self): + """Test close cleans up all toolsets.""" + mock_toolset = MagicMock() + mock_toolset.close = MagicMock(return_value=None) + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + manager._toolsets = {"test": mock_toolset} + manager._initialized = True + + await manager.close() + + mock_toolset.close.assert_called_once() + assert len(manager._toolsets) == 0 + assert manager._initialized is False + + @pytest.mark.asyncio + async def test_close_handles_errors(self): + """Test close continues on errors.""" + mock_toolset1 = MagicMock() + mock_toolset1.close = MagicMock(side_effect=Exception("Error")) + + mock_toolset2 = MagicMock() + mock_toolset2.close = MagicMock(return_value=None) + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + manager._toolsets = {"failing": mock_toolset1, "working": mock_toolset2} + manager._initialized = True + + await manager.close() + + mock_toolset1.close.assert_called_once() + mock_toolset2.close.assert_called_once() + assert len(manager._toolsets) == 0 + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_create_toolset_stdio_missing_command(self, mock_toolset_class): + """Test that _create_toolset raises ValueError when stdio config has no command.""" + # Create a mock config that bypasses Pydantic validation + mock_config = MagicMock() + mock_config.is_stdio = True + mock_config.command = None # Missing command + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + + with pytest.raises(ValueError, match="stdio config requires 'command'"): + manager._create_toolset("test-server", mock_config) + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_create_toolset_sse_missing_url(self, mock_toolset_class): + """Test that _create_toolset raises ValueError when SSE config has no url.""" + mock_config = MagicMock() + mock_config.is_stdio = False + mock_config.transport = "sse" + mock_config.url = None # Missing URL + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + + with pytest.raises(ValueError, match="SSE config requires 'url'"): + manager._create_toolset("test-server", mock_config) + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_create_toolset_http_missing_url(self, mock_toolset_class): + """Test that _create_toolset raises ValueError when HTTP config has no url.""" + mock_config = MagicMock() + mock_config.is_stdio = False + mock_config.transport = "http" + mock_config.url = None # Missing URL + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + + with pytest.raises(ValueError, match="HTTP config requires 'url'"): + manager._create_toolset("test-server", mock_config) + + def test_import_error_propagates(self): + """Test that ImportError is re-raised, not silently caught.""" + import sys + + settings = MCPSettings( + servers={"server1": MCPServerConfig(command="echo", args=[])} + ) + manager = MCPManager(settings) + + original_mcp = sys.modules.get("mcp") + sys.modules["mcp"] = None # type: ignore[assignment] + + try: + with pytest.raises(ImportError, match="MCP dependencies not installed"): + manager.get_toolsets_sync() + finally: + if original_mcp is not None: + sys.modules["mcp"] = original_mcp + elif "mcp" in sys.modules: + del sys.modules["mcp"] + + @pytest.mark.asyncio + async def test_close_awaits_async_close(self): + """Test that close() properly awaits an async toolset.close() method.""" + close_awaited = False + + async def async_close(): + nonlocal close_awaited + close_awaited = True + + mock_toolset = MagicMock() + mock_toolset.close = MagicMock(return_value=async_close()) + + settings = MCPSettings(servers={}) + manager = MCPManager(settings) + manager._toolsets = {"async-server": mock_toolset} + manager._initialized = True + + await manager.close() + + assert close_awaited, "Async close() should have been awaited" + assert len(manager._toolsets) == 0 + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + def test_non_import_errors_stored_in_get_errors(self, mock_toolset_class): + """Test that non-ImportError exceptions are stored and remaining toolsets returned.""" + working_toolset = MagicMock() + mock_toolset_class.side_effect = [ + ValueError("Invalid config"), + working_toolset, + ] + + settings = MCPSettings( + servers={ + "failing": MCPServerConfig(command="fail", args=[]), + "working": MCPServerConfig(command="work", args=[]), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert len(toolsets) == 1 + assert toolsets[0] is working_toolset + + errors = manager.get_errors() + assert "failing" in errors + assert isinstance(errors["failing"], ValueError) + assert "working" not in errors + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.SseConnectionParams") + def test_create_toolset_sse_server(self, mock_sse_params, mock_toolset_class): + """Test that SSE remote server is created with correct params.""" + mock_toolset_instance = MagicMock() + mock_toolset_class.return_value = mock_toolset_instance + + settings = MCPSettings( + servers={ + "sse-server": MCPServerConfig( + url="http://localhost:3000/sse", + transport="sse", + headers={"Authorization": "Bearer token"}, + timeout=60.0, + tool_filter=["tool1", "tool2"], + ), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert len(toolsets) == 1 + mock_sse_params.assert_called_once() + call_kwargs = mock_sse_params.call_args.kwargs + assert call_kwargs["url"] == "http://localhost:3000/sse" + assert call_kwargs["headers"] == {"Authorization": "Bearer token"} + assert call_kwargs["timeout"] == 60.0 + mock_toolset_class.assert_called_once() + toolset_kwargs = mock_toolset_class.call_args.kwargs + assert toolset_kwargs["tool_filter"] == ["tool1", "tool2"] + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + @patch( + "google.adk.tools.mcp_tool.mcp_session_manager.StreamableHTTPConnectionParams" + ) + def test_create_toolset_http_server(self, mock_http_params, mock_toolset_class): + """Test that HTTP remote server is created with correct params.""" + mock_toolset_instance = MagicMock() + mock_toolset_class.return_value = mock_toolset_instance + + settings = MCPSettings( + servers={ + "http-server": MCPServerConfig( + url="https://api.example.com/mcp", + transport="http", + headers={"X-API-Key": "secret"}, + timeout=45.0, + ), + } + ) + manager = MCPManager(settings) + + toolsets = manager.get_toolsets_sync() + + assert len(toolsets) == 1 + mock_http_params.assert_called_once() + call_kwargs = mock_http_params.call_args.kwargs + assert call_kwargs["url"] == "https://api.example.com/mcp" + assert call_kwargs["headers"] == {"X-API-Key": "secret"} + assert call_kwargs["timeout"] == 45.0 + + @patch("google.adk.tools.mcp_tool.mcp_toolset.McpToolset") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.StdioConnectionParams") + @patch("mcp.StdioServerParameters") + def test_stdio_server_uses_empty_env_not_none( + self, mock_server_params, mock_stdio_params, mock_toolset_class + ): + """Test that stdio servers use env={}, not env=None.""" + mock_toolset_class.return_value = MagicMock() + + settings = MCPSettings( + servers={ + "test-server": MCPServerConfig(command="echo", args=["hello"]), + } + ) + manager = MCPManager(settings) + + manager.get_toolsets_sync() + + mock_server_params.assert_called_once() + call_kwargs = mock_server_params.call_args.kwargs + assert call_kwargs["env"] == {}, ( + "env should be empty dict, not None (prevents parent env inheritance)" + ) diff --git a/tests/unit/mcp/test_settings.py b/tests/unit/mcp/test_settings.py new file mode 100644 index 00000000..7445a110 --- /dev/null +++ b/tests/unit/mcp/test_settings.py @@ -0,0 +1,211 @@ +"""Tests for MCP settings parsing.""" + +import pytest + +from flare_ai_kit.mcp.settings import MCPServerConfig, MCPSettings + + +class TestMCPServerConfig: + """Tests for MCPServerConfig model.""" + + def test_stdio_config(self): + """Test valid stdio server configuration.""" + config = MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem", "/path"], + env={"API_KEY": "secret"}, + ) + assert config.is_stdio + assert not config.is_remote + assert config.transport == "stdio" + assert config.command == "npx" + assert config.args == ["-y", "@modelcontextprotocol/server-filesystem", "/path"] + assert config.env == {"API_KEY": "secret"} + + def test_stdio_config_minimal(self): + """Test minimal stdio configuration.""" + config = MCPServerConfig(command="echo", args=["hello"]) + assert config.is_stdio + assert config.transport == "stdio" + assert config.enabled is True + assert config.timeout == 30.0 + + def test_remote_sse_config(self): + """Test valid SSE remote server configuration.""" + config = MCPServerConfig( + url="http://localhost:3000/sse", + transport="sse", + ) + assert config.is_remote + assert not config.is_stdio + assert config.transport == "sse" + assert str(config.url) == "http://localhost:3000/sse" + + def test_remote_http_config(self): + """Test valid HTTP remote server configuration.""" + config = MCPServerConfig( + url="http://localhost:3000/mcp", + transport="http", + headers={"Authorization": "Bearer token"}, + ) + assert config.is_remote + assert config.transport == "http" + assert config.headers == {"Authorization": "Bearer token"} + + def test_remote_default_transport(self): + """Test that remote server defaults to http transport.""" + config = MCPServerConfig(url="http://localhost:3000/mcp") + assert config.is_remote + assert config.transport == "http" + + def test_invalid_both_command_and_url(self): + """Test that having both command and url raises error.""" + with pytest.raises(ValueError, match="Cannot specify both"): + MCPServerConfig( + command="npx", + url="http://localhost:3000", + ) + + def test_invalid_neither_command_nor_url(self): + """Test that having neither command nor url raises error.""" + with pytest.raises(ValueError, match="Must specify either"): + MCPServerConfig() + + def test_tool_filter(self): + """Test tool filtering configuration.""" + config = MCPServerConfig( + command="npx", + args=["-y", "some-server"], + tool_filter=["read_file", "write_file"], + ) + assert config.tool_filter == ["read_file", "write_file"] + + def test_disabled_server(self): + """Test disabled server configuration.""" + config = MCPServerConfig( + command="npx", + args=["-y", "some-server"], + enabled=False, + ) + assert config.enabled is False + + def test_custom_timeout(self): + """Test custom timeout configuration.""" + config = MCPServerConfig( + command="npx", + args=["-y", "some-server"], + timeout=60.0, + ) + assert config.timeout == 60.0 + + +class TestMCPSettings: + """Tests for MCPSettings model.""" + + def test_empty_settings(self): + """Test empty MCP settings.""" + settings = MCPSettings() + assert len(settings.servers) == 0 + assert settings.get_enabled_servers() == {} + assert settings.has_servers is False + + def test_settings_with_servers(self): + """Test settings with configured servers.""" + settings = MCPSettings( + servers={ + "filesystem": MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem", "/path"], + ), + "remote": MCPServerConfig( + url="http://localhost:3000/sse", + transport="sse", + ), + } + ) + assert len(settings.servers) == 2 + assert "filesystem" in settings.servers + assert "remote" in settings.servers + assert settings.servers["filesystem"].is_stdio + assert settings.servers["remote"].is_remote + assert settings.has_servers is True + + def test_get_enabled_servers(self): + """Test filtering enabled servers.""" + settings = MCPSettings( + servers={ + "enabled": MCPServerConfig(command="echo", args=[], enabled=True), + "disabled": MCPServerConfig(command="echo", args=[], enabled=False), + } + ) + enabled = settings.get_enabled_servers() + assert "enabled" in enabled + assert "disabled" not in enabled + + def test_parse_json_from_env_var(self, monkeypatch): + """Test parsing MCP__SERVERS from environment variable.""" + json_config = ( + '{"filesystem": {"command": "npx", "args": ["-y", "server"]}, ' + '"api": {"url": "http://localhost:3000/mcp", "transport": "http"}}' + ) + monkeypatch.setenv("MCP__SERVERS", json_config) + + settings = MCPSettings() + + assert len(settings.servers) == 2 + assert "filesystem" in settings.servers + assert "api" in settings.servers + assert settings.servers["filesystem"].command == "npx" + assert settings.servers["api"].transport == "http" + + def test_parse_tool_filter_from_env_var(self, monkeypatch): + """Test that tool_filter is properly parsed from JSON env var.""" + json_config = ( + '{"fs": {"command": "npx", "args": ["-y", "server"], ' + '"tool_filter": ["read_file", "list_directory"]}}' + ) + monkeypatch.setenv("MCP__SERVERS", json_config) + + settings = MCPSettings() + + assert settings.servers["fs"].tool_filter == ["read_file", "list_directory"] + + def test_parse_invalid_json_from_env_var(self, monkeypatch): + """Test error on invalid JSON in env var.""" + from pydantic_settings.exceptions import SettingsError + + monkeypatch.setenv("MCP__SERVERS", "not valid json{") + + with pytest.raises(SettingsError, match="error parsing value"): + MCPSettings() + + def test_multiple_servers_mixed_types(self): + """Test settings with mixed server types.""" + settings = MCPSettings( + servers={ + "local-fs": MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + ), + "local-git": MCPServerConfig( + command="python", + args=["-m", "mcp_git_server"], + ), + "remote-api": MCPServerConfig( + url="https://api.example.com/mcp", + transport="http", + headers={"X-API-Key": "secret"}, + ), + "remote-stream": MCPServerConfig( + url="https://stream.example.com/sse", + transport="sse", + ), + } + ) + + assert len(settings.servers) == 4 + assert settings.servers["local-fs"].is_stdio + assert settings.servers["local-git"].is_stdio + assert settings.servers["remote-api"].is_remote + assert settings.servers["remote-api"].transport == "http" + assert settings.servers["remote-stream"].transport == "sse" diff --git a/uv.lock b/uv.lock index c07f3248..68ae230a 100644 --- a/uv.lock +++ b/uv.lock @@ -1032,6 +1032,9 @@ dependencies = [ a2a = [ { name = "fastapi", extra = ["standard"] }, ] +mcp = [ + { name = "mcp" }, +] pdf = [ { name = "dulwich" }, { name = "pillow" }, @@ -1085,6 +1088,7 @@ requires-dist = [ { name = "google-adk", specifier = ">=1.19.0" }, { name = "google-genai", specifier = ">=1.51.0" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.0.0" }, { name = "pillow", marker = "extra == 'pdf'", specifier = ">=11.3.0" }, { name = "pydantic", specifier = ">=2.12.4" }, { name = "pyjwt", marker = "extra == 'tee'", specifier = ">=2.10.1" }, @@ -1101,7 +1105,7 @@ requires-dist = [ { name = "tweepy", marker = "extra == 'social'", specifier = ">=4.16.0" }, { name = "web3", specifier = ">=7.14.0" }, ] -provides-extras = ["ftso", "da", "fassets", "pdf", "rag", "a2a", "social", "tee", "wallet"] +provides-extras = ["ftso", "da", "fassets", "pdf", "rag", "a2a", "social", "tee", "wallet", "mcp"] [package.metadata.requires-dev] dev = [