forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_client_disconnect_post.py
More file actions
128 lines (101 loc) · 4.78 KB
/
Copy pathtest_client_disconnect_post.py
File metadata and controls
128 lines (101 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Tests for ClientDisconnect handling in StreamableHTTPServerTransport._handle_post_request.
Regression test for pattern 1: ClientDisconnect raised during POST should log at WARNING
(not ERROR) and should not attempt to send a response to the closed socket.
Inspired by upstream PRs:
- https://github.com/modelcontextprotocol/python-sdk/pull/1647 (scope: POST only)
- https://github.com/modelcontextprotocol/python-sdk/pull/1947 (semantics: notify writer, skip response)
"""
from __future__ import annotations as _annotations
import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from starlette.requests import ClientDisconnect, Request
from mcp.server.streamable_http import StreamableHTTPServerTransport
class TestClientDisconnectDuringPOST:
"""ClientDisconnect during POST should be handled gracefully."""
def _make_scope(self, headers: dict[str, bytes] | None = None) -> dict[str, Any]:
"""Build a minimal ASGI scope for a POST request."""
return {
"type": "http",
"method": "POST",
"path": "/mcp",
"query_string": b"",
"headers": list((headers or {}).items()) if headers else [
(b"content-type", b"application/json"),
(b"accept", b"application/json, text/event-stream"),
],
}
@pytest.mark.anyio
async def test_client_disconnect_logs_warning_not_error(self, caplog):
"""ClientDisconnect should produce a WARNING, not an ERROR."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()
# Set up a dummy writer so the transport passes the None check
mock_writer = MagicMock()
mock_writer.send = AsyncMock()
transport._read_stream_writer = mock_writer
# Mock request.body() to raise ClientDisconnect (simulates client going away
# mid-request body upload).
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope
send_calls: list[Any] = []
async def dummy_receive():
return {"type": "http.request", "body": b""}
async def dummy_send(message):
send_calls.append(message)
with caplog.at_level(logging.DEBUG, logger="mcp.server.streamable_http"):
await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)
# Should log a WARNING, not an ERROR
warning_records = [
r for r in caplog.records if r.levelno == logging.WARNING
and "Client disconnected" in r.getMessage()
]
error_records = [
r for r in caplog.records if r.levelno == logging.ERROR
]
assert len(warning_records) == 1, (
f"Expected exactly 1 WARNING with 'Client disconnected', got {len(warning_records)}"
)
assert len(error_records) == 0, (
f"Expected 0 ERROR logs, got {len(error_records)}: {[r.getMessage() for r in error_records]}"
)
@pytest.mark.anyio
async def test_client_disconnect_sends_response(self):
"""After ClientDisconnect, a 202 response is sent so middleware chains don't
raise 'No response returned' (ASGI server drops it if socket is closed)."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()
# Set up a dummy writer so the transport passes the None check
mock_writer = MagicMock()
mock_writer.send = AsyncMock()
transport._read_stream_writer = mock_writer
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope
send_calls: list[Any] = []
async def dummy_receive():
return {"type": "http.request", "body": b""}
async def dummy_send(message):
send_calls.append(message)
await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)
# A response IS sent (202 Accepted) so middleware chains don't blow up
assert len(send_calls) >= 1, (
f"Expected at least 1 ASGI send (response), got {len(send_calls)}"
)
# First send should be http.response.start with 499 (Client Closed Request)
assert send_calls[0]["type"] == "http.response.start"
assert send_calls[0]["status"] == 499