Skip to content

Commit 4fc66e4

Browse files
authored
Merge pull request #38 from brs96/add-max-flow
Add max flow
2 parents 6f60c4b + b28fc2a commit 4fc66e4

File tree

7 files changed

+187
-1
lines changed

7 files changed

+187
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog - [0.6.0] - yyyy-MM-dd
22

33
### New Features
4+
1. Add new maxflow path algorithm tool.
45

56
### Bug Fixes
67

mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,3 +1165,117 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
11651165
nodeLabels=arguments.get("nodeLabels"),
11661166
relTypes=arguments.get("relTypes"),
11671167
)
1168+
1169+
1170+
class MaxFlowHandler(AlgorithmHandler):
1171+
def max_flow(
1172+
self,
1173+
source_nodes: list,
1174+
target_nodes: list,
1175+
node_identifier_property: str,
1176+
**kwargs,
1177+
):
1178+
source_node_ids = []
1179+
source_node_names = []
1180+
unmatched_sources = []
1181+
1182+
for source_name in source_nodes:
1183+
source_query = f"""
1184+
MATCH (source)
1185+
WHERE toLower(source.{node_identifier_property}) CONTAINS toLower($source_name)
1186+
RETURN id(source) as source_id, source.{node_identifier_property} as source_name
1187+
"""
1188+
1189+
source_df = self.gds.run_cypher(
1190+
source_query, params={"source_name": source_name}
1191+
)
1192+
1193+
if not source_df.empty:
1194+
source_node_ids.append(int(source_df["source_id"].iloc[0]))
1195+
source_node_names.append(source_df["source_name"].iloc[0])
1196+
else:
1197+
unmatched_sources.append(source_name)
1198+
1199+
# Check if all source nodes were found
1200+
if unmatched_sources:
1201+
return {
1202+
"found": False,
1203+
"message": f"The following source nodes were not found: {', '.join(unmatched_sources)}",
1204+
}
1205+
1206+
if not source_node_ids:
1207+
return {"found": False, "message": "No source nodes found"}
1208+
1209+
# Find target node IDs
1210+
target_node_ids = []
1211+
target_node_names = []
1212+
unmatched_targets = []
1213+
1214+
for target_name in target_nodes:
1215+
target_query = f"""
1216+
MATCH (target)
1217+
WHERE toLower(target.{node_identifier_property}) CONTAINS toLower($target_name)
1218+
RETURN id(target) as target_id, target.{node_identifier_property} as target_name
1219+
"""
1220+
1221+
target_df = self.gds.run_cypher(
1222+
target_query, params={"target_name": target_name}
1223+
)
1224+
1225+
if not target_df.empty:
1226+
target_node_ids.append(int(target_df["target_id"].iloc[0]))
1227+
target_node_names.append(target_df["target_name"].iloc[0])
1228+
else:
1229+
unmatched_targets.append(target_name)
1230+
1231+
# Check if all target nodes were found
1232+
if unmatched_targets:
1233+
return {
1234+
"found": False,
1235+
"message": f"The following target nodes were not found: {', '.join(unmatched_targets)}",
1236+
}
1237+
1238+
if not target_node_ids:
1239+
return {"found": False, "message": "No target nodes found"}
1240+
1241+
with projected_graph_from_params(self.gds, **kwargs) as G:
1242+
params = clean_params(kwargs, ["nodeLabels", "relTypes"])
1243+
logger.info(f"Max Flow parameters: {params}")
1244+
1245+
max_flow_data = self.gds.maxFlow.stream(
1246+
G, sourceNodes=source_node_ids, targetNodes=target_node_ids, **params
1247+
)
1248+
1249+
# Get node names using GDS utility function (batch operation)
1250+
max_flow_data["sourceNodeName"] = self.gds.util.asNodes(
1251+
max_flow_data["source"].tolist()
1252+
)
1253+
max_flow_data["targetNodeName"] = self.gds.util.asNodes(
1254+
max_flow_data["target"].tolist()
1255+
)
1256+
1257+
# Convert to list of dictionaries
1258+
flows = max_flow_data[
1259+
[
1260+
"source",
1261+
"target",
1262+
"sourceNodeName",
1263+
"targetNodeName",
1264+
"flow",
1265+
]
1266+
].to_dict("records")
1267+
1268+
return {
1269+
"found": True,
1270+
"flows": flows,
1271+
}
1272+
1273+
def execute(self, arguments: Dict[str, Any]) -> Any:
1274+
return self.max_flow(
1275+
arguments.get("sourceNodes"),
1276+
arguments.get("targetNodes"),
1277+
arguments.get("nodeIdentifierProperty"),
1278+
capacityProperty=arguments.get("capacityProperty"),
1279+
nodeLabels=arguments.get("nodeLabels"),
1280+
relTypes=arguments.get("relTypes"),
1281+
)

mcp_server/src/mcp_server_neo4j_gds/path_algorithm_specs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,4 +609,48 @@
609609
"required": [],
610610
},
611611
),
612+
types.Tool(
613+
name="max_flow",
614+
description="Given source nodes, target nodes and relationships with capacity constraints, the max-flow algorithm assigns a flow to each relationship to achieve maximal transport from source to target. "
615+
"The flow is a scalar property for each relationship and must satisfy 1) Flow into a node equals flow out of a node (preservation). 2) Flow is restricted by the capacity of a relationship",
616+
inputSchema={
617+
"type": "object",
618+
"properties": {
619+
"sourceNodes": {
620+
"type": "array",
621+
"items": {"type": "string"},
622+
"description": "List of source node names from which flow originates.",
623+
},
624+
"targetNodes": {
625+
"type": "array",
626+
"items": {"type": "string"},
627+
"description": "List of target node names to which flow is sent.",
628+
},
629+
"nodeIdentifierProperty": {
630+
"type": "string",
631+
"description": "Property name to use for identifying nodes (e.g., 'name', 'Name', 'title'). Use get_node_properties_keys to find available properties.",
632+
},
633+
"capacityProperty": {
634+
"type": "string",
635+
"description": "Name of the relationship property that specifies the maximum flow capacity for each edge.",
636+
},
637+
"nodeLabels": {
638+
"type": "array",
639+
"items": {"type": "string"},
640+
"description": "The node labels used to project and run max flow on. Nodes with different node labels will be ignored. Do not specify to run for all nodes",
641+
},
642+
"relTypes": {
643+
"type": "array",
644+
"items": {"type": "string"},
645+
"description": "The relationships types used to project and run max flow on. Relationship types of different type will be ignored. Do not specify to run for all relationship types",
646+
},
647+
},
648+
"required": [
649+
"sourceNodes",
650+
"targetNodes",
651+
"nodeIdentifierProperty",
652+
"capacityProperty",
653+
],
654+
},
655+
),
612656
]

mcp_server/src/mcp_server_neo4j_gds/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DepthFirstSearchHandler,
5252
BellmanFordSingleSourceShortestPathHandler,
5353
LongestPathHandler,
54+
MaxFlowHandler,
5455
)
5556

5657

@@ -103,6 +104,7 @@ class AlgorithmRegistry:
103104
"depth_first_search": DepthFirstSearchHandler,
104105
"bellman_ford_single_source_shortest_path": BellmanFordSingleSourceShortestPathHandler,
105106
"longest_path": LongestPathHandler,
107+
"max_flow": MaxFlowHandler,
106108
}
107109

108110
@classmethod

mcp_server/tests/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
services:
22
neo4j:
3-
image: neo4j:2025.05.0
3+
image: neo4j:2025.11.2
44
ports:
55
- "7474"
66
- "7687"

mcp_server/tests/test_basic_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ async def test_list_tools(mcp_client):
7272
"depth_first_search",
7373
"bellman_ford_single_source_shortest_path",
7474
"longest_path",
75+
"max_flow",
7576
# similarity
7677
"node_similarity",
7778
"k_nearest_neighbors",

mcp_server/tests/test_path_algorithms.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,27 @@ async def test_longest_path(mcp_client):
625625
filtered_paths = result_filtered_data["paths"]
626626
assert len(filtered_paths) == 1
627627
assert result_filtered_data["paths"][0]["costs"] == [0.0, 3.0, 6.0, 10.0, 13.0]
628+
629+
630+
@pytest.mark.asyncio
631+
async def test_max_flow(mcp_client):
632+
result = await mcp_client.call_tool(
633+
"max_flow",
634+
{
635+
"sourceNodes": ["Baker Street"],
636+
"targetNodes": [
637+
"Bond Street",
638+
"Euston Square",
639+
"Paddington",
640+
"Wembley Park",
641+
],
642+
"nodeIdentifierProperty": "name",
643+
"capacityProperty": "time",
644+
},
645+
)
646+
647+
assert len(result) == 1
648+
result_text = result[0]["text"]
649+
result_data = json.loads(result_text)
650+
651+
assert len(result_data.get("flows")) == 7

0 commit comments

Comments
 (0)