Skip to content

feat: MCP Servers with CRUD operations #10699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-litellm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ jobs:

- name: Run tests
run: |
poetry run prisma generate
poetry run pytest tests/litellm -x -vv -n 4
13 changes: 11 additions & 2 deletions docs/my-website/docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ This allows you to define tools that can be called by any MCP compatible client.

#### How it works

1. Allow proxy admin users to perform create, update, and delete operations on MCP servers stored in the db.
2. Allows users to view and call tools to the MCP servers they have access to.

LiteLLM exposes the following MCP endpoints:

- `/mcp/tools/list` - List all available tools
- `/mcp/tools/call` - Call a specific tool with the provided arguments
- GET `/mcp/enabled` - Returns if MCP is enabled (python>=3.10 requirements are met)
- GET `/mcp/tools/list` - List all available tools
- POST `/mcp/tools/call` - Call a specific tool with the provided arguments
- GET `/v1/mcp/server` - Returns all of the configured mcp servers in the db filtered by requestor's access
- GET `/v1/mcp/server/{server_id}` - Returns the the specific mcp server in the db given `server_id` filtered by requestor's access
- GET `/v1/mcp/server/{server_id}/tools` - Get all the tools from the mcp server specified by the `server_id`
- POST `/v1/mcp/server` - Add a new external mcp server.
- DELETE `/v1/mcp/server/{server_id}` - Deletes the mcp server given `server_id`.

When MCP clients connect to LiteLLM they can follow this workflow:

Expand Down
186 changes: 186 additions & 0 deletions litellm/proxy/_experimental/mcp_server/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from typing import Dict, Iterable, List, Literal, Optional, Set, Union, cast

import uuid

from prisma.models import LiteLLM_MCPServerTable, LiteLLM_ObjectPermissionTable, LiteLLM_TeamTable
from litellm.proxy._types import NewMCPServerRequest, SpecialMCPServerName, UserAPIKeyAuth
from litellm.proxy.utils import PrismaClient


async def get_all_mcp_servers(prisma_client: PrismaClient) -> List[LiteLLM_MCPServerTable]:
"""
Returns all of the mcp servers from the db
"""
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()

return mcp_servers


async def get_mcp_server(prisma_client: PrismaClient, server_id: str) -> Optional[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp server from the db iff exists
"""
mcp_server: Optional[LiteLLM_MCPServerTable] = await prisma_client.db.litellm_mcpservertable.find_unique(
where={
"server_id": server_id,
}
)
return mcp_server


async def get_mcp_servers(prisma_client: PrismaClient, server_ids: Iterable[str]) -> List[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp servers from the db with the server_ids
"""
mcp_servers: List[LiteLLM_MCPServerTable] = await prisma_client.db.litellm_mcpservertable.find_many(
where={
"server_id": {"in": server_ids},
}
)
return mcp_servers


async def get_mcp_servers_by_verificationtoken(prisma_client: PrismaClient, token: str) -> List[str]:
"""
Returns the mcp servers from the db for the verification token
"""
verification_token_record: LiteLLM_TeamTable = await prisma_client.db.litellm_verificationtoken.find_unique(
where={
"token": token,
},
include={
"object_permission": True,
},
)

mcp_servers = []
if verification_token_record is not None and verification_token_record.object_permission is not None:
mcp_servers = verification_token_record.object_permission.mcp_servers
return mcp_servers


async def get_mcp_servers_by_team(prisma_client: PrismaClient, team_id: str) -> List[str]:
"""
Returns the mcp servers from the db for the team id
"""
team_record: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
},
include={
"object_permission": True,
},
)

mcp_servers = []
if team_record is not None and team_record.object_permission is not None:
mcp_servers = team_record.object_permission.mcp_servers
return mcp_servers


async def get_all_mcp_servers_for_user(
prisma_client: PrismaClient,
user: UserAPIKeyAuth,
) -> List[LiteLLM_MCPServerTable]:
"""
Get all the mcp servers filtered by the given user has access to.

Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to.
"""

mcp_server_ids: Set[str] = set()
mcp_servers = []

# Get the mcp servers for the key
if user.api_key:
token_mcp_servers = await get_mcp_servers_by_verificationtoken(prisma_client, user.api_key)
mcp_server_ids.update(token_mcp_servers)

# check for special team membership
if SpecialMCPServerName.all_team_servers in mcp_server_ids and user.team_id is not None:
team_mcp_servers = await get_mcp_servers_by_team(prisma_client, user.team_id)
mcp_server_ids.update(team_mcp_servers)

if len(mcp_server_ids) > 0:
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids)

return mcp_servers


async def get_objectpermissions_for_mcp_server(
prisma_client: PrismaClient, mcp_server_id: str
) -> List[LiteLLM_ObjectPermissionTable]:
"""
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
"""
object_permission_records = await prisma_client.db.litellm_objectpermissiontable.find_many(
where={
"mcp_servers": {"has": mcp_server_id},
},
include={
"teams": True,
"verification_tokens": True,
},
)

return object_permission_records


async def get_virtualkeys_for_mcp_server(prisma_client: PrismaClient, server_id: str) -> List:
"""
Get all the virtual keys that have access to the mcp server
"""
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
where={
"mcp_servers": {"has": server_id},
},
)

if virtual_keys is None:
return []
return virtual_keys


async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str):
"""
Remove the mcp server from the team
"""
pass


async def delete_mcp_server_from_virtualkey():
"""
Remove the mcp server from the virtual key
"""
pass


async def delete_mcp_server(prisma_client: PrismaClient, server_id: str) -> Optional[LiteLLM_MCPServerTable]:
"""
Delete the mcp server from the db by server_id

Returns the deleted mcp server record if it exists, otherwise None
"""
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
where={
"server_id": server_id,
},
)
return deleted_server


async def create_mcp_server(prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str) -> LiteLLM_MCPServerTable:
if data.server_id is None:
data.server_id = str(uuid.uuid4())

"""
Create a new mcp server record in the db
"""
mcp_server_record = await prisma_client.db.litellm_mcpservertable.create(
data={
**data.model_dump(),
'created_by': touched_by,
'updated_by': touched_by,
}
)
return mcp_server_record
62 changes: 31 additions & 31 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from typing import Any, Dict, List, Optional, Union

from anyio import BrokenResourceError
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from pydantic import ConfigDict, ValidationError

from litellm._version import version
from litellm._logging import verbose_logger
from litellm.constants import MCP_TOOL_NAME_PREFIX
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
Expand All @@ -19,20 +20,33 @@
from litellm.types.utils import StandardLoggingMCPToolCall
from litellm.utils import client

router = APIRouter(
prefix="/mcp",
tags=["mcp"],
)

# Check if MCP is available
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
# We're making this conditional import to avoid breaking users who use python 3.8.
# TODO: Make this a util function for litellm client usage
MCP_AVAILABLE: bool = True
try:
from mcp.server import Server

MCP_AVAILABLE = True
except ImportError as e:
verbose_logger.debug(f"MCP module not found: {e}")
MCP_AVAILABLE = False
router = APIRouter(
prefix="/mcp",
tags=["mcp"],
)


# Routes
@router.get(
"/enabled",
description="Returns if the MCP server is enabled",
)
def get_mcp_server_enabled() -> Dict[str, bool]:
"""
Returns if the MCP server is enabled
"""
return {"enabled": MCP_AVAILABLE}


if MCP_AVAILABLE:
Expand Down Expand Up @@ -63,10 +77,6 @@ class ListMCPToolsRestAPIResponseObject(MCPTool):
########################################################
############ Initialize the MCP Server #################
########################################################
router = APIRouter(
prefix="/mcp",
tags=["mcp"],
)
server: Server = Server("litellm-mcp-server")
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")

Expand All @@ -93,9 +103,7 @@ async def _list_mcp_tools() -> List[MCPTool]:
inputSchema=tool.input_schema,
)
)
verbose_logger.debug(
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
)
verbose_logger.debug("GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools())
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
if sse_tools is not None:
Expand Down Expand Up @@ -134,28 +142,20 @@ async def call_mcp_tool(
Call a specific tool with the provided arguments
"""
if arguments is None:
raise HTTPException(
status_code=400, detail="Request arguments are required"
)
raise HTTPException(status_code=400, detail="Request arguments are required")

standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
_get_standard_logging_mcp_tool_call(
name=name,
arguments=arguments,
)
)
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
"litellm_logging_obj", None
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = _get_standard_logging_mcp_tool_call(
name=name,
arguments=arguments,
)
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get("litellm_logging_obj", None)
if litellm_logging_obj:
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
standard_logging_mcp_tool_call
)
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = standard_logging_mcp_tool_call
litellm_logging_obj.model_call_details["model"] = (
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
)
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
standard_logging_mcp_tool_call.get("mcp_server_name")
litellm_logging_obj.model_call_details["custom_llm_provider"] = standard_logging_mcp_tool_call.get(
"mcp_server_name"
)

# Try managed server tool first
Expand Down Expand Up @@ -306,4 +306,4 @@ async def call_tool_rest_api(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
)
)
39 changes: 39 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,45 @@ def check_potential_json_str(cls, values):
pass
return values

# MCP Types
class SpecialMCPServerName(str, enum.Enum):
all_team_servers = "all-team-mcpservers"
all_proxy_servers = "all-proxy-mcpservers"


class MCPTransport(str, enum.Enum):
sse = "sse"
http = "http"


class MCPSpecVersion(str, enum.Enum):
nov_2024 = "2024-11-05"
mar_2025 = "2025-03-26"


class MCPAuth(str, enum.Enum):
none = "none"
api_key = "api_key"
bearer_token = "bearer_token"
basic = "basic"


# MCP Literals
MCPTransportType = Literal[MCPTransport.sse, MCPTransport.http]
MCPSpecVersionType = Literal[MCPSpecVersion.nov_2024, MCPSpecVersion.mar_2025]
MCPAuthType = Literal[MCPAuth.none, MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic]


# MCP Proxy Request Types
class NewMCPServerRequest(LiteLLMPydanticObjectBase):
server_id: Optional[str] = None
alias: Optional[str] = None
description: Optional[str] = None
transport: MCPTransportType = MCPTransport.sse
spec_version: MCPSpecVersionType = MCPSpecVersion.mar_2025
auth_type: Optional[MCPAuthType] = None
url: str


class NewUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None
Expand Down
2 changes: 2 additions & 0 deletions litellm/proxy/auth/route_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def non_proxy_admin_allowed_routes_check(
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
): # routes that manage their own allowed/disallowed logic
pass
elif route.startswith("/v1/mcp/"):
pass # authN/authZ handled by api itself
else:
user_role = "unknown"
user_id = "unknown"
Expand Down
Loading
Loading