Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions python/semantic_kernel/connectors/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,15 @@ async def __aexit__(

async def connect(self) -> None:
"""Connect to the MCP server."""
ready_event = asyncio.Event()
loop = asyncio.get_running_loop()
ready_future: asyncio.Future[None] = loop.create_future()

try:
self._current_task = asyncio.create_task(self._inner_connect(ready_event))
await ready_event.wait()
self._current_task = asyncio.create_task(self._inner_connect(ready_future))
await ready_future
except KernelPluginInvalidConfigurationError:
ready_event.clear()
raise
except Exception as ex:
ready_event.clear()
await self.close()
raise FunctionExecutionException("Failed to enter context manager.") from ex

Expand All @@ -261,16 +261,20 @@ async def close(self) -> None:
self._current_task = None
self.session = None

async def _inner_connect(self, ready_event: asyncio.Event) -> None:
async def _inner_connect(self, ready_future: asyncio.Future) -> None:
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
ready_event.set()
raise KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
)
exc.__cause__ = ex
ready_future.set_exception(exc)
return

try:
session = await self._exit_stack.enter_async_context(
ClientSession(
Expand All @@ -284,16 +288,24 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None:
)
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
)
exc.__cause__ = ex
ready_future.set_exception(exc)
return
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
)
exc.__cause__ = ex
ready_future.set_exception(exc)
return
self.session = session
elif self.session._request_id == 0:
# If the session is not initialized, we need to reinitialize it
Expand All @@ -312,7 +324,8 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None:
except Exception:
logger.warning("Failed to set log level to %s", logger.level)
# Setting up is complete, will now signal the main loop that we are ready
ready_event.set()
if not ready_future.done():
ready_future.set_result(None)
# Create a stop event to signal the exit stack to close
self._stop_event = asyncio.Event()
await self._stop_event.wait()
Expand Down Expand Up @@ -434,11 +447,8 @@ async def message_handler(

async def load_prompts(self):
"""Load prompts from the MCP server."""
try:
prompt_list = await self.session.list_prompts()
except Exception:
prompt_list = None
for prompt in prompt_list.prompts if prompt_list else []:
prompt_list = await self.session.list_prompts()
for prompt in prompt_list.prompts:
local_name = _normalize_mcp_name(prompt.name)
func = kernel_function(name=local_name, description=prompt.description)(
partial(self.get_prompt, prompt.name)
Expand All @@ -448,12 +458,8 @@ async def load_prompts(self):

async def load_tools(self):
"""Load tools from the MCP server."""
try:
tool_list = await self.session.list_tools()
except Exception:
tool_list = None
# Create methods with the kernel_function decorator for each tool
for tool in tool_list.tools if tool_list else []:
tool_list = await self.session.list_tools()
for tool in tool_list.tools:
local_name = _normalize_mcp_name(tool.name)
func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name))
func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool)
Expand Down
Loading