diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
index 9becb8075843..f15f1d4c361f 100644
--- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
+++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
@@ -9,26 +9,31 @@
import asyncio
import json
from typing import Any, Dict, List, Optional
+import uuid
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import Tool as MCPTool
+from prisma.models import LiteLLM_MCPServerTable
from litellm._logging import verbose_logger
-from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
-
+# from litellm.proxy._experimental.mcp_server.db import fetch_all_mcp_servers
+from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
class MCPServerManager:
def __init__(self):
- self.mcp_servers: List[MCPSSEServer] = []
+ self.registry: Dict[str, MCPServer] = {}
"""
eg.
[
- {
+ "server-1": {
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
+ "transport": "sse",
+ auth_type: "api_key",
+ "spec_version": "2025-03-26"
},
- {
+ "uuid-2": {
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
@@ -41,28 +46,64 @@ def __init__(self):
"gmail_send_email": "zapier_mcp_server",
}
"""
+ def get_registry(self) -> Dict[str, MCPServer]:
+ """
+ Get the registered MCP Servers
+ """
+ return self.registry
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
+ verbose_logger.debug("Loading MCP Servers from config-----")
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
- self.mcp_servers.append(
- MCPSSEServer(
- name=server_name,
- url=server_config["url"],
- mcp_info=mcp_info,
- )
+ mcp_info["description"] = server_config.get("description", None)
+ new_server = MCPServer(
+ name=server_name,
+ url=server_config["url"],
+ # TODO: utility fn the default values
+ transport=server_config.get("transport", "sse"),
+ spec_version=server_config.get("spec_version", "2025-03-26"),
+ auth_type=server_config.get("auth_type", None),
+ mcp_info=mcp_info,
)
+ server_id = str(uuid.uuid4())
+ self.registry[server_id] = new_server
verbose_logger.debug(
- f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
+ f"Loaded MCP Servers: {json.dumps(self.registry, indent=4, default=str)}"
)
self.initialize_tool_name_to_mcp_server_name_mapping()
+ async def load_from_db(self) -> List[MCPServer]:
+ """
+ Load the MCP Servers from the database
+ """
+ # TODO: start reading from db to import
+ # from litellm.proxy.proxy_server import prisma_client
+
+ # if prisma_client is not None:
+ # mcp_servers = await fetch_all_mcp_servers(prisma_client)
+ # for db_server in mcp_servers:
+ # new_server = MCPServer(
+ # name=db_server.alias or "unknown",
+ # url=db_server.url,
+ # transport=db_server.transport,
+ # spec_version=db_server.spec_version,
+ # auth_type=db_server.auth_type,
+ # mcp_info=MCPInfo(
+ # server_name=db_server.alias or db_server.server_id,
+ # description=db_server.description,
+ # logo_url=None,
+ # ),
+ # )
+ # self.registry[db_server.server_id] = new_server
+ return list(self.registry.values())
+
async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.
@@ -73,13 +114,13 @@ async def list_tools(self) -> List[MCPTool]:
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
- for server in self.mcp_servers:
+ for _, server in self.registry.items():
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)
return list_tools_result
- async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
+ async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.
@@ -91,19 +132,28 @@ async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
"""
verbose_logger.debug(f"Connecting to url: {server.url}")
- async with sse_client(url=server.url) as (read, write):
- async with ClientSession(read, write) as session:
- await session.initialize()
-
- tools_result = await session.list_tools()
- verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
-
- # Update tool to server mapping
- for tool in tools_result.tools:
- self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
-
- return tools_result.tools
-
+ verbose_logger.info("_get_tools_from_server...")
+ # send transport to connect to the server
+ if server.transport is None or server.transport == "sse":
+ async with sse_client(url=server.url) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ tools_result = await session.list_tools()
+ verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
+
+ # Update tool to server mapping
+ for tool in tools_result.tools:
+ self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
+
+ return tools_result.tools
+ elif server.transport == "http":
+ # TODO: implement http transport
+ return []
+ else:
+ # TODO: throw error on transport found or skip
+ return []
+
def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
@@ -122,7 +172,7 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
- for server in self.mcp_servers:
+ for _, server in self.registry.items():
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
@@ -139,12 +189,12 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]):
await session.initialize()
return await session.call_tool(name, arguments)
- def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
+ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
- for server in self.mcp_servers:
+ for _, server in self.registry.items():
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None
diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py
index aecd11aa1aeb..4aacf60431e6 100644
--- a/litellm/types/mcp_server/mcp_server_manager.py
+++ b/litellm/types/mcp_server/mcp_server_manager.py
@@ -6,11 +6,16 @@
class MCPInfo(TypedDict, total=False):
server_name: str
+ description: Optional[str]
logo_url: Optional[str]
-class MCPSSEServer(BaseModel):
+class MCPServer(BaseModel):
name: str
url: str
+ # TODO: alter the types to be the Literal explicit
+ transport: str
+ spec_version: str
+ auth_type: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py
index 2cf919387153..9659ab329776 100644
--- a/tests/mcp_tests/test_mcp_server.py
+++ b/tests/mcp_tests/test_mcp_server.py
@@ -9,7 +9,7 @@
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
- MCPSSEServer,
+ MCPServer,
)