Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,35 +666,69 @@ async def _start_mcp_server(
if shutdown_event is None:
shutdown_event = asyncio.Event()

mcp_client: MCPClient | None = None
try:
mcp_client = await asyncio.wait_for(
self._init_mcp_client(name, cfg),
timeout=timeout,
mcp_client = MCPClient()
mcp_client.name = name

connect_done = asyncio.Event()
Comment thread
GlowingBrick marked this conversation as resolved.
connect_error: BaseException | None = None

async def connect_and_lifecycle() -> None:
# Single task that handles connect, lifecycle, and cleanup.

nonlocal connect_error
try:
await mcp_client.connect_to_server(cfg, name)
await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
# cleanup on cancellation
try:
Comment on lines +675 to +684

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Tool registration and logging errors won’t propagate via connect_error, leaving the caller to see only a timeout.

Because the try ends at await mcp_client.list_tools_and_save(), any exception during tool registration or logging won’t set connect_error or connect_done. The caller will then hit the timeout and raise MCPInitTimeoutError instead of the real failure. Please extend the try to include the registration/logging block, or add a dedicated try/except there that sets connect_error and signals connect_done on error.

await mcp_client.cleanup()
except BaseException:
pass
raise
except Exception as e:
connect_error = e
try:
await mcp_client.cleanup()
except Exception:
pass
connect_done.set()
return

# Register tools
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
for tool in mcp_client.tools:
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)

logger.info(
f"Connected to MCP server {name}, "
f"Tools: {[t.name for t in mcp_client.tools]}"
)
except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
) from exc
except Exception:
logger.error(f"Failed to initialize MCP client {name}", exc_info=True)
raise
finally:
if mcp_client is None:
async with self._runtime_lock:
self._mcp_starting.discard(name)

async def lifecycle() -> None:
connect_done.set()

try:
await shutdown_event.wait()
logger.info(f"Received shutdown signal for MCP client {name}")
except asyncio.CancelledError:
logger.debug(f"MCP client {name} task was cancelled")
raise
finally:
await self._terminate_mcp_client(name)
# Cleanup in the same task that entered the anyio contexts
await asyncio.shield(self._terminate_mcp_client(name))

lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
lifecycle_task = asyncio.create_task(
connect_and_lifecycle(), name=f"mcp-client:{name}"
)
async with self._runtime_lock:
self._mcp_server_runtime[name] = _MCPServerRuntime(
name=name,
Expand All @@ -704,6 +738,26 @@ async def lifecycle() -> None:
)
self._mcp_starting.discard(name)

try:
await asyncio.wait_for(connect_done.wait(), timeout=timeout)
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
lifecycle_task.cancel()
await asyncio.gather(lifecycle_task, return_exceptions=True)
async with self._runtime_lock:
self._mcp_starting.discard(name)
self._mcp_server_runtime.pop(name, None)
if isinstance(e, asyncio.TimeoutError):
raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
) from e
raise

if connect_error is not None:
async with self._runtime_lock:
self._mcp_starting.discard(name)
self._mcp_server_runtime.pop(name, None)
raise connect_error

async def _shutdown_runtimes(
self,
runtimes: list[_MCPServerRuntime],
Expand Down Expand Up @@ -768,41 +822,6 @@ async def _cleanup_mcp_client_safely(
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
)

async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
"""初始化单个MCP客户端"""
mcp_client = MCPClient()
mcp_client.name = name
try:
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
except Exception:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]

# 移除该MCP服务之前的工具(如有)
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]

# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)

logger.info(f"Connected to MCP server {name}, Tools: {tool_names}")
return mcp_client

async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
async with self._runtime_lock:
Expand Down
Loading