Skip to content

Commit 9ebbc43

Browse files
authored
Merge pull request #110 from Serverless-Devs/codex/mcp-malformed-error-timeout
[codex] Prevent MCP tool metadata hangs on malformed responses
2 parents 288406f + 1dca6f2 commit 9ebbc43

3 files changed

Lines changed: 243 additions & 15 deletions

File tree

agentrun/tool/api/mcp.py

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
import httpx
1212

13-
from agentrun.tool.model import ToolInfo, ToolSchema
13+
from agentrun.tool.model import ToolInfo
1414
from agentrun.utils.config import Config
1515
from agentrun.utils.log import logger
16+
from agentrun.utils.ram_signature import get_agentrun_signed_headers
17+
18+
_MCP_METADATA_TIMEOUT_SECONDS = 30.0
1619

1720

1821
def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
@@ -30,9 +33,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
3033
return loop
3134

3235

33-
from agentrun.utils.ram_signature import get_agentrun_signed_headers
34-
35-
3636
class _AgentrunRamAuth(httpx.Auth):
3737
"""httpx Auth handler:为每次请求动态生成 RAM 签名。
3838
@@ -144,6 +144,54 @@ def is_streamable(self) -> bool:
144144
"""是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport"""
145145
return self.session_affinity == "MCP_STREAMABLE"
146146

147+
def _metadata_timeout_seconds(self) -> float:
148+
timeout = self.config.get_timeout()
149+
if timeout and timeout > 0:
150+
return min(float(timeout), _MCP_METADATA_TIMEOUT_SECONDS)
151+
return _MCP_METADATA_TIMEOUT_SECONDS
152+
153+
def _invoke_timeout_seconds(self) -> float:
154+
timeout = self.config.get_timeout()
155+
if timeout and timeout > 0:
156+
return float(timeout)
157+
return 600.0
158+
159+
async def _wait_for_mcp_request(
160+
self,
161+
awaitable: Any,
162+
operation: str,
163+
timeout: float,
164+
) -> Any:
165+
try:
166+
return await asyncio.wait_for(awaitable, timeout=timeout)
167+
except asyncio.TimeoutError as exc:
168+
raise TimeoutError(
169+
f"MCP {operation} timed out after {timeout:g}s for endpoint"
170+
f" {self.endpoint}"
171+
) from exc
172+
173+
def _find_mcp_timeout_error(
174+
self, exc: BaseException
175+
) -> Optional[TimeoutError]:
176+
if isinstance(exc, TimeoutError) and str(exc).startswith("MCP "):
177+
return exc
178+
179+
nested_exceptions = getattr(exc, "exceptions", None)
180+
if not nested_exceptions:
181+
return None
182+
183+
for nested_exc in nested_exceptions:
184+
timeout_error = self._find_mcp_timeout_error(nested_exc)
185+
if timeout_error is not None:
186+
return timeout_error
187+
188+
return None
189+
190+
def _raise_mcp_timeout_if_present(self, exc: BaseException) -> None:
191+
timeout_error = self._find_mcp_timeout_error(exc)
192+
if timeout_error is not None:
193+
raise timeout_error
194+
147195
def _build_ram_auth(self, url: str) -> tuple:
148196
"""当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。
149197
@@ -199,8 +247,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
199247
async with ClientSession(
200248
read_stream, write_stream
201249
) as session:
202-
await session.initialize()
203-
result = await session.list_tools()
250+
metadata_timeout = self._metadata_timeout_seconds()
251+
await self._wait_for_mcp_request(
252+
session.initialize(),
253+
"initialize",
254+
metadata_timeout,
255+
)
256+
result = await self._wait_for_mcp_request(
257+
session.list_tools(),
258+
"list_tools",
259+
metadata_timeout,
260+
)
204261
return [
205262
ToolInfo.from_mcp_tool(tool)
206263
for tool in result.tools
@@ -215,8 +272,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
215272
async with ClientSession(
216273
read_stream, write_stream
217274
) as session:
218-
await session.initialize()
219-
result = await session.list_tools()
275+
metadata_timeout = self._metadata_timeout_seconds()
276+
await self._wait_for_mcp_request(
277+
session.initialize(),
278+
"initialize",
279+
metadata_timeout,
280+
)
281+
result = await self._wait_for_mcp_request(
282+
session.list_tools(),
283+
"list_tools",
284+
metadata_timeout,
285+
)
220286
return [
221287
ToolInfo.from_mcp_tool(tool)
222288
for tool in result.tools
@@ -226,6 +292,9 @@ async def list_tools_async(self) -> List[ToolInfo]:
226292
"mcp package is not installed. Install it with: pip install mcp"
227293
)
228294
return []
295+
except Exception as exc:
296+
self._raise_mcp_timeout_if_present(exc)
297+
raise
229298

230299
def list_tools(self) -> List[ToolInfo]:
231300
"""同步获取工具列表 / Get tool list synchronously
@@ -266,9 +335,15 @@ async def call_tool_async(
266335
async with ClientSession(
267336
read_stream, write_stream
268337
) as session:
269-
await session.initialize()
270-
result = await session.call_tool(
271-
name, arguments=arguments or {}
338+
await self._wait_for_mcp_request(
339+
session.initialize(),
340+
"initialize",
341+
self._metadata_timeout_seconds(),
342+
)
343+
result = await self._wait_for_mcp_request(
344+
session.call_tool(name, arguments=arguments or {}),
345+
f"call_tool {name}",
346+
self._invoke_timeout_seconds(),
272347
)
273348
return result
274349
else:
@@ -281,16 +356,25 @@ async def call_tool_async(
281356
async with ClientSession(
282357
read_stream, write_stream
283358
) as session:
284-
await session.initialize()
285-
result = await session.call_tool(
286-
name, arguments=arguments or {}
359+
await self._wait_for_mcp_request(
360+
session.initialize(),
361+
"initialize",
362+
self._metadata_timeout_seconds(),
363+
)
364+
result = await self._wait_for_mcp_request(
365+
session.call_tool(name, arguments=arguments or {}),
366+
f"call_tool {name}",
367+
self._invoke_timeout_seconds(),
287368
)
288369
return result
289370
except ImportError:
290371
raise ImportError(
291372
"mcp package is required for MCP tool calls. "
292373
"Install it with: pip install mcp"
293374
)
375+
except Exception as exc:
376+
self._raise_mcp_timeout_if_present(exc)
377+
raise
294378

295379
def call_tool(
296380
self,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""E2E regression tests for malformed MCP streamable-http responses."""
2+
3+
import asyncio
4+
import socket
5+
import threading
6+
import time
7+
8+
from fastapi import FastAPI, Request
9+
from fastapi.responses import JSONResponse
10+
import httpx
11+
import pytest
12+
import uvicorn
13+
14+
from agentrun.tool.api.mcp import ToolMCPSession
15+
from agentrun.utils.config import Config
16+
17+
18+
def _find_free_port() -> int:
19+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
20+
sock.bind(("127.0.0.1", 0))
21+
return sock.getsockname()[1]
22+
23+
24+
def _build_malformed_mcp_app() -> FastAPI:
25+
app = FastAPI()
26+
27+
@app.get("/health")
28+
async def health():
29+
return {"ok": True}
30+
31+
@app.post("/mcp")
32+
async def mcp_endpoint(request: Request):
33+
payload = await request.json()
34+
return JSONResponse(
35+
{
36+
"jsonrpc": "2.0",
37+
"id": payload.get("id"),
38+
"error": {
39+
"code": -32000,
40+
"message": None,
41+
},
42+
}
43+
)
44+
45+
return app
46+
47+
48+
@pytest.fixture
49+
def malformed_mcp_server():
50+
app = _build_malformed_mcp_app()
51+
port = _find_free_port()
52+
config = uvicorn.Config(
53+
app, host="127.0.0.1", port=port, log_level="warning"
54+
)
55+
server = uvicorn.Server(config)
56+
57+
thread = threading.Thread(target=server.run, daemon=True)
58+
thread.start()
59+
60+
base_url = f"http://127.0.0.1:{port}"
61+
for _ in range(50):
62+
try:
63+
httpx.get(f"{base_url}/health", timeout=0.2)
64+
break
65+
except Exception:
66+
time.sleep(0.1)
67+
else:
68+
raise RuntimeError("malformed MCP server did not start")
69+
70+
yield f"{base_url}/mcp"
71+
72+
server.should_exit = True
73+
thread.join(timeout=5)
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_streamable_mcp_malformed_initialize_error_fails_fast(
78+
malformed_mcp_server,
79+
):
80+
session = ToolMCPSession(
81+
endpoint=malformed_mcp_server,
82+
session_affinity="MCP_STREAMABLE",
83+
config=Config(timeout=0.05),
84+
)
85+
86+
with pytest.raises(TimeoutError, match="MCP initialize timed out"):
87+
await asyncio.wait_for(session.list_tools_async(), timeout=1)

tests/unittests/tool/test_mcp.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
Tests MCP protocol interaction functionality of ToolMCPSession.
55
"""
66

7+
import asyncio
78
import sys
8-
from unittest.mock import AsyncMock, MagicMock, Mock, patch
9+
from unittest.mock import AsyncMock, MagicMock, patch
910

1011
import pytest
1112

1213
from agentrun.tool.api.mcp import ToolMCPSession
1314
from agentrun.tool.model import ToolInfo
15+
from agentrun.utils.config import Config
1416

1517

1618
class TestToolMCPSessionInit:
@@ -186,6 +188,36 @@ def mock_import(name, *args, **kwargs):
186188
sys.modules.update(saved_modules)
187189
assert tools == []
188190

191+
@pytest.mark.asyncio
192+
async def test_list_tools_async_initialize_timeout(self):
193+
"""测试 initialize 无响应时不会无限等待"""
194+
195+
async def never_return():
196+
await asyncio.Event().wait()
197+
198+
mock_session = AsyncMock()
199+
mock_session.initialize = never_return
200+
mock_session.list_tools = AsyncMock()
201+
202+
mock_modules = _setup_mock_mcp_modules(mock_session)
203+
204+
with patch.dict(sys.modules, mock_modules):
205+
with patch(
206+
"agentrun.tool.api.mcp._MCP_METADATA_TIMEOUT_SECONDS",
207+
0.01,
208+
):
209+
session = ToolMCPSession(
210+
endpoint="http://example.com/mcp",
211+
session_affinity="MCP_STREAMABLE",
212+
)
213+
214+
with pytest.raises(
215+
TimeoutError, match="MCP initialize timed out"
216+
):
217+
await session.list_tools_async()
218+
219+
mock_session.list_tools.assert_not_called()
220+
189221

190222
class TestToolMCPSessionListTools:
191223
"""测试 list_tools 同步方法"""
@@ -258,6 +290,31 @@ async def test_call_tool_async_sse_mode(self):
258290

259291
assert result == mock_call_result
260292

293+
@pytest.mark.asyncio
294+
async def test_call_tool_async_timeout(self):
295+
"""测试工具调用无响应时会按 Config.timeout 退出"""
296+
297+
async def never_return(*args, **kwargs):
298+
await asyncio.Event().wait()
299+
300+
mock_session = AsyncMock()
301+
mock_session.initialize = AsyncMock()
302+
mock_session.call_tool = never_return
303+
304+
mock_modules = _setup_mock_mcp_modules(mock_session)
305+
306+
with patch.dict(sys.modules, mock_modules):
307+
session = ToolMCPSession(
308+
endpoint="http://example.com/mcp",
309+
session_affinity="MCP_STREAMABLE",
310+
config=Config(timeout=0.01),
311+
)
312+
313+
with pytest.raises(
314+
TimeoutError, match="MCP call_tool test_tool timed out"
315+
):
316+
await session.call_tool_async("test_tool", {"key": "val"})
317+
261318
@pytest.mark.asyncio
262319
async def test_call_tool_async_import_error(self):
263320
"""测试 mcp 未安装时抛出 ImportError"""

0 commit comments

Comments
 (0)