Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/01_basic_usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/02_full_schema_description_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
mcp.mount_http()

if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/03_custom_exposed_endpoints_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
combined_include_mcp.mount_http(mount_path="/combined-include-mcp")

if __name__ == "__main__":
import os
import uvicorn

print("Server is running with multiple MCP endpoints:")
Expand All @@ -68,4 +69,5 @@
print(" - /include-tags-mcp: Only operations with the 'items' tag")
print(" - /exclude-tags-mcp: All operations except those with the 'search' tag")
print(" - /combined-include-mcp: Operations with 'search' tag or delete_item operation")
uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/04_separate_server_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# It still works 🚀
# Your original API is **not exposed**, only via the MCP server.
if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(mcp_app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(mcp_app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/05_reregister_tools_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ async def new_endpoint():


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/06_custom_mcp_router_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/07_configure_http_timeout_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/08_auth_example_token_passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ async def private(token=Depends(token_auth_scheme)):


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
4 changes: 3 additions & 1 deletion examples/09_auth_example_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ async def protected(user_id: str = Depends(get_current_user_id)):


if __name__ == "__main__":
import os
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
# Set HOST=0.0.0.0 to expose on your network.
uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000)
34 changes: 31 additions & 3 deletions fastapi_mcp/transport/sse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from uuid import UUID
import logging
from uuid import UUID
from typing import Union

from anyio.streams.memory import MemoryObjectSendStream
Expand All @@ -13,8 +13,36 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_BODY_BYTES = 1_000_000


class FastApiSseTransport(SseServerTransport):
def __init__(self, endpoint: str, max_body_bytes: int = DEFAULT_MAX_BODY_BYTES):
super().__init__(endpoint)

max_body_bytes_int = int(max_body_bytes)
if max_body_bytes_int <= 0:
raise ValueError("max_body_bytes must be positive")

self._max_body_bytes = max_body_bytes_int

async def _read_body_with_limit(self, request: Request) -> bytes:
content_length = request.headers.get("content-length")
if content_length is not None:
try:
if int(content_length) > self._max_body_bytes:
raise HTTPException(status_code=413, detail="Payload too large")
except ValueError:
# Ignore invalid Content-Length and fall back to streaming limit.
pass

out = bytearray()
async for chunk in request.stream():
out.extend(chunk)
if len(out) > self._max_body_bytes:
raise HTTPException(status_code=413, detail="Payload too large")
return bytes(out)

async def handle_fastapi_post_message(self, request: Request) -> Response:
"""
A reimplementation of the handle_post_message method of SseServerTransport
Expand Down Expand Up @@ -53,8 +81,8 @@ async def handle_fastapi_post_message(self, request: Request) -> Response:
logger.warning(f"Could not find session for ID: {session_id}")
raise HTTPException(status_code=404, detail="Could not find session")

body = await request.body()
logger.debug(f"Received JSON: {body.decode()}")
body = await self._read_body_with_limit(request)
logger.debug(f"Received JSON: {body.decode(errors='replace')}")

try:
message = JSONRPCMessage.model_validate_json(body)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_sse_mock_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from mcp.types import JSONRPCMessage, JSONRPCError


async def _bytes_stream(data: bytes):
yield data


@pytest.fixture
def mock_transport() -> FastApiSseTransport:
# Initialize transport with a mock endpoint
Expand Down Expand Up @@ -88,7 +92,8 @@ async def test_handle_post_message_validation_error(
# Create a mock request with valid session_id but invalid body
mock_request = MagicMock(spec=Request)
mock_request.query_params = {"session_id": valid_session_id.hex}
mock_request.body = AsyncMock(return_value=b'{"invalid": "json"}')
mock_request.headers = {}
mock_request.stream = lambda: _bytes_stream(b'{"invalid": "json"}')

# Mock BackgroundTasks
with patch("fastapi_mcp.transport.sse.BackgroundTasks") as MockBackgroundTasks:
Expand Down Expand Up @@ -119,7 +124,8 @@ async def test_handle_post_message_general_exception(
# Instead of mocking the body method to raise an exception,
# we'll patch the body method to return a normal value and then
# patch JSONRPCMessage.model_validate_json to raise the exception
mock_request.body = AsyncMock(return_value=b'{"jsonrpc": "2.0", "method": "test", "id": "1"}')
mock_request.headers = {}
mock_request.stream = lambda: _bytes_stream(b'{"jsonrpc": "2.0", "method": "test", "id": "1"}')

# Mock the model_validate_json method to raise an Exception
with patch("mcp.types.JSONRPCMessage.model_validate_json", side_effect=Exception("Test exception")):
Expand All @@ -131,6 +137,22 @@ async def test_handle_post_message_general_exception(
assert "Invalid request body" in excinfo.value.detail


@pytest.mark.anyio
async def test_handle_post_message_payload_too_large(valid_session_id: UUID, mock_writer: AsyncMock) -> None:
transport = FastApiSseTransport("/messages", max_body_bytes=10)
transport._read_stream_writers = {valid_session_id: mock_writer}

mock_request = MagicMock(spec=Request)
mock_request.query_params = {"session_id": valid_session_id.hex}
mock_request.headers = {}
mock_request.stream = lambda: _bytes_stream(b"x" * 11)

with pytest.raises(HTTPException) as excinfo:
await transport.handle_fastapi_post_message(mock_request)

assert excinfo.value.status_code == 413


@pytest.mark.anyio
async def test_send_message_safely_with_validation_error(
mock_transport: FastApiSseTransport, mock_writer: AsyncMock
Expand Down