Skip to content

Commit cce430d

Browse files
wukathcopybara-github
authored andcommitted
feat: start and close ClientSession in a single task in McpSessionManager
Merge #4025 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: - #3950 - #3731 - #3708 **2. Or, if no issue exists, describe the change:** **Problem:** - `ClientSession` of https://github.com/modelcontextprotocol/python-sdk uses AnyIO for async task management. - AnyIO TaskGroup requires its start and close must happen in a same task. - Since `McpSessionManager` does not create task per client, the client might be closed by different task, cause the error: `Attempted to exit cancel scope in a different task than it was entered in`. **Solution:** I Suggest 2 changes: Handling the `ClientSession` in a single task - To start and close `ClientSession` by the same task, we need to wrap the whole lifecycle of `ClientSession` to a single task. - `SessionContext` wraps the initialization and disposal of `ClientSession` to a single task, ensures that the `ClientSession` will be handled only in a dedicated task. Add timeout for `ClientSession` - Since now we are using task per `ClientSession`, task should never be leaked. - But `McpSessionManager` does not deliver timeout directly to `ClientSession` when the type is not STDIO. - There is only timeout for `httpx` client when MCP type is SSE or StreamableHTTP. - But the timeout applys only to `httpx` client, so if there is an issue in MCP client itself(e.g. modelcontextprotocol/python-sdk#262), a tool call waits the result **FOREVER**! - To overcome this issue, I propagated the `sse_read_timeout` to `ClientSession`. - `timeout` is too short for timeout for tool call, since its default value is only 5s. - `sse_read_timeout` is originally made for read timeout of SSE(default value of 5m or 300s), but actually most of SSE implementations from server (e.g. FastAPI, etc.) sends ping periodically(about 15s I assume), so in a normal circumstances this timeout is quite useless. - If the server does not send ping, the timeout is equal to tool call timeout. Therefore, it would be appropriate to use `sse_read_timeout` as tool call timeout. - Most of tool calls should finish within 5 minutes, and sse timeout is adjustable if not. - If this change is not acceptable, we could make a dedicate parameter for tool call timeout(e.g. `tool_call_timeout`). ### Testing Plan - Although this does not change the interface itself, it changes its own session management logics, some existing tests are no longer valid. - I made changes to those tests, especially those of which validate session states(e.g. checking whether `initialize()` called). - Since now session is encapsulated with `SessionContext`, we cannot validate the initialized state of the session in `TestMcpSessionManager`, should validate it at `TestSessionContext`. - Added a simple test for reproducing the issue(`test_create_and_close_session_in_different_tasks`). - Also made a test for the new component: `SessionContext`. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. ```plaintext =================================================================================== 3689 passed, 1 skipped, 2205 warnings in 63.39s (0:01:03) =================================================================================== ``` **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [ ] ~~Any dependent changes have been merged and published in downstream modules.~~ `no deps has been changed` ### Additional context This PR is related to modelcontextprotocol/python-sdk#1817 since it also fixes endless tool call awaiting. Co-authored-by: Kathy Wu <wukathy@google.com> COPYBARA_INTEGRATE_REVIEW=#4025 from challenger71498:feat/task-based-mcp-session-manager f7f7cd0 PiperOrigin-RevId: 856438147
1 parent 1133ce2 commit cce430d

File tree

4 files changed

+848
-46
lines changed

4 files changed

+848
-46
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from pydantic import BaseModel
4242
from pydantic import ConfigDict
4343

44+
from .session_context import SessionContext
45+
4446
logger = logging.getLogger('google_adk.' + __name__)
4547

4648

@@ -385,29 +387,27 @@ async def create_session(
385387
if hasattr(self._connection_params, 'timeout')
386388
else None
387389
)
390+
sse_read_timeout_in_seconds = (
391+
self._connection_params.sse_read_timeout
392+
if hasattr(self._connection_params, 'sse_read_timeout')
393+
else None
394+
)
388395

389396
try:
390397
client = self._create_client(merged_headers)
391-
392-
transports = await asyncio.wait_for(
393-
exit_stack.enter_async_context(client),
398+
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
399+
400+
session = await asyncio.wait_for(
401+
exit_stack.enter_async_context(
402+
SessionContext(
403+
client=client,
404+
timeout=timeout_in_seconds,
405+
sse_read_timeout=sse_read_timeout_in_seconds,
406+
is_stdio=is_stdio,
407+
)
408+
),
394409
timeout=timeout_in_seconds,
395410
)
396-
# The streamable http client returns a GetSessionCallback in addition to the
397-
# read/write MemoryObjectStreams needed to build the ClientSession, we limit
398-
# then to the two first values to be compatible with all clients.
399-
if isinstance(self._connection_params, StdioConnectionParams):
400-
session = await exit_stack.enter_async_context(
401-
ClientSession(
402-
*transports[:2],
403-
read_timeout_seconds=timedelta(seconds=timeout_in_seconds),
404-
)
405-
)
406-
else:
407-
session = await exit_stack.enter_async_context(
408-
ClientSession(*transports[:2])
409-
)
410-
await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds)
411411

412412
# Store session and exit stack in the pool
413413
self._sessions[session_key] = (session, exit_stack)
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
from contextlib import AsyncExitStack
19+
from datetime import timedelta
20+
import logging
21+
from typing import AsyncContextManager
22+
from typing import Optional
23+
24+
from mcp import ClientSession
25+
26+
logger = logging.getLogger('google_adk.' + __name__)
27+
28+
29+
class SessionContext:
30+
"""Represents the context of a single MCP session within a dedicated task.
31+
32+
AnyIO's TaskGroup/CancelScope requires that the start and end of a scope
33+
occur within the same task. Since MCP clients use AnyIO internally, we need
34+
to ensure that the client's entire lifecycle (creation, usage, and cleanup)
35+
happens within a single dedicated task.
36+
37+
This class spawns a background task that:
38+
1. Enters the MCP client's async context and initializes the session
39+
2. Signals readiness via an asyncio.Event
40+
3. Waits for a close signal
41+
4. Cleans up the client within the same task
42+
43+
This ensures CancelScope constraints are satisfied regardless of which
44+
task calls start() or close().
45+
46+
Can be used in two ways:
47+
1. Direct method calls: start() and close()
48+
2. As an async context manager: async with lifecycle as session: ...
49+
"""
50+
51+
def __init__(
52+
self,
53+
client: AsyncContextManager,
54+
timeout: Optional[float],
55+
sse_read_timeout: Optional[float],
56+
is_stdio: bool = False,
57+
):
58+
"""
59+
Args:
60+
client: An MCP client context manager (e.g., from streamablehttp_client,
61+
sse_client, or stdio_client).
62+
timeout: Timeout in seconds for connection and initialization.
63+
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
64+
server.
65+
is_stdio: Whether this is a stdio connection (affects read timeout).
66+
"""
67+
self._client = client
68+
self._timeout = timeout
69+
self._sse_read_timeout = sse_read_timeout
70+
self._is_stdio = is_stdio
71+
self._session: Optional[ClientSession] = None
72+
self._ready_event = asyncio.Event()
73+
self._close_event = asyncio.Event()
74+
self._task: Optional[asyncio.Task] = None
75+
self._task_lock = asyncio.Lock()
76+
77+
@property
78+
def session(self) -> Optional[ClientSession]:
79+
"""Get the managed ClientSession, if available."""
80+
return self._session
81+
82+
async def start(self) -> ClientSession:
83+
"""Start the runner and wait for the session to be ready.
84+
85+
Returns:
86+
The initialized ClientSession.
87+
88+
Raises:
89+
ConnectionError: If session creation fails.
90+
"""
91+
async with self._task_lock:
92+
if self._session:
93+
logger.debug(
94+
'Session has already been created, returning existing session'
95+
)
96+
return self._session
97+
98+
if self._close_event.is_set():
99+
raise ConnectionError(
100+
'Failed to create MCP session: session already closed'
101+
)
102+
103+
if not self._task:
104+
self._task = asyncio.create_task(self._run())
105+
106+
await self._ready_event.wait()
107+
108+
if self._task.cancelled():
109+
raise ConnectionError('Failed to create MCP session: task cancelled')
110+
111+
if self._task.done() and self._task.exception():
112+
raise ConnectionError(
113+
f'Failed to create MCP session: {self._task.exception()}'
114+
) from self._task.exception()
115+
116+
return self._session
117+
118+
async def close(self):
119+
"""Signal the context task to close and wait for cleanup."""
120+
# Set the close event to signal the task to close.
121+
# Even if start has not been called, we need to set the close event
122+
# to signal the task to close right away.
123+
async with self._task_lock:
124+
self._close_event.set()
125+
126+
# If start has not been called, only set the close event and return
127+
if not self._task:
128+
return
129+
130+
if not self._ready_event.is_set():
131+
self._task.cancel()
132+
133+
try:
134+
await asyncio.wait_for(self._task, timeout=self._timeout)
135+
except asyncio.TimeoutError:
136+
logger.warning('Failed to close MCP session: task timed out')
137+
self._task.cancel()
138+
except asyncio.CancelledError:
139+
pass
140+
except Exception as e:
141+
logger.warning(f'Failed to close MCP session: {e}')
142+
143+
async def __aenter__(self) -> ClientSession:
144+
return await self.start()
145+
146+
async def __aexit__(self, exc_type, exc_val, exc_tb):
147+
await self.close()
148+
149+
async def _run(self):
150+
"""Run the complete session context within a single task."""
151+
try:
152+
async with AsyncExitStack() as exit_stack:
153+
transports = await asyncio.wait_for(
154+
exit_stack.enter_async_context(self._client),
155+
timeout=self._timeout,
156+
)
157+
# The streamable http client returns a GetSessionCallback in addition
158+
# to the read/write MemoryObjectStreams needed to build the
159+
# ClientSession. We limit to the first two values to be compatible
160+
# with all clients.
161+
if self._is_stdio:
162+
session = await exit_stack.enter_async_context(
163+
ClientSession(
164+
*transports[:2],
165+
read_timeout_seconds=timedelta(seconds=self._timeout)
166+
if self._timeout is not None
167+
else None,
168+
)
169+
)
170+
else:
171+
# For SSE and Streamable HTTP clients, use the sse_read_timeout
172+
# instead of the connection timeout as the read_timeout for the session.
173+
session = await exit_stack.enter_async_context(
174+
ClientSession(
175+
*transports[:2],
176+
read_timeout_seconds=timedelta(seconds=self._sse_read_timeout)
177+
if self._sse_read_timeout is not None
178+
else None,
179+
)
180+
)
181+
await asyncio.wait_for(session.initialize(), timeout=self._timeout)
182+
logger.debug('Session has been successfully initialized')
183+
184+
self._session = session
185+
self._ready_event.set()
186+
187+
# Wait for close signal - the session remains valid while we wait
188+
await self._close_event.wait()
189+
except BaseException as e:
190+
logger.warning(f'Error on session runner task: {e}')
191+
raise
192+
finally:
193+
self._ready_event.set()
194+
self._close_event.set()

0 commit comments

Comments
 (0)