forked from MervinPraison/PraisonAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcp_sse.py
More file actions
258 lines (211 loc) · 9.28 KB
/
mcp_sse.py
File metadata and controls
258 lines (211 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
SSE (Server-Sent Events) client implementation for MCP (Model Context Protocol).
This module provides the necessary classes and functions to connect to an MCP server
over SSE transport.
"""
import asyncio
import logging
import threading
import inspect
import json
from typing import List, Dict, Any, Optional, Callable, Iterable
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
MCP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
ClientSession = None
sse_client = None
logger = logging.getLogger("mcp-sse")
# Global event loop for async operations
_event_loop = None
def get_event_loop():
"""Get or create a global event loop."""
global _event_loop
if _event_loop is None or _event_loop.is_closed():
_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_event_loop)
return _event_loop
class SSEMCPTool:
"""A wrapper for an MCP tool that can be used with praisonaiagents."""
def __init__(self, name: str, description: str, session: ClientSession, input_schema: Optional[Dict[str, Any]] = None, timeout: int = 60):
self.name = name
self.__name__ = name # Required for Agent to recognize it as a tool
self.__qualname__ = name # Required for Agent to recognize it as a tool
self.__doc__ = description # Required for Agent to recognize it as a tool
self.description = description
self.session = session
self.input_schema = input_schema or {}
self.timeout = timeout
# Create a signature based on input schema
params = []
if input_schema and 'properties' in input_schema:
for param_name, prop_schema in input_schema['properties'].items():
# Determine type annotation based on schema
prop_type = prop_schema.get('type', 'string') if isinstance(prop_schema, dict) else 'string'
if prop_type == 'string':
annotation = str
elif prop_type == 'integer':
annotation = int
elif prop_type == 'number':
annotation = float
elif prop_type == 'boolean':
annotation = bool
elif prop_type == 'array':
annotation = list
elif prop_type == 'object':
annotation = dict
else:
annotation = str # Default to string for SSE
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=inspect.Parameter.empty if param_name in input_schema.get('required', []) else None,
annotation=annotation
)
)
self.__signature__ = inspect.Signature(params)
def __call__(self, **kwargs):
"""Synchronous wrapper for the async call."""
logger.debug(f"Tool {self.name} called with args: {kwargs}")
# Use the global event loop
loop = get_event_loop()
# Run the async call in the event loop
future = asyncio.run_coroutine_threadsafe(self._async_call(**kwargs), loop)
try:
# Wait for the result with a timeout
return future.result(timeout=self.timeout)
except Exception as e:
logger.error(f"Error calling tool {self.name}: {e}")
return f"Error: {str(e)}"
async def _async_call(self, **kwargs):
"""Call the tool with the provided arguments."""
logger.debug(f"Async calling tool {self.name} with args: {kwargs}")
try:
result = await self.session.call_tool(self.name, kwargs)
# Extract text from result
if hasattr(result, 'content') and result.content:
if hasattr(result.content[0], 'text'):
return result.content[0].text
return str(result.content[0])
return str(result)
except Exception as e:
logger.error(f"Error in _async_call for {self.name}: {e}")
raise
def _fix_array_schemas(self, schema):
"""
Fix array schemas by adding missing 'items' attribute required by OpenAI.
This ensures compatibility with OpenAI's function calling format which
requires array types to specify the type of items they contain.
Args:
schema: The schema dictionary to fix
Returns:
dict: The fixed schema
"""
if not isinstance(schema, dict):
return schema
# Create a copy to avoid modifying the original
fixed_schema = schema.copy()
# Fix array types at the current level
if fixed_schema.get("type") == "array" and "items" not in fixed_schema:
# Add a default items schema for arrays without it
fixed_schema["items"] = {"type": "string"}
# Recursively fix nested schemas
if "properties" in fixed_schema:
fixed_properties = {}
for prop_name, prop_schema in fixed_schema["properties"].items():
fixed_properties[prop_name] = self._fix_array_schemas(prop_schema)
fixed_schema["properties"] = fixed_properties
# Fix items schema if it exists
if "items" in fixed_schema:
fixed_schema["items"] = self._fix_array_schemas(fixed_schema["items"])
return fixed_schema
def to_openai_tool(self):
"""Convert the tool to OpenAI format."""
# Fix array schemas to include 'items' attribute
fixed_schema = self._fix_array_schemas(self.input_schema)
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": fixed_schema
}
}
class SSEMCPClient:
"""A client for connecting to an MCP server over SSE."""
def __init__(self, server_url: str, debug: bool = False, timeout: int = 60):
"""
Initialize an SSE MCP client.
Args:
server_url: The URL of the SSE MCP server
debug: Whether to enable debug logging
timeout: Timeout in seconds for operations (default: 60)
"""
# Check if MCP is available
if not MCP_AVAILABLE:
raise ImportError(
"MCP (Model Context Protocol) package is not installed. "
"Install it with: pip install praisonaiagents[mcp]"
)
self.server_url = server_url
self.debug = debug
self.timeout = timeout
self.session = None
self.tools = []
# Set up logging
if debug:
logger.setLevel(logging.DEBUG)
else:
# Set to WARNING by default to hide INFO messages
logger.setLevel(logging.WARNING)
self._initialize()
def _initialize(self):
"""Initialize the connection and tools."""
# Use the global event loop
loop = get_event_loop()
# Start a background thread to run the event loop
def run_event_loop():
asyncio.set_event_loop(loop)
loop.run_forever()
self.loop_thread = threading.Thread(target=run_event_loop, daemon=True)
self.loop_thread.start()
# Run the initialization in the event loop
future = asyncio.run_coroutine_threadsafe(self._async_initialize(), loop)
self.tools = future.result(timeout=self.timeout)
async def _async_initialize(self):
"""Asynchronously initialize the connection and tools."""
logger.debug(f"Connecting to MCP server at {self.server_url}")
# Create SSE client
self._streams_context = sse_client(url=self.server_url)
streams = await self._streams_context.__aenter__()
self._session_context = ClientSession(*streams)
self.session = await self._session_context.__aenter__()
# Initialize
await self.session.initialize()
# List available tools
logger.debug("Listing tools...")
response = await self.session.list_tools()
tools_data = response.tools
logger.debug(f"Found {len(tools_data)} tools: {[tool.name for tool in tools_data]}")
# Create tool wrappers
tools = []
for tool in tools_data:
input_schema = tool.inputSchema if hasattr(tool, 'inputSchema') else None
wrapper = SSEMCPTool(
name=tool.name,
description=tool.description if hasattr(tool, 'description') else f"Call the {tool.name} tool",
session=self.session,
input_schema=input_schema,
timeout=self.timeout
)
tools.append(wrapper)
return tools
def __iter__(self):
"""Return an iterator over the tools."""
return iter(self.tools)
def to_openai_tools(self):
"""Convert all tools to OpenAI format."""
return [tool.to_openai_tool() for tool in self.tools]