Skip to content

Commit b85fc51

Browse files
committed
add tests using fastmcp
1 parent a8c2f75 commit b85fc51

File tree

4 files changed

+139
-22
lines changed

4 files changed

+139
-22
lines changed

mcp_server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ target-version = "py311"
2323
dev = [
2424
"ruff>=0.12.2",
2525
"pytest>=8.4.1",
26+
"pytest-asyncio>=1.0.0",
2627
]

mcp_server/src/mcp_server_neo4j_gds/server.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,58 @@ def serialize_result(result: Any) -> str:
4444

4545

4646
def create_algorithm_tool(mcp: FastMCP, tool_name: str, gds: GraphDataScience):
47-
"""Create and register an algorithm tool with proper closure"""
48-
49-
async def algorithm_tool(parameters: Dict[str, Any] = None) -> str:
50-
"""Execute algorithm tool with parameters dictionary"""
51-
try:
52-
handler = AlgorithmRegistry.get_handler(tool_name, gds)
53-
result = handler.execute(parameters or {})
54-
return serialize_result(result)
55-
except Exception as e:
56-
return f"Error executing {tool_name}: {str(e)}"
57-
58-
# Set the function name and docstring
59-
algorithm_tool.__name__ = tool_name
60-
algorithm_tool.__doc__ = f"Execute {tool_name} algorithm with parameters dictionary"
61-
62-
# Register the tool with the server
63-
mcp.tool(algorithm_tool)
64-
65-
66-
def main(db_url: str, username: str, password: str, database: str = None):
67-
"""Main function that sets up and runs the FastMCP server"""
68-
logger.info(f"Starting MCP Server for {db_url} with username {username}")
47+
"""Create and register an algorithm tool using a single parameters dict approach"""
48+
handler = AlgorithmRegistry.get_handler(tool_name, gds)
49+
50+
# Get tool definition from specs to extract parameter information
51+
all_tool_definitions = (
52+
centrality_tool_definitions
53+
+ community_tool_definitions
54+
+ path_tool_definitions
55+
+ similarity_tool_definitions
56+
)
57+
58+
tool_def = None
59+
for tool in all_tool_definitions:
60+
if tool.name == tool_name:
61+
tool_def = tool
62+
break
63+
64+
if tool_def:
65+
# Create a function that accepts a single parameters dict
66+
# This is the only approach that works consistently with FastMCP
67+
async def algorithm_tool(parameters: Dict[str, Any] = None) -> str:
68+
"""Execute algorithm tool with parameters dictionary"""
69+
try:
70+
result = handler.execute(parameters or {})
71+
return serialize_result(result)
72+
except Exception as e:
73+
return f"Error executing {tool_name}: {str(e)}"
74+
75+
# Set the function name and docstring
76+
algorithm_tool.__name__ = tool_name
77+
algorithm_tool.__doc__ = tool_def.description
78+
79+
# Register the tool with the server
80+
mcp.tool(algorithm_tool)
81+
else:
82+
# Fallback for tools without specs
83+
async def algorithm_tool(parameters: Dict[str, Any] = None) -> str:
84+
"""Execute algorithm tool with parameters dictionary"""
85+
try:
86+
result = handler.execute(parameters or {})
87+
return serialize_result(result)
88+
except Exception as e:
89+
return f"Error executing {tool_name}: {str(e)}"
90+
91+
algorithm_tool.__name__ = tool_name
92+
algorithm_tool.__doc__ = f"Execute {tool_name} algorithm with parameters dictionary"
93+
mcp.tool(algorithm_tool)
94+
95+
96+
def setup_server(db_url: str, username: str, password: str, database: str = None) -> FastMCP:
97+
"""Set up the FastMCP server with all tools registered"""
98+
logger.info(f"Setting up MCP Server for {db_url} with username {username}")
6999
if database:
70100
logger.info(f"Connecting to database: {database}")
71101

@@ -142,6 +172,12 @@ async def list_tools() -> str:
142172
for tool_name in algorithm_names:
143173
create_algorithm_tool(mcp, tool_name, gds)
144174

175+
return mcp
176+
177+
178+
def main(db_url: str, username: str, password: str, database: str = None):
179+
"""Main function that sets up and runs the FastMCP server"""
180+
mcp = setup_server(db_url, username, password, database)
145181
# Run the server - FastMCP will handle the event loop
146182
mcp.run()
147183

mcp_server/tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pytest
2+
import asyncio
3+
4+
5+
# Configure pytest-asyncio
6+
pytest_plugins = ["pytest_asyncio"]
7+
8+
9+
@pytest.fixture(scope="session")
10+
def event_loop():
11+
"""Create an instance of the default event loop for the test session."""
12+
loop = asyncio.get_event_loop_policy().new_event_loop()
13+
yield loop
14+
loop.close()

mcp_server/tests/test_basic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,67 @@
1+
import pytest
2+
import asyncio
3+
import json
4+
import os
5+
from dotenv import load_dotenv
6+
from fastmcp import Client
17

8+
9+
@pytest.fixture
10+
def mcp_server():
11+
"""Create a FastMCP server instance using the actual server setup"""
12+
# Load environment variables
13+
load_dotenv("../../../.env")
14+
15+
# Import the actual server setup function
16+
from mcp_server_neo4j_gds.server import setup_server
17+
18+
# Get environment variables
19+
db_url = os.environ.get("NEO4J_URI")
20+
username = os.environ.get("NEO4J_USERNAME", "neo4j")
21+
password = os.environ.get("NEO4J_PASSWORD")
22+
database = os.environ.get("NEO4J_DATABASE")
23+
24+
# Use the actual server setup function
25+
server = setup_server(db_url, username, password, database)
26+
27+
return server
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_find_shortest_path(mcp_server):
32+
"""Test the find_shortest_path tool with correct parameters using the actual server"""
33+
# Pass the server directly to the Client constructor for in-memory testing
34+
async with Client(mcp_server) as client:
35+
# Test the find_shortest_path tool with correct parameter names
36+
result = await client.call_tool(
37+
"find_shortest_path",
38+
{
39+
"start_node": "Tower Hill",
40+
"end_node": "Paddington"
41+
}
42+
)
43+
44+
# Parse the result
45+
result_data = json.loads(result.data)
46+
breakpoint()
47+
48+
# Assertions
49+
assert "totalCost" in result_data, "Result should contain totalCost"
50+
assert "nodeIds" in result_data, "Result should contain nodeIds"
51+
assert "nodeNames" in result_data, "Result should contain nodeNames"
52+
assert "path" in result_data, "Result should contain path"
53+
assert "costs" in result_data, "Result should contain costs"
54+
55+
# Check that we got a valid path
56+
assert result_data["totalCost"] > 0, "Total cost should be positive"
57+
assert len(result_data["nodeIds"]) > 0, "Should have at least one node in path"
58+
assert len(result_data["nodeNames"]) > 0, "Should have at least one node name"
59+
60+
print(f"✅ find_shortest_path test passed!")
61+
print(f"Path from {result_data['nodeNames'][0]} to {result_data['nodeNames'][-1]}")
62+
print(f"Total cost: {result_data['totalCost']}")
63+
print(f"Number of nodes: {len(result_data['nodeIds'])}")
64+
65+
66+
if __name__ == "__main__":
67+
asyncio.run(test_find_shortest_path())

0 commit comments

Comments
 (0)