diff --git a/.github/workflows/test-litellm.yml b/.github/workflows/test-litellm.yml index 6dbbe663abad..bc9c1f98cb17 100644 --- a/.github/workflows/test-litellm.yml +++ b/.github/workflows/test-litellm.yml @@ -37,4 +37,4 @@ jobs: cd .. - name: Run tests run: | - poetry run pytest tests/test_litellm -x -vv -n 4 \ No newline at end of file + poetry run pytest tests/test_litellm -x -vv -n 4 diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index f04324f965fd..ad16cd17d1ef 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -18,10 +18,19 @@ This allows you to define tools that can be called by any MCP compatible client. #### How it works +1. Allow proxy admin users to perform create, update, and delete operations on MCP servers stored in the db. +2. Allows users to view and call tools to the MCP servers they have access to. + LiteLLM exposes the following MCP endpoints: -- `/mcp/tools/list` - List all available tools -- `/mcp/tools/call` - Call a specific tool with the provided arguments +- GET `/mcp/enabled` - Returns if MCP is enabled (python>=3.10 requirements are met) +- GET `/mcp/tools/list` - List all available tools +- POST `/mcp/tools/call` - Call a specific tool with the provided arguments +- GET `/v1/mcp/server` - Returns all of the configured mcp servers in the db filtered by requestor's access +- GET `/v1/mcp/server/{server_id}` - Returns the the specific mcp server in the db given `server_id` filtered by requestor's access +- PUT `/v1/mcp/server` - Updates an existing external mcp server. +- POST `/v1/mcp/server` - Add a new external mcp server. +- DELETE `/v1/mcp/server/{server_id}` - Deletes the mcp server given `server_id`. When MCP clients connect to LiteLLM they can follow this workflow: diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py new file mode 100644 index 000000000000..cdeb97b363a0 --- /dev/null +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -0,0 +1,203 @@ +from typing import Iterable, List, Optional, Set + +import uuid + +from prisma.models import LiteLLM_MCPServerTable, LiteLLM_ObjectPermissionTable, LiteLLM_TeamTable +from litellm.proxy._types import NewMCPServerRequest, SpecialMCPServerName, UpdateMCPServerRequest, UserAPIKeyAuth +from litellm.proxy.utils import PrismaClient + + +async def get_all_mcp_servers(prisma_client: PrismaClient) -> List[LiteLLM_MCPServerTable]: + """ + Returns all of the mcp servers from the db + """ + mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() + + return mcp_servers + + +async def get_mcp_server(prisma_client: PrismaClient, server_id: str) -> Optional[LiteLLM_MCPServerTable]: + """ + Returns the matching mcp server from the db iff exists + """ + mcp_server: Optional[LiteLLM_MCPServerTable] = await prisma_client.db.litellm_mcpservertable.find_unique( + where={ + "server_id": server_id, + } + ) + return mcp_server + + +async def get_mcp_servers(prisma_client: PrismaClient, server_ids: Iterable[str]) -> List[LiteLLM_MCPServerTable]: + """ + Returns the matching mcp servers from the db with the server_ids + """ + mcp_servers: List[LiteLLM_MCPServerTable] = await prisma_client.db.litellm_mcpservertable.find_many( + where={ + "server_id": {"in": server_ids}, + } + ) + return mcp_servers + + +async def get_mcp_servers_by_verificationtoken(prisma_client: PrismaClient, token: str) -> List[str]: + """ + Returns the mcp servers from the db for the verification token + """ + verification_token_record: LiteLLM_TeamTable = await prisma_client.db.litellm_verificationtoken.find_unique( + where={ + "token": token, + }, + include={ + "object_permission": True, + }, + ) + + mcp_servers = [] + if verification_token_record is not None and verification_token_record.object_permission is not None: + mcp_servers = verification_token_record.object_permission.mcp_servers + return mcp_servers + + +async def get_mcp_servers_by_team(prisma_client: PrismaClient, team_id: str) -> List[str]: + """ + Returns the mcp servers from the db for the team id + """ + team_record: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.find_unique( + where={ + "team_id": team_id, + }, + include={ + "object_permission": True, + }, + ) + + mcp_servers = [] + if team_record is not None and team_record.object_permission is not None: + mcp_servers = team_record.object_permission.mcp_servers + return mcp_servers + + +async def get_all_mcp_servers_for_user( + prisma_client: PrismaClient, + user: UserAPIKeyAuth, +) -> List[LiteLLM_MCPServerTable]: + """ + Get all the mcp servers filtered by the given user has access to. + + Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to. + """ + + mcp_server_ids: Set[str] = set() + mcp_servers = [] + + # Get the mcp servers for the key + if user.api_key: + token_mcp_servers = await get_mcp_servers_by_verificationtoken(prisma_client, user.api_key) + mcp_server_ids.update(token_mcp_servers) + + # check for special team membership + if SpecialMCPServerName.all_team_servers in mcp_server_ids and user.team_id is not None: + team_mcp_servers = await get_mcp_servers_by_team(prisma_client, user.team_id) + mcp_server_ids.update(team_mcp_servers) + + if len(mcp_server_ids) > 0: + mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids) + + return mcp_servers + + +async def get_objectpermissions_for_mcp_server( + prisma_client: PrismaClient, mcp_server_id: str +) -> List[LiteLLM_ObjectPermissionTable]: + """ + Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server + """ + object_permission_records = await prisma_client.db.litellm_objectpermissiontable.find_many( + where={ + "mcp_servers": {"has": mcp_server_id}, + }, + include={ + "teams": True, + "verification_tokens": True, + }, + ) + + return object_permission_records + + +async def get_virtualkeys_for_mcp_server(prisma_client: PrismaClient, server_id: str) -> List: + """ + Get all the virtual keys that have access to the mcp server + """ + virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many( + where={ + "mcp_servers": {"has": server_id}, + }, + ) + + if virtual_keys is None: + return [] + return virtual_keys + + +async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str): + """ + Remove the mcp server from the team + """ + pass + + +async def delete_mcp_server_from_virtualkey(): + """ + Remove the mcp server from the virtual key + """ + pass + + +async def delete_mcp_server(prisma_client: PrismaClient, server_id: str) -> Optional[LiteLLM_MCPServerTable]: + """ + Delete the mcp server from the db by server_id + + Returns the deleted mcp server record if it exists, otherwise None + """ + deleted_server = await prisma_client.db.litellm_mcpservertable.delete( + where={ + "server_id": server_id, + }, + ) + return deleted_server + + +async def create_mcp_server(prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str) -> LiteLLM_MCPServerTable: + """ + Create a new mcp server record in the db + """ + if data.server_id is None: + data.server_id = str(uuid.uuid4()) + + mcp_server_record = await prisma_client.db.litellm_mcpservertable.create( + data={ + **data.model_dump(), + 'created_by': touched_by, + 'updated_by': touched_by, + } + ) + return mcp_server_record + + +async def update_mcp_server(prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str) -> LiteLLM_MCPServerTable: + """ + Update a new mcp server record in the db + """ + mcp_server_record = await prisma_client.db.litellm_mcpservertable.update( + where={ + 'server_id': data.server_id, + }, + data={ + **data.model_dump(), + 'created_by': touched_by, + 'updated_by': touched_by, + } + ) + return mcp_server_record \ No newline at end of file diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 9becb8075843..77cb889a1c0b 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -9,26 +9,32 @@ import asyncio import json from typing import Any, Dict, List, Optional +import uuid from mcp import ClientSession from mcp.client.sse import sse_client -from mcp.types import Tool as MCPTool +from mcp.types import CallToolResult, Tool as MCPTool +from prisma.models import LiteLLM_MCPServerTable from litellm._logging import verbose_logger -from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer - +from litellm.proxy._types import MCPSpecVersion, MCPTransport +from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer class MCPServerManager: def __init__(self): - self.mcp_servers: List[MCPSSEServer] = [] + self.registry: Dict[str, MCPServer] = {} + self.config_mcp_servers: Dict[str, MCPServer] = {} """ eg. [ - { + "server-1": { "name": "zapier_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" + "transport": "sse", + "auth_type": "api_key", + "spec_version": "2025-03-26" }, - { + "uuid-2": { "name": "google_drive_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" } @@ -41,28 +47,69 @@ def __init__(self): "gmail_send_email": "zapier_mcp_server", } """ + def get_registry(self) -> Dict[str, MCPServer]: + """ + Get the registered MCP Servers from the registry and union with the config MCP Servers + """ + return self.config_mcp_servers | self.registry def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]): """ Load the MCP Servers from the config """ + verbose_logger.debug("Loading MCP Servers from config-----") for server_name, server_config in mcp_servers_config.items(): _mcp_info: dict = server_config.get("mcp_info", None) or {} mcp_info = MCPInfo(**_mcp_info) mcp_info["server_name"] = server_name - self.mcp_servers.append( - MCPSSEServer( - name=server_name, - url=server_config["url"], - mcp_info=mcp_info, - ) + mcp_info["description"] = server_config.get("description", None) + new_server = MCPServer( + name=server_name, + url=server_config["url"], + # TODO: utility fn the default values + transport=server_config.get("transport", MCPTransport.sse), + spec_version=server_config.get("spec_version", MCPSpecVersion.mar_2025), + auth_type=server_config.get("auth_type", None), + mcp_info=mcp_info, ) + server_id = str(uuid.uuid4()) + self.config_mcp_servers[server_id] = new_server verbose_logger.debug( - f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}" + f"Loaded MCP Servers: {json.dumps(self.config_mcp_servers, indent=4, default=str)}" ) self.initialize_tool_name_to_mcp_server_name_mapping() + def remove_server(self, mcp_server: LiteLLM_MCPServerTable): + """ + Remove a server from the registry + """ + if mcp_server.alias in self.get_registry(): + del self.registry[mcp_server.alias] + verbose_logger.debug(f"Removed MCP Server: {mcp_server.alias}") + elif mcp_server.server_id in self.get_registry(): + del self.registry[mcp_server.server_id] + verbose_logger.debug(f"Removed MCP Server: {mcp_server.server_id}") + else: + verbose_logger.warning(f"Server ID {mcp_server.server_id} not found in registry") + + def add_update_server(self, mcp_server: LiteLLM_MCPServerTable): + if mcp_server.server_id not in self.get_registry(): + new_server = MCPServer( + name=mcp_server.alias or mcp_server.server_id, + url=mcp_server.url, + transport=mcp_server.transport, + spec_version=mcp_server.spec_version, + auth_type=mcp_server.auth_type, + mcp_info=MCPInfo( + server_name=mcp_server.alias or mcp_server.server_id, + description=mcp_server.description, + ), + ) + self.registry[mcp_server.server_id] = new_server + verbose_logger.debug(f"Added MCP Server: {mcp_server.alias or mcp_server.server_id}") + + async def list_tools(self) -> List[MCPTool]: """ List all tools available across all MCP Servers. @@ -71,39 +118,48 @@ async def list_tools(self) -> List[MCPTool]: List[MCPTool]: Combined list of tools from all servers """ list_tools_result: List[MCPTool] = [] - verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS") + verbose_logger.debug("SERVER MANAGER LISTING TOOLS") - for server in self.mcp_servers: + for _, server in self.get_registry().items(): tools = await self._get_tools_from_server(server) list_tools_result.extend(tools) return list_tools_result - async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]: + async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: """ Helper method to get tools from a single MCP server. Args: - server (MCPSSEServer): The server to query tools from + server (MCPServer): The server to query tools from Returns: List[MCPTool]: List of tools available on the server """ verbose_logger.debug(f"Connecting to url: {server.url}") - async with sse_client(url=server.url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - tools_result = await session.list_tools() - verbose_logger.debug(f"Tools from {server.name}: {tools_result}") - - # Update tool to server mapping - for tool in tools_result.tools: - self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name - - return tools_result.tools - + verbose_logger.info("_get_tools_from_server...") + # send transport to connect to the server + if server.transport is None or server.transport == MCPTransport.sse: + async with sse_client(url=server.url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + tools_result = await session.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools_result}") + + # Update tool to server mapping + for tool in tools_result.tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + + return tools_result.tools + elif server.transport == MCPTransport.http: + # TODO: implement http transport + return [] + else: + # TODO: throw error on transport found or skip + return [] + def initialize_tool_name_to_mcp_server_name_mapping(self): """ On startup, initialize the tool name to MCP server name mapping @@ -122,7 +178,7 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self): """ Call list_tools for each server and update the tool name to MCP server name mapping """ - for server in self.mcp_servers: + for server in self.get_registry().values(): tools = await self._get_tools_from_server(server) for tool in tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name @@ -134,17 +190,23 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]): mcp_server = self._get_mcp_server_from_tool_name(name) if mcp_server is None: raise ValueError(f"Tool {name} not found") - async with sse_client(url=mcp_server.url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - return await session.call_tool(name, arguments) - - def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]: + elif mcp_server.transport is None or mcp_server.transport == MCPTransport.sse: + async with sse_client(url=mcp_server.url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + return await session.call_tool(name, arguments) + elif mcp_server.transport == MCPTransport.http: + # TODO: implement http transport + raise NotImplementedError("HTTP transport is not implemented yet") + else: + return CallToolResult(content = [], isError=True) + + def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: """ Get the MCP Server from the tool name """ if tool_name in self.tool_name_to_mcp_server_name_mapping: - for server in self.mcp_servers: + for server in self.get_registry().values(): if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]: return server return None diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index fe1eccb048f8..b6a9122bb066 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -8,31 +8,48 @@ from anyio import BrokenResourceError from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse +from prisma.models import LiteLLM_MCPServerTable from pydantic import ConfigDict, ValidationError from litellm._logging import verbose_logger from litellm.constants import MCP_TOOL_NAME_PREFIX from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers, get_all_mcp_servers_for_user from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.proxy.management_endpoints.mcp_management_endpoints import get_prisma_client_or_throw from litellm.types.mcp_server.mcp_server_manager import MCPInfo from litellm.types.utils import StandardLoggingMCPToolCall from litellm.utils import client +router = APIRouter( + prefix="/mcp", + tags=["mcp"], +) + # Check if MCP is available # "mcp" requires python 3.10 or higher, but several litellm users use python 3.8 # We're making this conditional import to avoid breaking users who use python 3.8. +# TODO: Make this a util function for litellm client usage +MCP_AVAILABLE: bool = True try: from mcp.server import Server - - MCP_AVAILABLE = True except ImportError as e: verbose_logger.debug(f"MCP module not found: {e}") MCP_AVAILABLE = False - router = APIRouter( - prefix="/mcp", - tags=["mcp"], - ) + + +# Routes +@router.get( + "/enabled", + description="Returns if the MCP server is enabled", +) +def get_mcp_server_enabled() -> Dict[str, bool]: + """ + Returns if the MCP server is enabled + """ + return {"enabled": MCP_AVAILABLE} if MCP_AVAILABLE: @@ -63,10 +80,6 @@ class ListMCPToolsRestAPIResponseObject(MCPTool): ######################################################## ############ Initialize the MCP Server ################# ######################################################## - router = APIRouter( - prefix="/mcp", - tags=["mcp"], - ) server: Server = Server("litellm-mcp-server") sse: SseServerTransport = SseServerTransport("/mcp/sse/messages") @@ -93,9 +106,7 @@ async def _list_mcp_tools() -> List[MCPTool]: inputSchema=tool.input_schema, ) ) - verbose_logger.debug( - "GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools() - ) + verbose_logger.debug("GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()) sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools() verbose_logger.debug("SSE TOOLS: %s", sse_tools) if sse_tools is not None: @@ -134,28 +145,20 @@ async def call_mcp_tool( Call a specific tool with the provided arguments """ if arguments is None: - raise HTTPException( - status_code=400, detail="Request arguments are required" - ) + raise HTTPException(status_code=400, detail="Request arguments are required") - standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = ( - _get_standard_logging_mcp_tool_call( - name=name, - arguments=arguments, - ) - ) - litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( - "litellm_logging_obj", None + standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = _get_standard_logging_mcp_tool_call( + name=name, + arguments=arguments, ) + litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get("litellm_logging_obj", None) if litellm_logging_obj: - litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( - standard_logging_mcp_tool_call - ) + litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = standard_logging_mcp_tool_call litellm_logging_obj.model_call_details["model"] = ( f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}" ) - litellm_logging_obj.model_call_details["custom_llm_provider"] = ( - standard_logging_mcp_tool_call.get("mcp_server_name") + litellm_logging_obj.model_call_details["custom_llm_provider"] = standard_logging_mcp_tool_call.get( + "mcp_server_name" ) # Try managed server tool first @@ -235,7 +238,9 @@ async def handle_messages(request: Request): ############ MCP Server REST API Routes ################# ######################################################## @router.get("/tools/list", dependencies=[Depends(user_api_key_auth)]) - async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]: + async def list_tool_rest_api( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ) -> List[ListMCPToolsRestAPIResponseObject]: """ List all available tools with information about the server they belong to. @@ -262,8 +267,25 @@ async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]: } ] """ + # perform authz check to filter the mcp servers user has access to + prisma_client = get_prisma_client_or_throw("Database not connected. Connect a database to your proxy") + + db_mcp_servers: List[LiteLLM_MCPServerTable] = [] + + # Check the db for the mcp server list TODO: reuse same logic as in the mcp_endpoint + if _user_has_admin_view(user_api_key_dict): + db_mcp_servers = await get_all_mcp_servers(prisma_client) + else: + db_mcp_servers = await get_all_mcp_servers_for_user( + prisma_client, + user_api_key_dict, + ) + # ensure the global_mcp_server_manager is up to date with the db + for server in db_mcp_servers: + global_mcp_server_manager.add_update_server(server) + list_tools_result: List[ListMCPToolsRestAPIResponseObject] = [] - for server in global_mcp_server_manager.mcp_servers: + for server in global_mcp_server_manager.get_registry().values(): try: tools = await global_mcp_server_manager._get_tools_from_server(server) for tool in tools: @@ -306,4 +328,4 @@ async def call_tool_rest_api( notification_options=NotificationOptions(), experimental_capabilities={}, ), - ) + ) \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d66cdd4ad6b4..74d90cb979c2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -816,6 +816,54 @@ def check_potential_json_str(cls, values): pass return values +# MCP Types +class SpecialMCPServerName(str, enum.Enum): + all_team_servers = "all-team-mcpservers" + all_proxy_servers = "all-proxy-mcpservers" + + +class MCPTransport(str, enum.Enum): + sse = "sse" + http = "http" + + +class MCPSpecVersion(str, enum.Enum): + nov_2024 = "2024-11-05" + mar_2025 = "2025-03-26" + + +class MCPAuth(str, enum.Enum): + none = "none" + api_key = "api_key" + bearer_token = "bearer_token" + basic = "basic" + + +# MCP Literals +MCPTransportType = Literal[MCPTransport.sse, MCPTransport.http] +MCPSpecVersionType = Literal[MCPSpecVersion.nov_2024, MCPSpecVersion.mar_2025] +MCPAuthType = Optional[Literal[MCPAuth.none, MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic]] + + +# MCP Proxy Request Types +class NewMCPServerRequest(LiteLLMPydanticObjectBase): + server_id: Optional[str] = None + alias: Optional[str] = None + description: Optional[str] = None + transport: MCPTransportType = MCPTransport.sse + spec_version: MCPSpecVersionType = MCPSpecVersion.mar_2025 + auth_type: Optional[MCPAuthType] = None + url: str + +class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): + server_id: str + alias: Optional[str] = None + description: Optional[str] = None + transport: MCPTransportType = MCPTransport.sse + spec_version: MCPSpecVersionType = MCPSpecVersion.mar_2025 + auth_type: Optional[MCPAuthType] = None + url: str + class NewUserRequest(GenerateRequestBase): max_budget: Optional[float] = None diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 93bcb70a9005..5fae38cee65a 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -149,6 +149,8 @@ def non_proxy_admin_allowed_routes_check( route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value ): # routes that manage their own allowed/disallowed logic pass + elif route.startswith("/v1/mcp/"): + pass # authN/authZ handled by api itself else: user_role = "unknown" user_id = "unknown" diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py new file mode 100644 index 000000000000..7a55bedd94f3 --- /dev/null +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -0,0 +1,325 @@ +""" +1. Allow proxy admin to perform create, update, and delete operations on MCP servers in the db. +2. Allows users to view the mcp servers they have access to. + +Endpoints here: +- GET `/v1/mcp/server` - Returns all of the configured mcp servers in the db filtered by requestor's access +- GET `/v1/mcp/server/{server_id}` - Returns the the specific mcp server in the db given `server_id` filtered by requestor's access +- GET `/v1/mcp/server/{server_id}/tools` - Get all the tools from the mcp server specified by the `server_id` +- POST `/v1/mcp/server` - Add a new external mcp server. +- PUT `/v1/mcp/server` - Edits an existing mcp server. +- DELETE `/v1/mcp/server/{server_id}` - Deletes the mcp server given `server_id`. +""" + +from typing import Iterable, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Header, Response, status +from fastapi.responses import JSONResponse +from prisma.models import LiteLLM_MCPServerTable + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.constants import ( + LITELLM_PROXY_ADMIN_NAME, +) +from litellm.proxy._experimental.mcp_server.db import ( + create_mcp_server, + update_mcp_server, + delete_mcp_server, + get_all_mcp_servers, + get_all_mcp_servers_for_user, + get_mcp_server, +) +from litellm.proxy._types import ( + LitellmUserRoles, + NewMCPServerRequest, + UpdateMCPServerRequest, + SpecialMCPServerName, + UserAPIKeyAuth, +) +from litellm.proxy._experimental.mcp_server.mcp_server_manager import global_mcp_server_manager +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.proxy.management_helpers.utils import management_endpoint_wrapper + +router = APIRouter(prefix="/v1/mcp", tags=["mcp"]) + + +def get_prisma_client_or_throw(message: str): + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": message}, + ) + return prisma_client + + +def does_mcp_server_exist(mcp_server_records: Iterable[LiteLLM_MCPServerTable], mcp_server_id: str) -> bool: + """ + Check if the mcp server with the given id exists in the iterable of mcp servers + """ + for mcp_server_record in mcp_server_records: + if mcp_server_record.server_id == mcp_server_id: + return True + return False + + +## FastAPI Routes +@router.get( + "/server", + description="Returns the mcp server list", + dependencies=[Depends(user_api_key_auth)], + response_model=List[LiteLLM_MCPServerTable], +) +async def fetch_all_mcp_servers( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get all of the configured mcp servers for the user in the db + ``` + curl --location 'http://localhost:4000/v1/mcp/server' \ + --header 'Authorization: Bearer your_api_key_here' + ``` + """ + prisma_client = get_prisma_client_or_throw("Database not connected. Connect a database to your proxy") + + # perform authz check to filter the mcp servers user has access to + if _user_has_admin_view(user_api_key_dict): + return await get_all_mcp_servers(prisma_client) + + # Find all mcp servers the user has access to + return await get_all_mcp_servers_for_user(prisma_client, user_api_key_dict) + + +@router.get( + "/server/{server_id}", + description="Returns the mcp server info", + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_MCPServerTable, +) +async def fetch_mcp_server( + server_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get the info on the mcp server specified by the `server_id` + Parameters: + - server_id: str - Required. The unique identifier of the mcp server to get info on. + ``` + curl --location 'http://localhost:4000/v1/mcp/server/server_id' \ + --header 'Authorization: Bearer your_api_key_here' + ``` + """ + prisma_client = get_prisma_client_or_throw("Database not connected. Connect a database to your proxy") + + # check to see if server exists for all users + mcp_server = await get_mcp_server(prisma_client, server_id) + if mcp_server is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"MCP Server with id {server_id} not found"}, + ) + + # Implement authz restriction from requested user + if _user_has_admin_view(user_api_key_dict): + return mcp_server + + # Perform authz check to filter the mcp servers user has access to + mcp_server_records = await get_all_mcp_servers_for_user(prisma_client, user_api_key_dict) + exists = does_mcp_server_exist(mcp_server_records, server_id) + + if exists: + global_mcp_server_manager.add_update_server(mcp_server) + return mcp_server + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": f"User does not have permission to view mcp server with id {server_id}. You can only view mcp servers that you have access to." + }, + ) + +@router.post( + "/server", + description="Allows creation of mcp servers", + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_MCPServerTable, + status_code=status.HTTP_201_CREATED, +) +@management_endpoint_wrapper +async def add_mcp_server( + payload: NewMCPServerRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Allow users to add a new external mcp server. + """ + prisma_client = get_prisma_client_or_throw("Database not connected. Connect a database to your proxy") + + # AuthZ - restrict only proxy admins to create mcp servers + if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "User does not have permission to create mcp servers. You can only create mcp servers if you are a PROXY_ADMIN." + }, + ) + elif payload.server_id is not None: + # fail if the mcp server with id already exists + mcp_server = await get_mcp_server(prisma_client, payload.server_id) + if mcp_server is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"MCP Server with id {payload.server_id} already exists. Cannot create another."}, + ) + elif ( + SpecialMCPServerName.all_team_servers == payload.server_id + or SpecialMCPServerName.all_proxy_servers == payload.server_id + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"MCP Server with id {payload.server_id} is special and cannot be used."}, + ) + + # TODO: audit log for create + + # Attempt to create the mcp server + try: + new_mcp_server = await create_mcp_server( + prisma_client, payload, touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME + ) + global_mcp_server_manager.add_update_server(new_mcp_server) + except Exception as e: + verbose_proxy_logger.exception(f"Error creating mcp server: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": f"Error creating mcp server: {str(e)}"}, + ) + return new_mcp_server + + +@router.delete( + "/server/{server_id}", + description="Allows deleting mcp serves in the db", + dependencies=[Depends(user_api_key_auth)], + response_class=JSONResponse, + status_code=status.HTTP_202_ACCEPTED, +) +@management_endpoint_wrapper +async def remove_mcp_server( + server_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Delete MCP Server from db and associated MCP related server entities. + + Parameters: + - server_id: str - Required. The unique identifier of the mcp server to delete. + ``` + curl -X "DELETE" --location 'http://localhost:4000/v1/mcp/server/server_id' \ + --header 'Authorization: Bearer your_api_key_here' + ``` + """ + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + # Authz - restrict only admins to delete mcp servers + if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "Call not allowed to delete MCP server. User is not a proxy admin. route={}".format( + "DELETE /v1/mcp/server" + ) + }, + ) + + # try to delete the mcp server + mcp_server_record_deleted = await delete_mcp_server(prisma_client, server_id) + + if mcp_server_record_deleted is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"MCP Server not found, passed server_id={server_id}"}, + ) + global_mcp_server_manager.remove_server(mcp_server_record_deleted) + + # TODO: Enterprise: Finish audit log trail + if litellm.store_audit_logs: + pass + + # TODO: Delete from virtual keys + + # TODO: Delete from teams + + # Update from global mcp store + + return Response(status_code=status.HTTP_202_ACCEPTED) + +@router.put( + "/server", + description="Allows deleting mcp serves in the db", + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_MCPServerTable, + status_code=status.HTTP_202_ACCEPTED, +) +@management_endpoint_wrapper +async def edit_mcp_server( + payload: UpdateMCPServerRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Updates the MCP Server in the db. + + Parameters: + - payload: UpdateMCPServerRequest - Required. The updated mcp server data. + ``` + curl -X "PUT" --location 'http://localhost:4000/v1/mcp/server' \ + --header 'Authorization: Bearer your_api_key_here' + ``` + """ + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + # Authz - restrict only admins to delete mcp servers + if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "Call not allowed to update MCP server. User is not a proxy admin. route={}".format( + "PUT /v1/mcp/server" + ) + }, + ) + + # try to update the mcp server + mcp_server_record_updated = await update_mcp_server(prisma_client, payload, touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME) + + if mcp_server_record_updated is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"MCP Server not found, passed server_id={payload.server_id}"}, + ) + global_mcp_server_manager.add_update_server(mcp_server_record_updated) + + # TODO: Enterprise: Finish audit log trail + if litellm.store_audit_logs: + pass + + return mcp_server_record_updated \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6d689d258c7d..c1a1cee42491 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -248,6 +248,7 @@ def generate_feedback_box(): from litellm.proxy.management_endpoints.model_management_endpoints import ( router as model_management_router, ) +from litellm.proxy.management_endpoints.mcp_management_endpoints import router as mcp_management_router from litellm.proxy.management_endpoints.organization_endpoints import ( router as organization_router, ) @@ -8178,6 +8179,7 @@ async def get_routes(): app.include_router(credential_router) app.include_router(llm_passthrough_router) app.include_router(mcp_router) +app.include_router(mcp_management_router) app.include_router(anthropic_router) app.include_router(langfuse_router) app.include_router(pass_through_router) diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index aecd11aa1aeb..c35d4e533b38 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -3,14 +3,21 @@ from pydantic import BaseModel, ConfigDict from typing_extensions import TypedDict +from litellm.proxy._types import MCPAuthType, MCPTransportType, MCPSpecVersionType + class MCPInfo(TypedDict, total=False): server_name: str + description: Optional[str] logo_url: Optional[str] -class MCPSSEServer(BaseModel): +class MCPServer(BaseModel): name: str url: str + # TODO: alter the types to be the Literal explicit + transport: MCPTransportType + spec_version: MCPSpecVersionType + auth_type: Optional[MCPAuthType] = None mcp_info: Optional[MCPInfo] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/litellm/proxy/experimental/mcp_server/test_db.py b/tests/litellm/proxy/experimental/mcp_server/test_db.py new file mode 100644 index 000000000000..49fbc874e32a --- /dev/null +++ b/tests/litellm/proxy/experimental/mcp_server/test_db.py @@ -0,0 +1,14 @@ +import json +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.proxy._experimental.mcp_server.db import get_mcp_servers_by_team + +def test_fetch_mcp_servers_by_team(): + assert True == True \ No newline at end of file diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 2cf919387153..9659ab329776 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -9,7 +9,7 @@ from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, - MCPSSEServer, + MCPServer, ) diff --git a/tests/store_model_in_db_tests/test_mcp_servers.py b/tests/store_model_in_db_tests/test_mcp_servers.py new file mode 100644 index 000000000000..18d4d4d4ba03 --- /dev/null +++ b/tests/store_model_in_db_tests/test_mcp_servers.py @@ -0,0 +1,240 @@ +from datetime import datetime +from typing import List, Optional +from prisma.models import LiteLLM_MCPServerTable +import pytest +import uuid +from httpx import AsyncClient +import uuid +import os + +from starlette import status + +from litellm.constants import LITELLM_PROXY_ADMIN_NAME +from litellm.proxy._types import MCPAuth, MCPSpecVersion, MCPSpecVersionType, MCPTransportType, MCPTransport, NewMCPServerRequest +from litellm.proxy.management_endpoints.mcp_management_endpoints import does_mcp_server_exist + +TEST_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-1234") +PROXY_BASE_URL = os.getenv("PROXY_BASE_URL", "http://localhost:4000") + +def generate_mcpserver_record(url: Optional[str] = None, + transport: Optional[MCPTransportType] = None, + spec_version: Optional[MCPSpecVersionType] = None) -> LiteLLM_MCPServerTable: + """ + Generate a mock record for testing. + """ + now = datetime.now() + + return LiteLLM_MCPServerTable( + server_id=str(uuid.uuid4()),alias="Test Server",url=url or "http://localhost.com:8080/mcp",transport=transport or MCPTransport.sse,spec_version=spec_version or MCPSpecVersion.mar_2025,created_at=now,updated_at=now, + ) + +# Cheers SO +def is_valid_uuid(val): + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + +def generate_mcpserver_create_request( + server_id: Optional[str] = None, + url: Optional[str] = None, + transport: Optional[MCPTransportType] = None, + spec_version: Optional[MCPSpecVersionType] = None) -> NewMCPServerRequest: + """ + Generate a mock create request for testing. + """ + now = datetime.now() + + return NewMCPServerRequest(server_id=server_id, + alias="Test Server",url=url or "http://localhost.com:8080/mcp",transport=transport or MCPTransport.sse,spec_version=spec_version or MCPSpecVersion.mar_2025, + ) + +def get_http_client(): + """ + Create an HTTP client for making requests to the proxy server. + """ + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + # headers = {"Authorization": f"x-litellm-api-key {TEST_MASTER_KEY}"} + return AsyncClient(base_url=PROXY_BASE_URL), headers + +def assert_mcp_server_record_same(mcp_server: NewMCPServerRequest, resp: LiteLLM_MCPServerTable): + """ + Assert that the mcp server record is created correctly. + """ + if mcp_server.server_id is not None: + assert resp.server_id == mcp_server.server_id + else: + assert is_valid_uuid(resp.server_id) + assert resp.alias == mcp_server.alias + assert resp.url == mcp_server.url + assert resp.description == mcp_server.description + assert resp.transport == mcp_server.transport + assert resp.spec_version == mcp_server.spec_version + assert resp.auth_type == mcp_server.auth_type + assert resp.created_at is not None + assert resp.updated_at is not None + assert resp.created_by == LITELLM_PROXY_ADMIN_NAME + assert resp.updated_by == LITELLM_PROXY_ADMIN_NAME + + +def test_does_mcp_server_exist(): + """ + Unit Test if the MCP server exists in the list. + """ + mcp_server_records: List[LiteLLM_MCPServerTable] = [generate_mcpserver_record(), generate_mcpserver_record()] + # test all records are found + for record in mcp_server_records: + assert does_mcp_server_exist(mcp_server_records, record.server_id) + + # test record not found + not_found_record = str(uuid.uuid4()) + assert False == does_mcp_server_exist(mcp_server_records, not_found_record) + +@pytest.mark.asyncio +async def test_create_get_delete(): + """ + Integration Test mcp servers can be created and returned correctly. + 1. Create a new mcp server with server id + 2. Create another mcp server without server id + 2.1 Verify duplicate mcp server (server id) creation fails + 3. Verify first server has matching server id and second server has a new server id + 4. Verify both servers are in the full mcp server list + 5. Verify first server can be retrieved by server id + 6. Delete both mcp servers + 7. Verify both servers are no longer in the full mcp server list + 8. Verify both servers cannot be retrieved by server id + """ + # client, headers = AsyncClient(base_url=PROXY_BASE_URL), headers + client, headers = get_http_client() + + first_server_id = str(uuid.uuid4()) + first_server = generate_mcpserver_create_request(server_id=first_server_id) + + # Add new mcp server with server id + first_create_response = await client.post( + "/v1/mcp/server", + json=first_server.json(), + headers=headers, + ) + + # Validate that the response is as expected and the server is created + assert status.HTTP_201_CREATED == first_create_response.status_code + first_resp = LiteLLM_MCPServerTable(**first_create_response.json()) + assert_mcp_server_record_same(first_server, first_resp) + + # Create second mcp server without server id + second_server = generate_mcpserver_create_request() + second_create_response = await client.post( + "/v1/mcp/server", + json=second_server.json(), + headers=headers, + ) + assert status.HTTP_201_CREATED == second_create_response.status_code + second_resp = LiteLLM_MCPServerTable(**second_create_response.json()) + assert_mcp_server_record_same(second_server, second_resp) + + # Try to create a duplicate mcp server + duplicate_create_response = await client.post( + "/v1/mcp/server", + json=first_server.json(), + headers=headers, + ) + assert status.HTTP_400_BAD_REQUEST == duplicate_create_response.status_code + + # Validate that the servers are in the full mcp server list + get_all_mcp_servers_response = await client.get( + "/v1/mcp/server", + headers=headers, + ) + assert status.HTTP_200_OK == get_all_mcp_servers_response.status_code + mcp_servers = [ + LiteLLM_MCPServerTable(**record) for record in get_all_mcp_servers_response.json() + ] + assert len(mcp_servers) >= 2 + assert does_mcp_server_exist(mcp_servers, first_resp.server_id) + assert does_mcp_server_exist(mcp_servers, second_resp.server_id) + + # Validate that the first server can be retrieved by server id + get_mcp_server_response = await client.get( + f"/v1/mcp/server/{first_resp.server_id}", + headers=headers, + ) + assert status.HTTP_200_OK == get_mcp_server_response.status_code + resp = LiteLLM_MCPServerTable(**get_mcp_server_response.json()) + assert_mcp_server_record_same(first_server, resp) + + # Delete the mcp servers + delete_response = await client.delete( + f"/v1/mcp/server/{first_resp.server_id}", + headers=headers, + ) + assert status.HTTP_202_ACCEPTED == delete_response.status_code + delete_response = await client.delete( + f"/v1/mcp/server/{second_resp.server_id}", + headers=headers, + ) + assert status.HTTP_202_ACCEPTED == delete_response.status_code + + # Validate that the servers are no longer in the full list + get_all_mcp_servers_response = await client.get( + "/v1/mcp/server", + headers=headers, + ) + assert status.HTTP_200_OK == get_all_mcp_servers_response.status_code + mcp_servers = [ + LiteLLM_MCPServerTable(**record) for record in get_all_mcp_servers_response.json() + ] + assert not does_mcp_server_exist(mcp_servers, first_resp.server_id) + assert not does_mcp_server_exist(mcp_servers, second_resp.server_id) + + # Validate that both servers cannot be retrieved by server id + for server_id in [first_resp.server_id, second_resp.server_id]: + get_mcp_server_response = await client.get( + f"/v1/mcp/server/{server_id}", + headers=headers, + ) + assert status.HTTP_404_NOT_FOUND == get_mcp_server_response.status_code + +@pytest.mark.asyncio +async def test_edit(): + """ + Integration Test mcp servers can be created and edited correctly. + 1. Create a new mcp server + 2. Edit the server id + 3. Verify the mcp server's data is updated + """ + # client, headers = AsyncClient(base_url=PROXY_BASE_URL), headers + client, headers = get_http_client() + + mcp_server_request = generate_mcpserver_create_request() + + # Add new mcp server with server id + first_create_response = await client.post( + "/v1/mcp/server", + json=mcp_server_request.json(), + headers=headers, + ) + + # Validate that the response is as expected and the server is created + assert status.HTTP_201_CREATED == first_create_response.status_code + mcp_server_response = LiteLLM_MCPServerTable(**first_create_response.json()) + assert_mcp_server_record_same(mcp_server_request, mcp_server_response) + + # Update the mcp server + mcp_server_request.server_id = mcp_server_response.server_id + mcp_server_request.spec_version = MCPSpecVersion.nov_2024 + mcp_server_request.transport = MCPTransport.http + mcp_server_request.description = "Some updated description" + mcp_server_request.url = "http://localhost.com:4040/mcp" + mcp_server_request.auth_type = MCPAuth.basic + + # Try to edit the mcp server + updated_response = await client.put( + "/v1/mcp/server", + json=mcp_server_request.json(), + headers=headers, + ) + assert status.HTTP_202_ACCEPTED == updated_response.status_code + updated_server = LiteLLM_MCPServerTable(**updated_response.json()) + assert_mcp_server_record_same(mcp_server_request, updated_server) diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index eb8a7a8df692..a6123eaf7fcd 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -32,7 +32,7 @@ import GuardrailsPanel from "@/components/guardrails"; import TransformRequestPanel from "@/components/transform_request"; import { fetchUserModels } from "@/components/create_key_button"; import { fetchTeams } from "@/components/common_components/fetch_teams"; -import MCPToolsViewer from "@/components/mcp_tools"; +import { MCPToolsViewer, MCPServers } from "@/components/mcp_tools"; import TagManagement from "@/components/tag_management"; import VectorStoreManagement from "@/components/vector_store_management"; import { UiLoadingSpinner } from "@/components/ui/ui-loading-spinner"; @@ -362,7 +362,7 @@ export default function CreateKeyPage() { ): page == "transform-request" ? ( - ): page == "general-settings" ? ( + ) : page == "general-settings" ? ( - ) : page == "mcp-tools" ? ( - - ) : - ( + ) : ( = ({ { key: "9", page: "caching", label: "Caching", icon: , roles: all_admin_roles }, { key: "10", page: "budgets", label: "Budgets", icon: , roles: all_admin_roles }, { key: "20", page: "transform-request", label: "API Playground", icon: , roles: [...all_admin_roles, ...internalUserRoles] }, - { key: "18", page: "mcp-tools", label: "MCP Tools", icon: , roles: all_admin_roles }, + { key: "18", page: "mcp-servers", label: "MCP Servers", icon: , roles: all_admin_roles }, { key: "19", page: "tag-management", label: "Tag Management", icon: , roles: all_admin_roles }, { key: "21", page: "vector-stores", label: "Vector Stores", icon: , roles: all_admin_roles }, { key: "4", page: "usage", label: "Old Usage", icon: }, diff --git a/ui/litellm-dashboard/src/components/mcp_tools/index.tsx b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx index c403b4966801..cfd767ec1030 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/index.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx @@ -4,6 +4,7 @@ import { DataTable } from '../view_logs/table'; import { columns, ToolTestPanel } from './columns'; import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from './types'; import { listMCPTools, callMCPTool } from '../networking'; +import MCPServers from './mcp_servers'; // Wrapper to handle the type mismatch between MCPTool and DataTable's expected type function DataTableWrapper({ @@ -32,11 +33,11 @@ function DataTableWrapper({ ); } -export default function MCPToolsViewer({ +const MCPToolsViewer = ({ accessToken, userRole, userID, -}: MCPToolsViewerProps) { +}: MCPToolsViewerProps) => { const [searchTerm, setSearchTerm] = useState(''); const [selectedTool, setSelectedTool] = useState(null); const [toolResult, setToolResult] = useState(null); @@ -92,7 +93,7 @@ export default function MCPToolsViewer({ const searchLower = searchTerm.toLowerCase(); return ( tool.name.toLowerCase().includes(searchLower) || - tool.description.toLowerCase().includes(searchLower) || + (tool.description != null && tool.description.toLowerCase().includes(searchLower)) || tool.mcp_info.server_name.toLowerCase().includes(searchLower) ); }); @@ -171,4 +172,6 @@ export default function MCPToolsViewer({ )} ); -} \ No newline at end of file +} + +export { MCPToolsViewer, MCPServers }; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx new file mode 100644 index 000000000000..b6e5c89d4044 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx @@ -0,0 +1,107 @@ +import React from "react"; + +import { + Title, + Card, + Button, + Text, + Grid, + TabGroup, + TabList, + TabPanel, + TabPanels, + Tab, +} from "@tremor/react"; + +import { MCPServer, handleTransport, handleAuth } from "./types"; +// TODO: Move Tools viewer from index file +import { MCPToolsViewer } from "."; + +interface MCPServerViewProps { + mcpServer: MCPServer; + onBack: () => void; + isProxyAdmin: boolean; + isEditing: boolean; + accessToken: string | null; + userRole: string | null; + userID: string | null; +} + +export const MCPServerView: React.FC = ({ + mcpServer, + onBack, + isEditing, + isProxyAdmin, + accessToken, + userRole, + userID, +}) => { + return ( +
+
+
+ + {mcpServer.alias} + {mcpServer.server_id} +
+
+ + {/* TODO: magic number for index */} + + + {[ + Overview, + MCP Tools, + ...(isProxyAdmin ? [Settings] : []), + ]} + + + + {/* Overview Panel */} + + + + Transport +
+ {handleTransport(mcpServer.transport)} +
+
+ + + Auth Type +
+ {handleAuth(mcpServer.auth_type)} +
+
+ + + Host Url +
{mcpServer.url}
+
+
+
+ + {/* Tool Panel */} + + + + + {/* Settings Panel */} + + +
+ Editing MCP Servers coming soon! +
+
+
+
+
+
+ ); +}; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx new file mode 100644 index 000000000000..6b5a66c11692 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx @@ -0,0 +1,449 @@ +import React, { useState } from "react"; +import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; + +import { PencilAltIcon, TrashIcon } from "@heroicons/react/outline"; + +import { + Modal, + Tooltip, + Form, + Select, + message, + Button as AntdButton, +} from "antd"; +import { InfoCircleOutlined } from "@ant-design/icons"; + +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeaderCell, + TableRow, + Icon, + Button, + Grid, + Col, + Title, + TextInput, +} from "@tremor/react"; + +import { + createMCPServer, + deleteMCPServer, + fetchMCPServers, +} from "../networking"; +import { + MCPServer, + MCPServerProps, + handleAuth, + handleTransport, +} from "./types"; +import { isAdminRole } from "@/utils/roles"; +import { MCPServerView } from "./mcp_server_view"; + +const displayFriendlyId = (id: string) => { + return `${id.slice(0, 7)}...`; +}; + +interface CreateMCPServerProps { + userRole: string; + accessToken: string | null; + onCreateSuccess: (newMcpServer: MCPServer) => void; +} + +const CreateMCPServer: React.FC = ({ + userRole, + accessToken, + onCreateSuccess, +}) => { + const [form] = Form.useForm(); + + const handleCreate = async (formValues: Record) => { + try { + console.log(`formValues: ${JSON.stringify(formValues)}`); + + if (accessToken != null) { + const response: MCPServer = await createMCPServer( + accessToken, + formValues + ); + + message.success("MCP Server created successfully"); + form.resetFields(); + setModalVisible(false); + onCreateSuccess(response); + } + } catch (error) { + message.error("Error creating the team: " + error, 20); + } + }; + + // state + const [isModalVisible, setModalVisible] = useState(false); + + // rendering + if (!isAdminRole(userRole)) { + return null; + } + + return ( +
+ + + setModalVisible(false)} + okButtonProps={{ style: { display: "none" } }} + > +
+ <> + + + + + + + + + + + + + + + + + MCP Version{" "} + + + + + } + name="spec_version" + rules={[ + { required: true, message: "Please enter a spec version" }, + ]} + > + + + +
+ Create MCP Server +
+
+
+
+ ); +}; + +interface DeleteModalProps { + isModalOpen: boolean; + title: string; + confirmDelete: () => void; + cancelDelete: () => void; +} + +const DeleteModal: React.FC = ({ + isModalOpen, + title, + confirmDelete, + cancelDelete, +}) => { + if (!isModalOpen) return null; + + return ( + + + {title} + +

Are you sure you want to delete this MCP Server?

+ +
+
+ ); +}; + +const MCPServers: React.FC = ({ + accessToken, + userRole, + userID, +}) => { + // Query to fetch MCP tools + const { + data: mcpServers, + isLoading: isLoadingServers, + refetch, + } = useQuery({ + queryKey: ["mcpServers"], + queryFn: () => { + if (!accessToken) throw new Error("Access Token required"); + return fetchMCPServers(accessToken); + }, + enabled: !!accessToken, + }); + + const createMCPServer = (newMcpServer: MCPServer) => { + refetch(); + }; + + // state + const [serverIdToDelete, setServerToDelete] = useState(null); + const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); + const [selectedServerId, setSelectedServerId] = useState(null); + const [editServer, setEditServer] = useState(false); + + const handleDelete = (server_id: string) => { + // Set the team to delete and open the confirmation modal + setServerToDelete(server_id); + setIsDeleteModalOpen(true); + }; + + const confirmDelete = async () => { + if (serverIdToDelete == null || accessToken == null) { + return; + } + + try { + await deleteMCPServer(accessToken, serverIdToDelete); + // Successfully completed the deletion. Update the state to trigger a rerender. + message.success("Deleted MCP Server successfully"); + refetch(); + } catch (error) { + console.error("Error deleting the mcp server:", error); + // Handle any error situations, such as displaying an error message to the user. + } + + // Close the confirmation modal and reset the serverToDelete + setIsDeleteModalOpen(false); + setServerToDelete(null); + }; + + const cancelDelete = () => { + // Close the confirmation modal and reset the serverToDelete + setIsDeleteModalOpen(false); + setServerToDelete(null); + }; + + if (!accessToken || !userRole || !userID) { + return ( +
+ Missing required authentication parameters. +
+ ); + } + + return ( +
+ {selectedServerId ? ( + server.server_id === selectedServerId + ) || {} + } + onBack={() => setSelectedServerId(null)} + isProxyAdmin={isAdminRole(userRole)} + isEditing={editServer} + accessToken={accessToken} + userID={userID} + userRole={userRole} + /> + ) : ( +
+
+

MCP Servers

+
+ + + + Server ID + Server Name + Description + Transport + Auth Type + Url + Created + Info + + + + + {!mcpServers || mcpServers.length == 0 + ? [] + : mcpServers.map((mcpServer: MCPServer) => ( + + +
+ + + +
+
+ + {mcpServer.alias} + + + {mcpServer.description} + + + {handleTransport(mcpServer.transport)} + + + {handleAuth(mcpServer.auth_type)} + + +
+ + {mcpServer.url} + +
+
+ + {mcpServer.created_at + ? new Date(mcpServer.created_at).toLocaleDateString() + : "N/A"} + + + {isAdminRole(userRole) ? ( + <> + { + setSelectedServerId(mcpServer.server_id); + setEditServer(true); + }} + /> + handleDelete(mcpServer.server_id)} + icon={TrashIcon} + size="sm" + /> + + ) : null} + +
+ ))} +
+
+ + +
+ )} +
+ ); +}; + +export default MCPServers; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx new file mode 100644 index 000000000000..38c537ecbcb0 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx @@ -0,0 +1,4 @@ +import { MCPToolsViewer } from "./index"; + +// TODO: Move Tools viewer from index file to this file +export default MCPToolsViewer; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx index 7bbb76fa2372..d3ac89c826b9 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx @@ -1,3 +1,25 @@ +export const TRANSPORT = { + SSE: "sse", + HTTP: "http", +}; + +export const handleTransport = (transport?: string): string => { + console.log(transport) + if (transport === null || transport === undefined) { + return TRANSPORT.SSE; + } + + return transport; +}; + +export const handleAuth = (authType?: string): string => { + if (authType === null || authType === undefined) { + return "none"; + } + + return authType; +}; + // Define the structure for tool input schema properties export interface InputSchemaProperty { type: string; @@ -20,7 +42,7 @@ export interface InputSchemaProperty { // Define the structure for a single MCP tool export interface MCPTool { name: string; - description: string; + description?: string; inputSchema: InputSchema | string; // API returns string "tool_input_schema" or the actual schema mcp_info: MCPInfo; // Function to select a tool (added in the component) @@ -68,4 +90,24 @@ export interface InputSchemaProperty { accessToken: string | null; userRole: string | null; userID: string | null; - } \ No newline at end of file + } + +export interface MCPServer { + server_id: string; + alias?: string | null; + description?: string | null; + url: string; + transport?: string | null; + spec_version?: string | null; + auth_type?: string | null; + created_at: string; + created_by: string; + updated_at: string; + updated_by: string; +} + +export interface MCPServerProps { + accessToken: string | null; + userRole: string | null; + userID: string | null; +} \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index ba9d1cb34a50..5f8a0cdc6bc0 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -14,6 +14,13 @@ if (isLocal != true) { console.log = function() {}; } +const HTTP_REQUEST = { + GET: "GET", + POST: "POST", + PUT: "PUT", + DELETE: "DELETE", +}; + export const DEFAULT_ORGANIZATION = "default_organization"; export interface Model { @@ -4458,6 +4465,99 @@ export const updateInternalUserSettings = async (accessToken: string, settings: } }; +export const fetchMCPServers = async (accessToken: string) => { + try { + // Construct base URL + const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/mcp/server` : `/v1/mcp/server`; + + console.log("Fetching MCP servers from:", url); + + const response = await fetch(url, { + method: HTTP_REQUEST.GET, + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error("Network response was not ok"); + } + + const data = await response.json(); + console.log("Fetched MCP servers:", data); + return data; + } catch (error) { + console.error("Failed to fetch MCP servers:", error); + throw error; + } +}; + +export const createMCPServer = async ( + accessToken: string, + formValues: Record // Assuming formValues is an object +) => { + try { + console.log("Form Values in createMCPServer:", formValues); // Log the form values before making the API call + + const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/mcp/server` : `/v1/mcp/server`; + + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + ...formValues, // Include formValues in the request body + }), + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + console.error("Error response from the server:", errorData); + throw new Error("Network response was not ok"); + } + + const data = await response.json(); + console.log("API Response:", data); + return data; + // Handle success - you might want to update some state or UI based on the created key + } catch (error) { + console.error("Failed to create key:", error); + throw error; + } +}; + +export const deleteMCPServer = async ( + accessToken: String, + serverId: String +) => { + try { + const url = + (proxyBaseUrl ? `${proxyBaseUrl}` : "") + `/v1/mcp/server/${serverId}`; + console.log("in deleteMCPServer:", serverId); + const response = await fetch(url, { + method: HTTP_REQUEST.DELETE, + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error("Network response was not ok"); + } + } catch (error) { + console.error("Failed to delete key:", error); + throw error; + } +}; export const listMCPTools = async (accessToken: string) => { try {