Skip to content

Commit 2040889

Browse files
added sampling support
- add sampling - vendor ClientSession from mcp - improve stdio client - add vscode launch profile
1 parent 7d85cfe commit 2040889

File tree

12 files changed

+473
-13
lines changed

12 files changed

+473
-13
lines changed

.vscode/launch.json

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Python Debugger: Module",
9+
"type": "debugpy",
10+
"request": "launch",
11+
"django": true,
12+
"module": "mcp_bridge.main",
13+
}
14+
]
15+
}

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ working features:
2323
- non streaming completions without MCP
2424

2525
- MCP tools
26+
- MCP sampling
2627

2728
- SSE Bridge for external clients
2829

@@ -129,6 +130,23 @@ an example config.json file with most of the options explicitly set:
129130
"base_url": "http://localhost:8000/v1",
130131
"api_key": "None"
131132
},
133+
"sampling": {
134+
"timeout": 10,
135+
"models": [
136+
{
137+
"model": "gpt-4o",
138+
"intelligence": 0.8,
139+
"cost": 0.9,
140+
"speed": 0.3
141+
},
142+
{
143+
"model": "gpt-4o-mini",
144+
"intelligence": 0.4,
145+
"cost": 0.1,
146+
"speed": 0.7
147+
}
148+
]
149+
},
132150
"mcp_servers": {
133151
"fetch": {
134152
"command": "uvx",

mcp_bridge/config/final.py

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ class Logging(BaseModel):
2020
log_server_pings: bool = Field(False, description="log server pings")
2121

2222

23+
class SamplingModel(BaseModel):
24+
model: Annotated[str, Field(description="Name of the sampling model")]
25+
26+
intelligence: Annotated[float, Field(description="Intelligence of the sampling model")] = 0.5
27+
cost: Annotated[float, Field(description="Cost of the sampling model")] = 0.5
28+
speed: Annotated[float, Field(description="Speed of the sampling model")] = 0.5
29+
30+
31+
class Sampling(BaseModel):
32+
timeout: Annotated[int, Field(description="Timeout for sampling requests")] = 10
33+
models: Annotated[list[SamplingModel], Field(description="List of sampling models")] = []
34+
2335
class SSEMCPServer(BaseModel):
2436
# TODO: expand this once I find a good definition for this
2537
url: str = Field(description="URL of the MCP server")
@@ -52,6 +64,11 @@ class Settings(BaseSettings):
5264
default_factory=dict, description="MCP servers configuration"
5365
)
5466

67+
sampling: Sampling = Field(
68+
default_factory=lambda: Sampling.model_construct(),
69+
description="sampling config",
70+
)
71+
5572
logging: Logging = Field(
5673
default_factory=lambda: Logging.model_construct(),
5774
description="logging config",

mcp_bridge/mcp_clients/AbstractClient.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import Any, Optional
44
from fastapi import HTTPException
5-
from mcp import ClientSession, McpError
5+
from mcp import McpError
66
from mcp.types import (
77
CallToolResult,
88
ListToolsResult,
@@ -15,14 +15,15 @@
1515
)
1616
from loguru import logger
1717
from pydantic import AnyUrl
18+
from mcp_bridge.mcp_clients.session import McpClientSession
1819
from mcp_bridge.models.mcpServerStatus import McpServerStatus
1920

2021

2122
class GenericMcpClient(ABC):
2223
name: str
2324
config: Any
2425
client: Any
25-
session: ClientSession | None = None
26+
session: McpClientSession | None = None
2627

2728
def __init__(self, name: str) -> None:
2829
super().__init__()
@@ -39,8 +40,10 @@ async def _session_maintainer(self):
3940
while True:
4041
try:
4142
await self._maintain_session()
43+
except FileNotFoundError as e:
44+
logger.error(f"failed to maintain session for {self.name}: file {e.filename} not found.")
4245
except Exception as e:
43-
logger.trace(f"failed to maintain session for {self.name}: {e}")
46+
logger.error(f"failed to maintain session for {self.name}: {type(e)} {e.args}")
4447

4548
logger.debug(f"restarting session for {self.name}")
4649
await asyncio.sleep(0.5)

mcp_bridge/mcp_clients/DockerClient.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
from mcp import ClientSession
3+
4+
from mcp_bridge.mcp_clients.session import McpClientSession
35
from .transports.docker import docker_client
46
from mcp_bridge.config import config
57
from mcp_bridge.config.final import DockerMCPServer
@@ -9,7 +11,6 @@
911

1012
class DockerClient(GenericMcpClient):
1113
config: DockerMCPServer
12-
session: ClientSession | None = None
1314

1415
def __init__(self, name: str, config: DockerMCPServer) -> None:
1516
super().__init__(name=name)
@@ -19,7 +20,7 @@ def __init__(self, name: str, config: DockerMCPServer) -> None:
1920
async def _maintain_session(self):
2021
async with docker_client(self.config) as client:
2122
logger.debug(f"made instance of docker client for {self.name}")
22-
async with ClientSession(*client) as session:
23+
async with McpClientSession(*client) as session:
2324
await session.initialize()
2425
logger.debug(f"finished initialise session for {self.name}")
2526
self.session = session

mcp_bridge/mcp_clients/SseClient.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
2-
from mcp import ClientSession
32
from mcp.client.sse import sse_client
43
from mcp_bridge.config import config
54
from mcp_bridge.config.final import SSEMCPServer
5+
from mcp_bridge.mcp_clients.session import McpClientSession
66
from .AbstractClient import GenericMcpClient
77
from loguru import logger
88

@@ -17,7 +17,7 @@ def __init__(self, name: str, config: SSEMCPServer) -> None:
1717

1818
async def _maintain_session(self):
1919
async with sse_client(self.config.url) as client:
20-
async with ClientSession(*client) as session:
20+
async with McpClientSession(*client) as session:
2121
await session.initialize()
2222
logger.debug(f"finished initialise session for {self.name}")
2323
self.session = session

mcp_bridge/mcp_clients/StdioClient.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2-
from mcp import ClientSession, StdioServerParameters, stdio_client
2+
from mcp import StdioServerParameters, stdio_client
33

44
from mcp_bridge.config import config
5+
from mcp_bridge.mcp_clients.session import McpClientSession
56
from .AbstractClient import GenericMcpClient
67
from loguru import logger
78
import shutil
@@ -17,24 +18,28 @@ class StdioClient(GenericMcpClient):
1718
def __init__(self, name: str, config: StdioServerParameters) -> None:
1819
super().__init__(name=name)
1920

21+
# logger.debug(f"initializing settings for {name}: {config.command} {" ".join(config.args)}")
22+
23+
own_config = config.model_copy(deep=True)
24+
2025
env = dict(os.environ.copy())
2126

2227
env = {
2328
key: value for key, value in env.items()
2429
if not any(key.startswith(keyword) for keyword in venv_keywords)
2530
}
2631

27-
# logger.debug(f"env: {env}")
28-
2932
if config.env is not None:
3033
env.update(config.env)
3134

35+
own_config.env = env
36+
3237
command = shutil.which(config.command)
3338
if command is None:
3439
logger.error(f"could not find command {config.command}")
3540
exit(1)
3641

37-
own_config = config.model_copy(deep=True)
42+
own_config.command = command
3843

3944
# this changes the default to ignore
4045
if "encoding_error_handler" not in config.model_fields_set:
@@ -48,7 +53,7 @@ async def _maintain_session(self):
4853
logger.debug(f"entered stdio_client context manager for {self.name}")
4954
assert client[0] is not None, f"missing read stream for {self.name}"
5055
assert client[1] is not None, f"missing write stream for {self.name}"
51-
async with ClientSession(*client) as session:
56+
async with McpClientSession(*client) as session:
5257
logger.debug(f"entered client session context manager for {self.name}")
5358
await session.initialize()
5459
logger.debug(f"finished initialise session for {self.name}")

0 commit comments

Comments
 (0)