Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,27 @@ def import_tube_data(uri, username, password, data_file):

session.run("""
UNWIND $stations AS station
CREATE (s:UndergroundStation {
id: station.id,
name: station.name,
display_name: CASE station.display_name
MERGE (s:UndergroundStation {id: station.id})
SET s.name = station.name,
s.display_name = CASE station.display_name
WHEN 'NULL' THEN station.name
ELSE station.display_name
END,
latitude: toFloat(station.latitude),
longitude: toFloat(station.longitude),
zone: CASE
s.latitude = toFloat(station.latitude),
s.longitude = toFloat(station.longitude),
s.zone = CASE
WHEN station.zone CONTAINS '.' THEN toFloat(station.zone)
ELSE toInteger(station.zone)
END,
total_lines: toInteger(station.total_lines),
rail: toInteger(station.rail)
})
s.total_lines = toInteger(station.total_lines),
s.rail = toInteger(station.rail)
""", {'stations': data['stations']})

session.run("""
UNWIND $connections AS conn
MATCH (s1:UndergroundStation {id: conn.station1})
MATCH (s2:UndergroundStation {id: conn.station2})
CREATE (s1)-[r:LINK {
MERGE (s1)-[r:LINK {
line: conn.line,
time: toInteger(conn.time),
distance: toInteger(conn.time)
Expand Down
8 changes: 7 additions & 1 deletion mcp_server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,11 @@ target-version = "py311"
[dependency-groups]
dev = [
"ruff>=0.12.2",
"pytest>=8.4.1"
"pytest>=8.4.1",
"pytest-asyncio>=0.23.0",
"pytest-docker>=3.2.3",
"neo4j>=5.0.0"
]

[tool.pytest.ini_options]
asyncio_mode = "auto"
7 changes: 7 additions & 0 deletions mcp_server/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[tool:pytest]
addopts = --docker-compose=tests/docker-compose.yml
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
132 changes: 90 additions & 42 deletions mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@


class DijkstraShortestPathHandler(AlgorithmHandler):
def find_shortest_path(self, start_node: str, end_node: str, **kwargs):
query = """
def find_shortest_path(
self, start_node: str, end_node: str, node_identifier_property: str, **kwargs
):
query = f"""
MATCH (start)
WHERE toLower(start.name) CONTAINS toLower($start_name)
WHERE toLower(start.{node_identifier_property}) CONTAINS toLower($start_name)
MATCH (end)
WHERE toLower(end.name) CONTAINS toLower($end_name)
WHERE toLower(end.{node_identifier_property}) CONTAINS toLower($end_name)
RETURN id(start) as start_id, id(end) as end_id
"""

Expand Down Expand Up @@ -74,15 +76,18 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
return self.find_shortest_path(
arguments.get("start_node"),
arguments.get("end_node"),
arguments.get("nodeIdentifierProperty"),
relationshipWeightProperty=arguments.get("relationship_property"),
)


class DeltaSteppingShortestPathHandler(AlgorithmHandler):
def delta_stepping_shortest_path(self, source_node: str, **kwargs):
query = """
def delta_stepping_shortest_path(
self, source_node: str, node_identifier_property: str, **kwargs
):
query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand Down Expand Up @@ -132,16 +137,19 @@ def delta_stepping_shortest_path(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.delta_stepping_shortest_path(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
delta=arguments.get("delta"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
)


class DijkstraSingleSourceShortestPathHandler(AlgorithmHandler):
def dijkstra_single_source_shortest_path(self, source_node: str, **kwargs):
query = """
def dijkstra_single_source_shortest_path(
self, source_node: str, node_identifier_property: str, **kwargs
):
query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand Down Expand Up @@ -191,17 +199,24 @@ def dijkstra_single_source_shortest_path(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.dijkstra_single_source_shortest_path(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
)


class AStarShortestPathHandler(AlgorithmHandler):
def a_star_shortest_path(self, source_node: str, target_node: str, **kwargs):
query = """
def a_star_shortest_path(
self,
source_node: str,
target_node: str,
node_identifier_property: str,
**kwargs,
):
query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
MATCH (target)
WHERE toLower(target.name) CONTAINS toLower($target_name)
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
RETURN id(source) as source_id, id(target) as target_id
"""

Expand Down Expand Up @@ -255,19 +270,26 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
return self.a_star_shortest_path(
arguments.get("sourceNode"),
arguments.get("targetNode"),
arguments.get("nodeIdentifierProperty"),
latitudeProperty=arguments.get("latitudeProperty"),
longitudeProperty=arguments.get("longitudeProperty"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
)


class YensShortestPathsHandler(AlgorithmHandler):
def yens_shortest_paths(self, source_node: str, target_node: str, **kwargs):
query = """
def yens_shortest_paths(
self,
source_node: str,
target_node: str,
node_identifier_property: str,
**kwargs,
):
query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
MATCH (target)
WHERE toLower(target.name) CONTAINS toLower($target_name)
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
RETURN id(source) as source_id, id(target) as target_id
"""

Expand Down Expand Up @@ -337,16 +359,19 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
return self.yens_shortest_paths(
arguments.get("sourceNode"),
arguments.get("targetNode"),
arguments.get("nodeIdentifierProperty"),
k=arguments.get("k"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
)


class MinimumWeightSpanningTreeHandler(AlgorithmHandler):
def minimum_weight_spanning_tree(self, source_node: str, **kwargs):
query = """
def minimum_weight_spanning_tree(
self, source_node: str, node_identifier_property: str, **kwargs
):
query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand Down Expand Up @@ -408,6 +433,7 @@ def minimum_weight_spanning_tree(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.minimum_weight_spanning_tree(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
objective=arguments.get("objective"),
)
Expand Down Expand Up @@ -449,12 +475,16 @@ def execute(self, arguments: Dict[str, Any]) -> Any:

class MinimumDirectedSteinerTreeHandler(AlgorithmHandler):
def minimum_directed_steiner_tree(
self, source_node: str, target_nodes: list, **kwargs
self,
source_node: str,
target_nodes: list,
node_identifier_property: str,
**kwargs,
):
# Find source node ID
source_query = """
source_query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand All @@ -473,10 +503,10 @@ def minimum_directed_steiner_tree(
unmatched_targets = []

for target_name in target_nodes:
target_query = """
target_query = f"""
MATCH (target)
WHERE toLower(target.name) CONTAINS toLower($target_name)
RETURN id(target) as target_id, target.name as target_name
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
RETURN id(target) as target_id, target.{node_identifier_property} as target_name
"""

target_df = self.gds.run_cypher(
Expand Down Expand Up @@ -554,6 +584,7 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
return self.minimum_directed_steiner_tree(
arguments.get("sourceNode"),
arguments.get("targetNodes"),
arguments.get("nodeIdentifierProperty"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
delta=arguments.get("delta"),
applyRerouting=arguments.get("applyRerouting"),
Expand Down Expand Up @@ -696,10 +727,17 @@ def random_walk(self, **kwargs):
# Process source nodes if provided
source_node_ids = []
if "sourceNodes" in kwargs and kwargs["sourceNodes"]:
node_identifier_property = kwargs.get("nodeIdentifierProperty")
if not node_identifier_property:
return {
"found": False,
"message": "nodeIdentifierProperty is required when sourceNodes are provided",
}

for source_name in kwargs["sourceNodes"]:
source_query = """
source_query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand Down Expand Up @@ -755,6 +793,7 @@ def random_walk(self, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.random_walk(
sourceNodes=arguments.get("sourceNodes"),
nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"),
walkLength=arguments.get("walkLength"),
walksPerNode=arguments.get("walksPerNode"),
inOutFactor=arguments.get("inOutFactor"),
Expand All @@ -765,11 +804,13 @@ def execute(self, arguments: Dict[str, Any]) -> Any:


class BreadthFirstSearchHandler(AlgorithmHandler):
def breadth_first_search(self, source_node: str, **kwargs):
def breadth_first_search(
self, source_node: str, node_identifier_property: str, **kwargs
):
# Find source node ID
source_query = """
source_query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand All @@ -786,9 +827,9 @@ def breadth_first_search(self, source_node: str, **kwargs):
target_node_ids = []
if "targetNodes" in kwargs and kwargs["targetNodes"]:
for target_name in kwargs["targetNodes"]:
target_query = """
target_query = f"""
MATCH (target)
WHERE toLower(target.name) CONTAINS toLower($target_name)
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
RETURN id(target) as target_id
"""

Expand Down Expand Up @@ -852,17 +893,20 @@ def breadth_first_search(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.breadth_first_search(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
targetNodes=arguments.get("targetNodes"),
maxDepth=arguments.get("maxDepth"),
)


class DepthFirstSearchHandler(AlgorithmHandler):
def depth_first_search(self, source_node: str, **kwargs):
def depth_first_search(
self, source_node: str, node_identifier_property: str, **kwargs
):
# Find source node ID
source_query = """
source_query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand All @@ -879,9 +923,9 @@ def depth_first_search(self, source_node: str, **kwargs):
target_node_ids = []
if "targetNodes" in kwargs and kwargs["targetNodes"]:
for target_name in kwargs["targetNodes"]:
target_query = """
target_query = f"""
MATCH (target)
WHERE toLower(target.name) CONTAINS toLower($target_name)
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
RETURN id(target) as target_id
"""

Expand Down Expand Up @@ -945,17 +989,20 @@ def depth_first_search(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.depth_first_search(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
targetNodes=arguments.get("targetNodes"),
maxDepth=arguments.get("maxDepth"),
)


class BellmanFordSingleSourceShortestPathHandler(AlgorithmHandler):
def bellman_ford_single_source_shortest_path(self, source_node: str, **kwargs):
def bellman_ford_single_source_shortest_path(
self, source_node: str, node_identifier_property: str, **kwargs
):
# Find source node ID
source_query = """
source_query = f"""
MATCH (source)
WHERE toLower(source.name) CONTAINS toLower($source_name)
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
RETURN id(source) as source_id
"""

Expand Down Expand Up @@ -1050,6 +1097,7 @@ def bellman_ford_single_source_shortest_path(self, source_node: str, **kwargs):
def execute(self, arguments: Dict[str, Any]) -> Any:
return self.bellman_ford_single_source_shortest_path(
arguments.get("sourceNode"),
arguments.get("nodeIdentifierProperty"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
)

Expand Down
Loading