Skip to content

feat: mcp server db and config connection #10707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 80 additions & 30 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
@@ -9,26 +9,31 @@
import asyncio
import json
from typing import Any, Dict, List, Optional
import uuid

from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import Tool as MCPTool
from prisma.models import LiteLLM_MCPServerTable

from litellm._logging import verbose_logger
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer

# from litellm.proxy._experimental.mcp_server.db import fetch_all_mcp_servers
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer

class MCPServerManager:
def __init__(self):
self.mcp_servers: List[MCPSSEServer] = []
self.registry: Dict[str, MCPServer] = {}
"""
eg.
[
{
"server-1": {
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
"transport": "sse",
auth_type: "api_key",
"spec_version": "2025-03-26"
},
{
"uuid-2": {
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
@@ -41,28 +46,64 @@ def __init__(self):
"gmail_send_email": "zapier_mcp_server",
}
"""
def get_registry(self) -> Dict[str, MCPServer]:
"""
Get the registered MCP Servers
"""
return self.registry

def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
verbose_logger.debug("Loading MCP Servers from config-----")
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
self.mcp_servers.append(
MCPSSEServer(
name=server_name,
url=server_config["url"],
mcp_info=mcp_info,
)
mcp_info["description"] = server_config.get("description", None)
new_server = MCPServer(
name=server_name,
url=server_config["url"],
# TODO: utility fn the default values
transport=server_config.get("transport", "sse"),
spec_version=server_config.get("spec_version", "2025-03-26"),
auth_type=server_config.get("auth_type", None),
mcp_info=mcp_info,
)
server_id = str(uuid.uuid4())
self.registry[server_id] = new_server
verbose_logger.debug(
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
f"Loaded MCP Servers: {json.dumps(self.registry, indent=4, default=str)}"
)

self.initialize_tool_name_to_mcp_server_name_mapping()

async def load_from_db(self) -> List[MCPServer]:
"""
Load the MCP Servers from the database
"""
# TODO: start reading from db to import
# from litellm.proxy.proxy_server import prisma_client

# if prisma_client is not None:
# mcp_servers = await fetch_all_mcp_servers(prisma_client)
# for db_server in mcp_servers:
# new_server = MCPServer(
# name=db_server.alias or "unknown",
# url=db_server.url,
# transport=db_server.transport,
# spec_version=db_server.spec_version,
# auth_type=db_server.auth_type,
# mcp_info=MCPInfo(
# server_name=db_server.alias or db_server.server_id,
# description=db_server.description,
# logo_url=None,
# ),
# )
# self.registry[db_server.server_id] = new_server
return list(self.registry.values())

async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.
@@ -73,13 +114,13 @@ async def list_tools(self) -> List[MCPTool]:
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")

for server in self.mcp_servers:
for _, server in self.registry.items():
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)

return list_tools_result

async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.

@@ -91,19 +132,28 @@ async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
"""
verbose_logger.debug(f"Connecting to url: {server.url}")

async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()

tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")

# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name

return tools_result.tools

verbose_logger.info("_get_tools_from_server...")
# send transport to connect to the server
if server.transport is None or server.transport == "sse":
async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()

tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")

# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name

return tools_result.tools
elif server.transport == "http":
# TODO: implement http transport
return []
else:
# TODO: throw error on transport found or skip
return []

def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
@@ -122,7 +172,7 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
for server in self.mcp_servers:
for _, server in self.registry.items():
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
@@ -139,12 +189,12 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]):
await session.initialize()
return await session.call_tool(name, arguments)

def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
for server in self.mcp_servers:
for _, server in self.registry.items():
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None
7 changes: 6 additions & 1 deletion litellm/types/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,16 @@

class MCPInfo(TypedDict, total=False):
server_name: str
description: Optional[str]
logo_url: Optional[str]


class MCPSSEServer(BaseModel):
class MCPServer(BaseModel):
name: str
url: str
# TODO: alter the types to be the Literal explicit
transport: str
spec_version: str
auth_type: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
2 changes: 1 addition & 1 deletion tests/mcp_tests/test_mcp_server.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
MCPSSEServer,
MCPServer,
)