diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 9becb8075843..f15f1d4c361f 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -9,26 +9,31 @@ import asyncio import json from typing import Any, Dict, List, Optional +import uuid from mcp import ClientSession from mcp.client.sse import sse_client from mcp.types import Tool as MCPTool +from prisma.models import LiteLLM_MCPServerTable from litellm._logging import verbose_logger -from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer - +# from litellm.proxy._experimental.mcp_server.db import fetch_all_mcp_servers +from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer class MCPServerManager: def __init__(self): - self.mcp_servers: List[MCPSSEServer] = [] + self.registry: Dict[str, MCPServer] = {} """ eg. [ - { + "server-1": { "name": "zapier_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" + "transport": "sse", + auth_type: "api_key", + "spec_version": "2025-03-26" }, - { + "uuid-2": { "name": "google_drive_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" } @@ -41,28 +46,64 @@ def __init__(self): "gmail_send_email": "zapier_mcp_server", } """ + def get_registry(self) -> Dict[str, MCPServer]: + """ + Get the registered MCP Servers + """ + return self.registry def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]): """ Load the MCP Servers from the config """ + verbose_logger.debug("Loading MCP Servers from config-----") for server_name, server_config in mcp_servers_config.items(): _mcp_info: dict = server_config.get("mcp_info", None) or {} mcp_info = MCPInfo(**_mcp_info) mcp_info["server_name"] = server_name - self.mcp_servers.append( - MCPSSEServer( - name=server_name, - url=server_config["url"], - mcp_info=mcp_info, - ) + mcp_info["description"] = server_config.get("description", None) + new_server = MCPServer( + name=server_name, + url=server_config["url"], + # TODO: utility fn the default values + transport=server_config.get("transport", "sse"), + spec_version=server_config.get("spec_version", "2025-03-26"), + auth_type=server_config.get("auth_type", None), + mcp_info=mcp_info, ) + server_id = str(uuid.uuid4()) + self.registry[server_id] = new_server verbose_logger.debug( - f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}" + f"Loaded MCP Servers: {json.dumps(self.registry, indent=4, default=str)}" ) self.initialize_tool_name_to_mcp_server_name_mapping() + async def load_from_db(self) -> List[MCPServer]: + """ + Load the MCP Servers from the database + """ + # TODO: start reading from db to import + # from litellm.proxy.proxy_server import prisma_client + + # if prisma_client is not None: + # mcp_servers = await fetch_all_mcp_servers(prisma_client) + # for db_server in mcp_servers: + # new_server = MCPServer( + # name=db_server.alias or "unknown", + # url=db_server.url, + # transport=db_server.transport, + # spec_version=db_server.spec_version, + # auth_type=db_server.auth_type, + # mcp_info=MCPInfo( + # server_name=db_server.alias or db_server.server_id, + # description=db_server.description, + # logo_url=None, + # ), + # ) + # self.registry[db_server.server_id] = new_server + return list(self.registry.values()) + async def list_tools(self) -> List[MCPTool]: """ List all tools available across all MCP Servers. @@ -73,13 +114,13 @@ async def list_tools(self) -> List[MCPTool]: list_tools_result: List[MCPTool] = [] verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS") - for server in self.mcp_servers: + for _, server in self.registry.items(): tools = await self._get_tools_from_server(server) list_tools_result.extend(tools) return list_tools_result - async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]: + async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: """ Helper method to get tools from a single MCP server. @@ -91,19 +132,28 @@ async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]: """ verbose_logger.debug(f"Connecting to url: {server.url}") - async with sse_client(url=server.url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - tools_result = await session.list_tools() - verbose_logger.debug(f"Tools from {server.name}: {tools_result}") - - # Update tool to server mapping - for tool in tools_result.tools: - self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name - - return tools_result.tools - + verbose_logger.info("_get_tools_from_server...") + # send transport to connect to the server + if server.transport is None or server.transport == "sse": + async with sse_client(url=server.url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + tools_result = await session.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools_result}") + + # Update tool to server mapping + for tool in tools_result.tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + + return tools_result.tools + elif server.transport == "http": + # TODO: implement http transport + return [] + else: + # TODO: throw error on transport found or skip + return [] + def initialize_tool_name_to_mcp_server_name_mapping(self): """ On startup, initialize the tool name to MCP server name mapping @@ -122,7 +172,7 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self): """ Call list_tools for each server and update the tool name to MCP server name mapping """ - for server in self.mcp_servers: + for _, server in self.registry.items(): tools = await self._get_tools_from_server(server) for tool in tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name @@ -139,12 +189,12 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]): await session.initialize() return await session.call_tool(name, arguments) - def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]: + def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: """ Get the MCP Server from the tool name """ if tool_name in self.tool_name_to_mcp_server_name_mapping: - for server in self.mcp_servers: + for _, server in self.registry.items(): if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]: return server return None diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index aecd11aa1aeb..4aacf60431e6 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -6,11 +6,16 @@ class MCPInfo(TypedDict, total=False): server_name: str + description: Optional[str] logo_url: Optional[str] -class MCPSSEServer(BaseModel): +class MCPServer(BaseModel): name: str url: str + # TODO: alter the types to be the Literal explicit + transport: str + spec_version: str + auth_type: Optional[str] = None mcp_info: Optional[MCPInfo] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 2cf919387153..9659ab329776 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -9,7 +9,7 @@ from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, - MCPSSEServer, + MCPServer, )