Skip to content
This repository was archived by the owner on Jun 27, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 30 additions & 47 deletions swarms/tools/mcp_client_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ async def wrapper(*args, **kwargs):
f"Failed after {retries} retries: {str(e)}"
)
raise
sleep_time = (
backoff_in_seconds * 2**x
+ random.uniform(0, 1)
sleep_time = backoff_in_seconds * 2**x + random.uniform(
0, 1
)
logger.warning(
f"Attempt {x + 1} failed, retrying in {sleep_time:.2f}s"
Expand All @@ -193,20 +192,30 @@ async def wrapper(*args, **kwargs):

@contextlib.contextmanager
def get_or_create_event_loop():
"""Context manager to handle event loop creation and cleanup."""
"""Context manager that safely provides an event loop.

If a loop is already running in the current thread, a new loop is created
to avoid "event loop is already running" errors. The created loop is
automatically closed when the context exits.
"""

try:
loop = asyncio.get_event_loop()
# Detect if there's a running loop in this thread. If so, create a new
# one instead of using it as ``run_until_complete`` would fail.
asyncio.get_running_loop()
loop = asyncio.new_event_loop()
created = True
except RuntimeError:
# No running loop, so create one and set it for this thread.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
created = True

try:
yield loop
finally:
# Only close the loop if we created it and it's not the main event loop
if loop != asyncio.get_event_loop() and not loop.is_running():
if not loop.is_closed():
loop.close()
if created and not loop.is_closed():
loop.close()


def connect_to_mcp_server(connection: MCPConnection = None):
Expand All @@ -227,9 +236,7 @@ def connect_to_mcp_server(connection: MCPConnection = None):
# Direct attribute access is faster than property access
headers = dict(connection.headers or {})
if connection.authorization_token:
headers["Authorization"] = (
f"Bearer {connection.authorization_token}"
)
headers["Authorization"] = f"Bearer {connection.authorization_token}"

return (
headers,
Expand Down Expand Up @@ -261,9 +268,7 @@ async def aget_mcp_tools(
MCPConnectionError: If connection to server fails
"""
if exists(connection):
headers, timeout, transport, url = connect_to_mcp_server(
connection
)
headers, timeout, transport, url = connect_to_mcp_server(connection)
else:
headers, timeout, _transport, _url = (
None,
Expand All @@ -287,18 +292,12 @@ async def aget_mcp_tools(
):
async with ClientSession(read, write) as session:
await session.initialize()
tools = await load_mcp_tools(
session=session, format=format
)
logger.info(
f"Successfully fetched {len(tools)} tools"
)
tools = await load_mcp_tools(session=session, format=format)
logger.info(f"Successfully fetched {len(tools)} tools")
return tools
except Exception as e:
logger.error(f"Error fetching MCP tools: {str(e)}")
raise MCPConnectionError(
f"Failed to connect to MCP server: {str(e)}"
)
raise MCPConnectionError(f"Failed to connect to MCP server: {str(e)}")


def get_mcp_tools_sync(
Expand Down Expand Up @@ -373,11 +372,7 @@ def get_tools_for_multiple_mcp_servers(
List[Dict[str, Any]]: Combined list of tools from all servers
"""
tools = []
(
min(32, os.cpu_count() + 4)
if max_workers is None
else max_workers
)
(min(32, os.cpu_count() + 4) if max_workers is None else max_workers)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
if exists(connections):
# Create future tasks for each URL-connection pair
Expand All @@ -403,9 +398,7 @@ def get_tools_for_multiple_mcp_servers(
server_tools = future.result()
tools.extend(server_tools)
except Exception as e:
logger.error(
f"Error fetching tools from {url}: {str(e)}"
)
logger.error(f"Error fetching tools from {url}: {str(e)}")
raise MCPExecutionError(
f"Failed to fetch tools from {url}: {str(e)}"
)
Expand All @@ -423,9 +416,7 @@ async def _execute_tool_call_simple(
):
"""Execute a tool call using the MCP client."""
if exists(connection):
headers, timeout, transport, url = connect_to_mcp_server(
connection
)
headers, timeout, transport, url = connect_to_mcp_server(connection)
else:
headers, timeout, _transport, url = (
None,
Expand Down Expand Up @@ -462,28 +453,20 @@ async def _execute_tool_call_simple(
for item in value:
if isinstance(item, dict):
for k, v in item.items():
formatted_lines.append(
f"{k}: {v}"
)
formatted_lines.append(f"{k}: {v}")
else:
formatted_lines.append(
f"{key}: {value}"
)
formatted_lines.append(f"{key}: {value}")
out = "\n".join(formatted_lines)

return out

except Exception as e:
logger.error(f"Error in tool execution: {str(e)}")
raise MCPExecutionError(
f"Tool execution failed: {str(e)}"
)
raise MCPExecutionError(f"Tool execution failed: {str(e)}")

except Exception as e:
logger.error(f"Error in SSE client connection: {str(e)}")
raise MCPConnectionError(
f"Failed to connect to MCP server: {str(e)}"
)
raise MCPConnectionError(f"Failed to connect to MCP server: {str(e)}")


async def execute_tool_call_simple(
Expand Down
15 changes: 15 additions & 0 deletions tests/tools/test_mcp_client_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import asyncio
import pytest

from swarms.tools import mcp_client_call


@pytest.mark.asyncio
async def test_get_mcp_tools_sync_with_running_loop(monkeypatch):
async def fake_get(*args, **kwargs):
return [{"name": "mock"}]

monkeypatch.setattr(mcp_client_call, "aget_mcp_tools", fake_get)

result = mcp_client_call.get_mcp_tools_sync(server_path="dummy")
assert result == [{"name": "mock"}]
Loading