diff --git a/CHANGELOG.md b/CHANGELOG.md index 88eb66b..0f0152d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ # Changelog - [0.2.0] - yyyy-mm-dd ### New Features +1. Project all (non-string) node properties using the appropriate (integer/float) types. + + +### Bug Fixes +1. Fix GDS call for path algorithms and clean up returned data format. + + +### Other Changes +1. Removed minimum_weight_k_spanning_tree since it is write mode only, which may modify the database unexpectedly. \ No newline at end of file diff --git a/mcp_server/src/mcp_server_neo4j_gds/gds.py b/mcp_server/src/mcp_server_neo4j_gds/gds.py index 4763cab..2a91dd5 100644 --- a/mcp_server/src/mcp_server_neo4j_gds/gds.py +++ b/mcp_server/src/mcp_server_neo4j_gds/gds.py @@ -56,22 +56,95 @@ def projected_graph(gds, undirected=False): """ graph_name = f"temp_graph_{uuid.uuid4().hex[:8]}" try: + # Get relationship properties (non-string) rel_properties = gds.run_cypher( "MATCH (n)-[r]-(m) RETURN DISTINCT keys(properties(r))" )["keys(properties(r))"][0] - # Include all properties that are not STRING - valid_properties = {} + valid_rel_properties = {} for i in range(len(rel_properties)): pi = gds.run_cypher( f"MATCH (n)-[r]-(m) RETURN distinct r.{rel_properties[i]} IS :: STRING AS ISSTRING" ) if pi.shape[0] == 1 and bool(pi["ISSTRING"][0]) is False: - valid_properties[rel_properties[i]] = f"r.{rel_properties[i]}" - prop_map = ", ".join(f"{prop}: r.{prop}" for prop in valid_properties) + valid_rel_properties[rel_properties[i]] = f"r.{rel_properties[i]}" + rel_prop_map = ", ".join(f"{prop}: r.{prop}" for prop in valid_rel_properties) + + # Get node properties (non-string, compatible with GDS) + node_properties = gds.run_cypher( + "MATCH (n) RETURN DISTINCT keys(properties(n))" + )["keys(properties(n))"][0] + valid_node_properties = {} + for i in range(len(node_properties)): + # Check property types and whether all values are whole numbers + type_check = gds.run_cypher( + f""" + MATCH (n) + WHERE n.{node_properties[i]} IS NOT NULL + WITH n.{node_properties[i]} AS prop + RETURN + prop IS :: STRING AS ISSTRING, + CASE + WHEN prop IS :: STRING THEN null + ELSE prop % 1 = 0 + END AS IS_WHOLE_NUMBER + LIMIT 10 + """ + ) + if not type_check.empty: + # Check if any value is a string - if so, skip this property + has_strings = any(type_check["ISSTRING"]) + + if not has_strings: + # All values are numeric, check if all are whole numbers + whole_numbers = type_check["IS_WHOLE_NUMBER"].dropna() + if len(whole_numbers) > 0 and all(whole_numbers): + # All values are whole numbers - use as integer + valid_node_properties[node_properties[i]] = ( + f"n.{node_properties[i]}" + ) + else: + # Has decimal values - use as float + valid_node_properties[node_properties[i]] = ( + f"toFloat(n.{node_properties[i]})" + ) + + node_prop_map = ", ".join( + f"{prop}: {expr}" for prop, expr in valid_node_properties.items() + ) + logger.info(f"Node property map: '{node_prop_map}'") # Configure graph projection based on undirected parameter + # Create data configuration (node/relationship structure) + data_config_parts = [ + "sourceNodeLabels: labels(n)", + "targetNodeLabels: labels(m)", + "relationshipType: type(r)", + ] + + if node_prop_map: + data_config_parts.extend( + [ + f"sourceNodeProperties: {{{node_prop_map}}}", + f"targetNodeProperties: {{{node_prop_map}}}", + ] + ) + + if rel_prop_map: + data_config_parts.append(f"relationshipProperties: {{{rel_prop_map}}}") + + data_config = ", ".join(data_config_parts) + + # Create additional configuration + additional_config_parts = [] if undirected: - # For undirected graphs, use undirectedRelationshipTypes: ['*'] to make all relationships undirected + additional_config_parts.append("undirectedRelationshipTypes: ['*']") + + additional_config = ( + ", ".join(additional_config_parts) if additional_config_parts else "" + ) + + # Use separate data and additional configuration parameters + if additional_config: G, _ = gds.graph.cypher.project( f""" MATCH (n)-[r]-(m) @@ -80,19 +153,13 @@ def projected_graph(gds, undirected=False): $graph_name, n, m, - {{ - sourceNodeLabels: labels(n), - targetNodeLabels: labels(m), - relationshipType: type(r), - relationshipProperties: {{{prop_map}}}, - undirectedRelationshipTypes: ['*'] - }} + {{{data_config}}}, + {{{additional_config}}} ) """, graph_name=graph_name, ) else: - # Default directed projection G, _ = gds.graph.cypher.project( f""" MATCH (n)-[r]-(m) @@ -101,12 +168,7 @@ def projected_graph(gds, undirected=False): $graph_name, n, m, - {{ - sourceNodeLabels: labels(n), - targetNodeLabels: labels(m), - relationshipType: type(r), - relationshipProperties: {{{prop_map}}} - }} + {{{data_config}}} ) """, graph_name=graph_name, diff --git a/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py b/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py index d2e1f14..e66ccca 100644 --- a/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py +++ b/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py @@ -98,7 +98,7 @@ def delta_stepping_shortest_path( params = {k: v for k, v in kwargs.items() if v is not None} logger.info(f"Delta-Stepping shortest path parameters: {params}") - path_data = self.gds.shortestPath.deltaStepping.stream( + path_data = self.gds.allShortestPaths.delta.stream( G, sourceNode=source_node_id, **params ) @@ -111,22 +111,42 @@ def delta_stepping_shortest_path( # Convert to native Python types as needed result_data = [] for _, row in path_data.iterrows(): - node_id = int(row["targetNode"]) - cost = float(row["cost"]) + target_node_id = int(row["targetNode"]) + total_cost = float(row["totalCost"]) - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Get the path details + node_ids = row["nodeIds"] + costs = row["costs"] + path = row["path"] + + # Convert to native Python types if needed + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + if hasattr(costs, "tolist"): + costs = costs.tolist() + + # Get node names using GDS utility function + target_node_name = self.gds.util.asNode(target_node_id) + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] result_data.append( - {"targetNodeId": node_id, "targetNodeName": node_name, "cost": cost} + { + "targetNode": target_node_id, + "targetNodeName": target_node_name, + "totalCost": total_cost, + "nodeIds": node_ids, + "nodeNames": node_names, + "costs": costs, + "path": path, + } ) + # Do we need to return the sourceNodeId and sourceNodeName? return { "found": True, "sourceNodeId": source_node_id, "sourceNodeName": self.gds.util.asNode(source_node_id), - "paths": result_data, - "totalPaths": len(result_data), + "results": result_data, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -160,7 +180,7 @@ def dijkstra_single_source_shortest_path( params = {k: v for k, v in kwargs.items() if v is not None} logger.info(f"Dijkstra single-source shortest path parameters: {params}") - path_data = self.gds.shortestPath.dijkstra.stream( + path_data = self.gds.allShortestPaths.dijkstra.stream( G, sourceNode=source_node_id, **params ) @@ -173,22 +193,41 @@ def dijkstra_single_source_shortest_path( # Convert to native Python types as needed result_data = [] for _, row in path_data.iterrows(): - node_id = int(row["targetNode"]) - cost = float(row["cost"]) + target_node_id = int(row["targetNode"]) + total_cost = float(row["totalCost"]) - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Get the path details + node_ids = row["nodeIds"] + costs = row["costs"] + path = row["path"] + + # Convert to native Python types if needed + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + if hasattr(costs, "tolist"): + costs = costs.tolist() + + # Get node names using GDS utility function + target_node_name = self.gds.util.asNode(target_node_id) + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] result_data.append( - {"targetNodeId": node_id, "targetNodeName": node_name, "cost": cost} + { + "targetNode": target_node_id, + "targetNodeName": target_node_name, + "totalCost": total_cost, + "nodeIds": node_ids, + "nodeNames": node_names, + "costs": costs, + "path": path, + } ) return { "found": True, "sourceNodeId": source_node_id, "sourceNodeName": self.gds.util.asNode(source_node_id), - "paths": result_data, - "totalPaths": len(result_data), + "results": result_data, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -346,8 +385,8 @@ def yens_shortest_paths( "sourceNodeName": self.gds.util.asNode(source_node_id), "targetNodeId": target_node_id, "targetNodeName": self.gds.util.asNode(target_node_id), - "paths": result_data, - "totalPaths": len(result_data), + "results": result_data, + "totalResults": len(result_data), } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -393,36 +432,38 @@ def minimum_weight_spanning_tree( } # Convert to native Python types as needed - result_data = [] + edges = [] total_weight = 0.0 for _, row in mst_data.iterrows(): - source_id = int(row["sourceNode"]) - target_id = int(row["targetNode"]) - weight = float(row["cost"]) + node_id = int(row["nodeId"]) + parent_id = int(row["parentId"]) + weight = float(row["weight"]) + + # Skip the root node (where nodeId == parentId) + if node_id == parent_id: + continue + total_weight += weight # Get node names using GDS utility function - source_name = self.gds.util.asNode(source_id) - target_name = self.gds.util.asNode(target_id) + parent_name = self.gds.util.asNode(parent_id) + node_name = self.gds.util.asNode(node_id) - result_data.append( + edges.append( { - "sourceNodeId": source_id, - "sourceNodeName": source_name, - "targetNodeId": target_id, - "targetNodeName": target_name, - "cost": weight, + "nodeId": node_id, + "parentId": parent_id, + "nodeName": node_name, + "parentName": parent_name, + "weight": weight, } ) return { "found": True, - "sourceNodeId": source_node_id, - "sourceNodeName": self.gds.util.asNode(source_node_id), "totalWeight": total_weight, - "relationships": result_data, - "totalRelationships": len(result_data), + "edges": edges, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -434,40 +475,6 @@ def execute(self, arguments: Dict[str, Any]) -> Any: ) -class MinimumWeightKSpanningTreeHandler(AlgorithmHandler): - def minimum_weight_k_spanning_tree(self, write_property: str, k: int, **kwargs): - with projected_graph(self.gds, undirected=True) as G: - # If any optional parameter is not None, use that parameter - params = {k: v for k, v in kwargs.items() if v is not None} - logger.info(f"Minimum Weight K-Spanning Tree parameters: {params}") - - # Run the k-spanning tree algorithm - result = self.gds.kSpanningTree.write( - G, writeProperty=write_property, k=k, **params - ) - - # The write procedure returns performance metrics and effectiveNodeCount - # The results are written to the database with the specified writeProperty - return { - "found": True, - "writeProperty": write_property, - "k": k, - "effectiveNodeCount": int(result["effectiveNodeCount"]), - "preProcessingMillis": int(result["preProcessingMillis"]), - "computeMillis": int(result["computeMillis"]), - "writeMillis": int(result["writeMillis"]), - "message": f"K-spanning tree with {result['effectiveNodeCount']} nodes written to property '{write_property}'", - } - - def execute(self, arguments: Dict[str, Any]) -> Any: - return self.minimum_weight_k_spanning_tree( - arguments.get("writeProperty"), - arguments.get("k"), - relationshipWeightProperty=arguments.get("relationshipWeightProperty"), - objective=arguments.get("objective"), - ) - - class MinimumDirectedSteinerTreeHandler(AlgorithmHandler): def minimum_directed_steiner_tree( self, @@ -541,38 +548,32 @@ def minimum_directed_steiner_tree( } # Convert to native Python types as needed - result_data = [] + edges = [] total_weight = 0.0 for _, row in steiner_data.iterrows(): - source_id = int(row["sourceNode"]) - target_id = int(row["targetNode"]) - weight = float(row["cost"]) - total_weight += weight + node_id = int(row["nodeId"]) + parent_id = int(row["parentId"]) + weight = float(row["weight"]) - # Get node names using GDS utility function - source_name = self.gds.util.asNode(source_id) - target_name = self.gds.util.asNode(target_id) + # Skip the root node (where nodeId == parentId) + if node_id == parent_id: + continue - result_data.append( + total_weight += weight + + edges.append( { - "sourceNodeId": source_id, - "sourceNodeName": source_name, - "targetNodeId": target_id, - "targetNodeName": target_name, - "cost": weight, + "nodeId": node_id, + "parentId": parent_id, + "weight": weight, } ) return { "found": True, - "sourceNodeId": source_node_id, - "sourceNodeName": self.gds.util.asNode(source_node_id), - "targetNodes": target_node_names, - "targetNodeIds": target_node_ids, "totalWeight": total_weight, - "relationships": result_data, - "totalRelationships": len(result_data), + "edges": edges, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -588,7 +589,7 @@ def execute(self, arguments: Dict[str, Any]) -> Any: class PrizeCollectingSteinerTreeHandler(AlgorithmHandler): def prize_collecting_steiner_tree(self, **kwargs): - with projected_graph(self.gds) as G: + with projected_graph(self.gds, undirected=True) as G: # If any optional parameter is not None, use that parameter params = {k: v for k, v in kwargs.items() if v is not None} logger.info(f"Prize-Collecting Steiner Tree parameters: {params}") @@ -603,48 +604,32 @@ def prize_collecting_steiner_tree(self, **kwargs): } # Convert to native Python types as needed - result_data = [] + edges = [] total_weight = 0.0 - total_prize = 0.0 for _, row in steiner_data.iterrows(): - source_id = int(row["sourceNode"]) - target_id = int(row["targetNode"]) - weight = float(row["cost"]) - total_weight += weight + node_id = int(row["nodeId"]) + parent_id = int(row["parentId"]) + weight = float(row["weight"]) - # Get node names using GDS utility function - source_name = self.gds.util.asNode(source_id) - target_name = self.gds.util.asNode(target_id) + # Skip the root node (where nodeId == parentId) + if node_id == parent_id: + continue - result_data.append( + total_weight += weight + + edges.append( { - "sourceNodeId": source_id, - "sourceNodeName": source_name, - "targetNodeId": target_id, - "targetNodeName": target_name, - "cost": weight, + "nodeId": node_id, + "parentId": parent_id, + "weight": weight, } ) - # Calculate total prize from nodes in the tree - if "prizeProperty" in params: - prize_query = f""" - MATCH (n) - WHERE n.{params["prizeProperty"]} IS NOT NULL - RETURN sum(n.{params["prizeProperty"]}) as totalPrize - """ - prize_df = self.gds.run_cypher(prize_query) - if not prize_df.empty: - total_prize = float(prize_df["totalPrize"].iloc[0]) - return { "found": True, "totalWeight": total_weight, - "totalPrize": total_prize, - "netValue": total_prize - total_weight, - "relationships": result_data, - "totalRelationships": len(result_data), + "edges": edges, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -668,47 +653,24 @@ def all_pairs_shortest_paths(self, **kwargs): return {"found": False, "message": "No shortest paths found"} # Convert to native Python types as needed - result_data = [] - finite_paths = 0 - infinite_paths = 0 + paths = [] for _, row in apsp_data.iterrows(): - source_id = int(row["sourceNode"]) - target_id = int(row["targetNode"]) - cost = row["cost"] - - # Check if the cost is finite (not infinity) - is_finite = self.gds.util.isFinite(cost) + source_id = int(row["sourceNodeId"]) + target_id = int(row["targetNodeId"]) + distance = float(row["distance"]) - if is_finite: - finite_paths += 1 - cost_value = float(cost) - else: - infinite_paths += 1 - cost_value = float("inf") - - # Get node names using GDS utility function - source_name = self.gds.util.asNode(source_id) - target_name = self.gds.util.asNode(target_id) - - result_data.append( + paths.append( { "sourceNodeId": source_id, - "sourceNodeName": source_name, "targetNodeId": target_id, - "targetNodeName": target_name, - "cost": cost_value, - "isFinite": is_finite, + "distance": distance, } ) return { "found": True, - "totalPairs": len(result_data), - "finitePaths": finite_paths, - "infinitePaths": infinite_paths, - "paths": result_data, - "message": f"Found {finite_paths} finite paths and {infinite_paths} infinite paths between all pairs of nodes", + "paths": paths, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -744,8 +706,12 @@ def random_walk(self, **kwargs): source_node_ids.append(int(source_df["source_id"].iloc[0])) with projected_graph(self.gds) as G: - # Prepare parameters for the random walk algorithm - params = {k: v for k, v in kwargs.items() if v is not None} + # Prepare parameters for the random walk algorithm, excluding our internal parameters + params = { + k: v + for k, v in kwargs.items() + if v is not None and k != "nodeIdentifierProperty" + } # Add source nodes if found if source_node_ids: @@ -760,29 +726,28 @@ def random_walk(self, **kwargs): return {"found": False, "message": "No random walks generated"} # Convert to native Python types as needed - result_data = [] - total_walks = 0 + walks = [] for _, row in walk_data.iterrows(): - walk_id = int(row["walkId"]) - path = row["path"] + node_ids = row["nodeIds"] + # Convert node_ids to list if it's not already + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() - # Convert path to list of node names - node_names = [] - for node_id in path: - node_name = self.gds.util.asNode(node_id) - node_names.append(node_name) + # Get node names using GDS utility function + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] - result_data.append( - {"walkId": walk_id, "path": node_names, "pathLength": len(path)} + walks.append( + { + "nodeIds": node_ids, + "nodeNames": node_names, + "walkLength": len(node_ids), + } ) - total_walks += 1 return { "found": True, - "totalWalks": total_walks, - "walks": result_data, - "message": f"Generated {total_walks} random walks", + "walks": walks, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -836,8 +801,12 @@ def breadth_first_search( target_node_ids.append(int(target_df["target_id"].iloc[0])) with projected_graph(self.gds) as G: - # Prepare parameters for the BFS algorithm - params = {k: v for k, v in kwargs.items() if v is not None} + # Prepare parameters for the BFS algorithm, excluding our internal parameters + params = { + k: v + for k, v in kwargs.items() + if v is not None and k != "nodeIdentifierProperty" + } # Add target nodes if found if target_node_ids: @@ -855,34 +824,31 @@ def breadth_first_search( } # Convert to native Python types as needed - result_data = [] - visited_nodes = 0 + traversals = [] for _, row in bfs_data.iterrows(): - node_id = int(row["nodeId"]) - depth = int(row["depth"]) + source_node = int(row["sourceNode"]) + node_ids = row["nodeIds"] - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Convert node_ids to list if it's not already + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() - result_data.append( - {"nodeId": node_id, "nodeName": node_name, "depth": depth} - ) - visited_nodes += 1 + # Get node names using GDS utility function + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] - # Sort by depth for better readability - result_data.sort(key=lambda x: x["depth"]) + traversals.append( + { + "sourceNode": source_node, + "nodeIds": node_ids, + "nodeNames": node_names, + "visitedNodes": len(node_ids), + } + ) return { "found": True, - "sourceNodeId": source_node_id, - "sourceNodeName": self.gds.util.asNode(source_node_id), - "visitedNodes": visited_nodes, - "nodes": result_data, - "maxDepthReached": max([node["depth"] for node in result_data]) - if result_data - else 0, - "message": f"Visited {visited_nodes} nodes starting from '{self.gds.util.asNode(source_node_id)}'", + "traversals": traversals, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -932,8 +898,12 @@ def depth_first_search( target_node_ids.append(int(target_df["target_id"].iloc[0])) with projected_graph(self.gds) as G: - # Prepare parameters for the DFS algorithm - params = {k: v for k, v in kwargs.items() if v is not None} + # Prepare parameters for the DFS algorithm, excluding our internal parameters + params = { + k: v + for k, v in kwargs.items() + if v is not None and k != "nodeIdentifierProperty" + } # Add target nodes if found if target_node_ids: @@ -951,34 +921,31 @@ def depth_first_search( } # Convert to native Python types as needed - result_data = [] - visited_nodes = 0 + traversals = [] for _, row in dfs_data.iterrows(): - node_id = int(row["nodeId"]) - depth = int(row["depth"]) + source_node = int(row["sourceNode"]) + node_ids = row["nodeIds"] - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Convert node_ids to list if it's not already + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() - result_data.append( - {"nodeId": node_id, "nodeName": node_name, "depth": depth} - ) - visited_nodes += 1 + # Get node names using GDS utility function + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] - # Sort by depth for better readability - result_data.sort(key=lambda x: x["depth"]) + traversals.append( + { + "sourceNode": source_node, + "nodeIds": node_ids, + "nodeNames": node_names, + "visitedNodes": len(node_ids), + } + ) return { "found": True, - "sourceNodeId": source_node_id, - "sourceNodeName": self.gds.util.asNode(source_node_id), - "visitedNodes": visited_nodes, - "nodes": result_data, - "maxDepthReached": max([node["depth"] for node in result_data]) - if result_data - else 0, - "message": f"Visited {visited_nodes} nodes starting from '{self.gds.util.asNode(source_node_id)}'", + "traversals": traversals, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -1011,8 +978,12 @@ def bellman_ford_single_source_shortest_path( source_node_id = int(source_df["source_id"].iloc[0]) with projected_graph(self.gds) as G: - # If any optional parameter is not None, use that parameter - params = {k: v for k, v in kwargs.items() if v is not None} + # Prepare parameters for the Bellman-Ford algorithm, excluding our internal parameters + params = { + k: v + for k, v in kwargs.items() + if v is not None and k != "nodeIdentifierProperty" + } logger.info( f"Bellman-Ford Single-Source Shortest Path parameters: {params}" ) @@ -1029,64 +1000,42 @@ def bellman_ford_single_source_shortest_path( } # Convert to native Python types as needed - result_data = [] - negative_cycles = [] + paths = [] for _, row in bellman_ford_data.iterrows(): - node_id = int(row["targetNode"]) - cost = row["cost"] - - # Check if the cost is finite (not infinity) - is_finite = self.gds.util.isFinite(cost) + index = int(row["index"]) + source_node = int(row["sourceNode"]) + target_node = int(row["targetNode"]) + total_cost = float(row["totalCost"]) + node_ids = row["nodeIds"] + costs = row["costs"] + is_negative_cycle = bool(row["isNegativeCycle"]) - if is_finite: - cost_value = float(cost) - else: - cost_value = float("inf") + # Convert arrays to lists if needed + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + if hasattr(costs, "tolist"): + costs = costs.tolist() - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Get node names using GDS utility function + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] - result_data.append( + paths.append( { - "targetNodeId": node_id, - "targetNodeName": node_name, - "cost": cost_value, - "isFinite": is_finite, + "index": index, + "sourceNode": source_node, + "targetNode": target_node, + "totalCost": total_cost, + "nodeIds": node_ids, + "nodeNames": node_names, + "costs": costs, + "isNegativeCycle": is_negative_cycle, } ) - # Check for negative cycles in the result - # If there are negative cycles, the algorithm might return them instead of shortest paths - if "negativeCycle" in bellman_ford_data.columns: - for _, row in bellman_ford_data.iterrows(): - if "negativeCycle" in row and row["negativeCycle"] is not None: - cycle = row["negativeCycle"] - if hasattr(cycle, "tolist"): - cycle = cycle.tolist() - - # Convert cycle node IDs to names - cycle_names = [ - self.gds.util.asNode(node_id) for node_id in cycle - ] - negative_cycles.append( - {"cycle": cycle_names, "cycleLength": len(cycle)} - ) - return { "found": True, - "sourceNodeId": source_node_id, - "sourceNodeName": self.gds.util.asNode(source_node_id), - "paths": result_data, - "totalPaths": len(result_data), - "negativeCycles": negative_cycles, - "hasNegativeCycles": len(negative_cycles) > 0, - "message": f"Found {len(result_data)} paths from '{self.gds.util.asNode(source_node_id)}'" - + ( - f" and {len(negative_cycles)} negative cycles" - if negative_cycles - else "" - ), + "paths": paths, } def execute(self, arguments: Dict[str, Any]) -> Any: @@ -1100,8 +1049,12 @@ def execute(self, arguments: Dict[str, Any]) -> Any: class LongestPathHandler(AlgorithmHandler): def longest_path(self, **kwargs): with projected_graph(self.gds) as G: - # If any optional parameter is not None, use that parameter - params = {k: v for k, v in kwargs.items() if v is not None} + # Prepare parameters for the longest path algorithm, excluding our internal parameters + params = { + k: v + for k, v in kwargs.items() + if v is not None and k != "nodeIdentifierProperty" + } logger.info(f"Longest Path parameters: {params}") # Run the longest path algorithm @@ -1114,55 +1067,40 @@ def longest_path(self, **kwargs): } # Convert to native Python types as needed - result_data = [] - total_weight = 0.0 + paths = [] for _, row in longest_path_data.iterrows(): - node_id = int(row["nodeId"]) - cost = row["cost"] - - # Check if the cost is finite (not infinity) - is_finite = self.gds.util.isFinite(cost) + index = int(row["index"]) + source_node = int(row["sourceNode"]) + target_node = int(row["targetNode"]) + total_cost = float(row["totalCost"]) + node_ids = row["nodeIds"] + costs = row["costs"] - if is_finite: - cost_value = float(cost) - total_weight += cost_value - else: - cost_value = float("inf") + # Convert arrays to lists if needed + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + if hasattr(costs, "tolist"): + costs = costs.tolist() - # Get node name using GDS utility function - node_name = self.gds.util.asNode(node_id) + # Get node names using GDS utility function + node_names = [self.gds.util.asNode(node_id) for node_id in node_ids] - result_data.append( + paths.append( { - "nodeId": node_id, - "nodeName": node_name, - "longestPathCost": cost_value, - "isFinite": is_finite, + "index": index, + "sourceNode": source_node, + "targetNode": target_node, + "totalCost": total_cost, + "nodeIds": node_ids, + "nodeNames": node_names, + "costs": costs, } ) - # Sort by cost for better readability (highest first) - result_data.sort( - key=lambda x: x["longestPathCost"] if x["isFinite"] else float("inf"), - reverse=True, - ) - return { "found": True, - "nodes": result_data, - "totalNodes": len(result_data), - "totalWeight": total_weight, - "maxLongestPath": max( - [ - node["longestPathCost"] - for node in result_data - if node["isFinite"] - ] - ) - if any(node["isFinite"] for node in result_data) - else 0, - "message": f"Found longest paths for {len(result_data)} nodes in DAG components", + "paths": paths, } def execute(self, arguments: Dict[str, Any]) -> Any: diff --git a/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_specs.py b/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_specs.py index 074de60..d9bebd2 100644 --- a/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_specs.py +++ b/mcp_server/src/mcp_server_neo4j_gds/path_algorithm_specs.py @@ -122,7 +122,13 @@ "description": "Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.", }, }, - "required": ["sourceNode", "targetNode", "nodeIdentifierProperty"], + "required": [ + "sourceNode", + "targetNode", + "nodeIdentifierProperty", + "latitudeProperty", + "longitudeProperty", + ], }, ), types.Tool( @@ -197,39 +203,6 @@ "required": ["sourceNode", "nodeIdentifierProperty"], }, ), - types.Tool( - name="minimum_weight_k_spanning_tree", - description="Sometimes, we might require a spanning tree(a tree where its nodes are connected with each via a single path) that does not necessarily span all nodes in the graph. " - "The K-Spanning tree heuristic algorithm returns a tree with k nodes and k − 1 relationships. " - "Our heuristic processes the result found by Prim's algorithm for the Minimum Weight Spanning Tree problem. " - "Like Prim, it starts from a given source node, finds a spanning tree for all nodes and then removes nodes using heuristics to produce a tree with 'k' nodes. " - "Note that the source node will not be necessarily included in the final output as the heuristic tries to find a globally good tree. " - "The Minimum weight k-Spanning Tree is NP-Hard. The algorithm in the Neo4j GDS Library is therefore not guaranteed to find the optimal answer, but should hopefully return a good approximation in practice. " - "Like Prim algorithm, the algorithm focuses only on the component of the source node. If that component has fewer than k nodes, it will not look into other components, but will instead return the component.", - inputSchema={ - "type": "object", - "properties": { - "writeProperty": { - "type": "string", - "description": "The node property in the Neo4j database to which the spanning tree is written.", - }, - "k": { - "type": "integer", - "description": "The size of the tree to be returned.", - }, - "relationshipWeightProperty": { - "type": "string", - "description": "Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.", - }, - "objective": { - "type": "string", - "enum": ["minimum", "maximum"], - "description": "If specified, the parameter dictates whether to seek a minimum or the maximum weight k-spanning tree. By default, the procedure looks for a minimum weight k-spanning tree. Permitted values are 'minimum' and 'maximum'.", - }, - }, - "required": ["writeProperty", "k"], - }, - ), types.Tool( name="minimum_directed_steiner_tree", description="Given a source node and a list of target nodes, a directed spanning tree in which there exists a path from the source node to each of the target nodes is called a Directed Steiner Tree. " diff --git a/mcp_server/src/mcp_server_neo4j_gds/registry.py b/mcp_server/src/mcp_server_neo4j_gds/registry.py index 2e54ca6..f0dea86 100644 --- a/mcp_server/src/mcp_server_neo4j_gds/registry.py +++ b/mcp_server/src/mcp_server_neo4j_gds/registry.py @@ -45,7 +45,6 @@ AStarShortestPathHandler, YensShortestPathsHandler, MinimumWeightSpanningTreeHandler, - MinimumWeightKSpanningTreeHandler, MinimumDirectedSteinerTreeHandler, PrizeCollectingSteinerTreeHandler, AllPairsShortestPathsHandler, @@ -100,7 +99,6 @@ class AlgorithmRegistry: "a_star_shortest_path": AStarShortestPathHandler, "yens_shortest_paths": YensShortestPathsHandler, "minimum_weight_spanning_tree": MinimumWeightSpanningTreeHandler, - "minimum_weight_k_spanning_tree": MinimumWeightKSpanningTreeHandler, "minimum_directed_steiner_tree": MinimumDirectedSteinerTreeHandler, "prize_collecting_steiner_tree": PrizeCollectingSteinerTreeHandler, "all_pairs_shortest_paths": AllPairsShortestPathsHandler, diff --git a/mcp_server/tests/conftest.py b/mcp_server/tests/conftest.py index 4ef62bc..a63c9d3 100644 --- a/mcp_server/tests/conftest.py +++ b/mcp_server/tests/conftest.py @@ -164,6 +164,7 @@ async def mcp_server_process(import_test_data): stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + limit=1024 * 1024 * 10, # 10MB buffer limit ) # Wait a moment for the server to initialize @@ -210,7 +211,15 @@ async def send_request(self, method, params=None): response_line = await self.process.stdout.readline() if not response_line: raise RuntimeError("No response from MCP server") - response = json.loads(response_line.decode().strip()) + + # Decode with error handling for large responses + try: + response_text = response_line.decode().strip() + response = json.loads(response_text) + except UnicodeDecodeError as e: + raise RuntimeError(f"Failed to decode response: {e}") + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse JSON response: {e}") return response diff --git a/mcp_server/tests/test_tools.py b/mcp_server/tests/test_basic_tools.py similarity index 62% rename from mcp_server/tests/test_tools.py rename to mcp_server/tests/test_basic_tools.py index 093b7a8..103da14 100644 --- a/mcp_server/tests/test_tools.py +++ b/mcp_server/tests/test_basic_tools.py @@ -2,58 +2,6 @@ import json -@pytest.mark.asyncio -async def test_find_shortest_path(mcp_client): - result = await mcp_client.call_tool( - "find_shortest_path", - { - "start_node": "Canada Water", - "end_node": "Tower Hill", - "nodeIdentifierProperty": "name", - "relationship_property": "time", - }, - ) - - assert len(result) == 1 - result_text = result[0]["text"] - result_data = json.loads(result_text) - - assert "nodeNames" in result_data - assert result_data["totalCost"] == 9.0 - expected_node_ids = [292, 188, 243, 196, 261, 2, 230] - assert result_data["nodeIds"] == expected_node_ids - - node_names = result_data["nodeNames"] - assert len(node_names) == 7 - assert "Canada Water" in node_names[0] - assert "Tower Hill" in node_names[-1] - expected_stations = [ - "Canada Water", - "Rotherhithe", - "Wapping", - "Shadwell", - "Whitechapel", - "Aldgate East", - "Tower Hill", - ] - for i, expected_station in enumerate(expected_stations): - assert expected_station in node_names[i] - - # Test with stations that should not have a path - result = await mcp_client.call_tool( - "find_shortest_path", - { - "start_node": "NonExistentStation1", - "end_node": "NonExistentStation2", - "nodeIdentifierProperty": "name", - }, - ) - - result_text = result[0]["text"] - result_data = json.loads(result_text) - assert result_data["found"] is False - - @pytest.mark.asyncio async def test_count_nodes(mcp_client): result = await mcp_client.call_tool("count_nodes") @@ -113,7 +61,6 @@ async def test_list_tools(mcp_client): "a_star_shortest_path", "yens_shortest_paths", "minimum_weight_spanning_tree", - "minimum_weight_k_spanning_tree", "minimum_directed_steiner_tree", "prize_collecting_steiner_tree", "all_pairs_shortest_paths", @@ -130,3 +77,23 @@ async def test_list_tools(mcp_client): assert expected_tool in tool_names, ( f"Expected tool '{expected_tool}' not found in tool list" ) + + +@pytest.mark.asyncio +async def test_get_node_properties_keys(mcp_client): + result = await mcp_client.call_tool("get_node_properties_keys") + + assert len(result) == 1 + result_text = result[0]["text"] + properties_keys = json.loads(result_text) + + assert properties_keys == [ + "zone", + "rail", + "latitude", + "name", + "total_lines", + "id", + "display_name", + "longitude", + ] diff --git a/mcp_server/tests/test_path_algorithms.py b/mcp_server/tests/test_path_algorithms.py new file mode 100644 index 0000000..a84b6c2 --- /dev/null +++ b/mcp_server/tests/test_path_algorithms.py @@ -0,0 +1,608 @@ +import pytest +import json + + +@pytest.mark.asyncio +async def test_find_shortest_path(mcp_client): + result = await mcp_client.call_tool( + "find_shortest_path", + { + "start_node": "Canada Water", + "end_node": "Tower Hill", + "nodeIdentifierProperty": "name", + "relationship_property": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert "nodeNames" in result_data + assert result_data["totalCost"] == 9.0 + expected_node_ids = [292, 188, 243, 196, 261, 2, 230] + assert result_data["nodeIds"] == expected_node_ids + + node_names = result_data["nodeNames"] + assert len(node_names) == 7 + assert "Canada Water" in node_names[0] + assert "Tower Hill" in node_names[-1] + expected_stations = [ + "Canada Water", + "Rotherhithe", + "Wapping", + "Shadwell", + "Whitechapel", + "Aldgate East", + "Tower Hill", + ] + for i, expected_station in enumerate(expected_stations): + assert expected_station in node_names[i] + + # Test with stations that should not have a path + result = await mcp_client.call_tool( + "find_shortest_path", + { + "start_node": "NonExistentStation1", + "end_node": "NonExistentStation2", + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_delta_stepping_shortest_path(mcp_client): + result = await mcp_client.call_tool( + "delta_stepping_shortest_path", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "delta": 2.0, + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_data = json.loads(result[0]["text"]) + + assert result_data["found"] is True + assert "sourceNodeId" in result_data + assert "sourceNodeName" in result_data + assert "results" in result_data + + assert "Canada Water" in result_data["sourceNodeName"] + + results = result_data["results"] + assert len(results) == 302 + # Verify structure of a result entry + assert "targetNode" in results[42] + assert "targetNodeName" in results[42] + assert "totalCost" in results[42] + assert "nodeIds" in results[42] + assert "nodeNames" in results[42] + assert "costs" in results[42] + assert "path" in results[42] + + result = await mcp_client.call_tool( + "delta_stepping_shortest_path", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + "delta": 1.0, + }, + ) + + result_data = json.loads(result[0]["text"]) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_dijkstra_single_source_shortest_path(mcp_client): + result = await mcp_client.call_tool( + "dijkstra_single_source_shortest_path", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_data = json.loads(result[0]["text"]) + + assert result_data["found"] is True + assert "sourceNodeId" in result_data + assert "sourceNodeName" in result_data + assert "results" in result_data + + assert "Canada Water" in result_data["sourceNodeName"] + + results = result_data["results"] + assert len(results) == 302 + # Verify structure of a result entry + assert "targetNode" in results[42] + assert "targetNodeName" in results[42] + assert "totalCost" in results[42] + assert "nodeIds" in results[42] + assert "nodeNames" in results[42] + assert "costs" in results[42] + assert "path" in results[42] + + result = await mcp_client.call_tool( + "dijkstra_single_source_shortest_path", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + }, + ) + + result_data = json.loads(result[0]["text"]) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_a_star_shortest_path(mcp_client): + result = await mcp_client.call_tool( + "a_star_shortest_path", + { + "sourceNode": "Canada Water", + "targetNode": "Tower Hill", + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + "latitudeProperty": "latitude", + "longitudeProperty": "longitude", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert "nodeNames" in result_data + assert result_data["totalCost"] == 9.0 + expected_node_ids = [292, 188, 243, 196, 261, 2, 230] + assert result_data["nodeIds"] == expected_node_ids + + node_names = result_data["nodeNames"] + assert len(node_names) == 7 + assert "Canada Water" in node_names[0] + assert "Tower Hill" in node_names[-1] + expected_stations = [ + "Canada Water", + "Rotherhithe", + "Wapping", + "Shadwell", + "Whitechapel", + "Aldgate East", + "Tower Hill", + ] + for i, expected_station in enumerate(expected_stations): + assert expected_station in node_names[i] + + # Test with stations that should not have a path + result = await mcp_client.call_tool( + "a_star_shortest_path", + { + "sourceNode": "NonExistentStation1", + "targetNode": "NonExistentStation2", + "nodeIdentifierProperty": "name", + "latitudeProperty": "latitude", + "longitudeProperty": "longitude", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_yens_shortest_paths(mcp_client): + result = await mcp_client.call_tool( + "yens_shortest_paths", + { + "sourceNode": "Canada Water", + "targetNode": "Tower Hill", + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + "k": 3, + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "sourceNodeId" in result_data + assert "targetNodeId" in result_data + assert "sourceNodeName" in result_data + assert "targetNodeName" in result_data + assert "results" in result_data + assert "totalResults" in result_data + + assert "Canada Water" in result_data["sourceNodeName"] + assert "Tower Hill" in result_data["targetNodeName"] + + results = result_data["results"] + assert 1 <= len(results) <= 3 + assert result_data["totalResults"] == len(results) + + first_result = results[0] + assert "index" in first_result + assert "totalCost" in first_result + assert "nodeIds" in first_result + assert "nodeNames" in first_result + assert "costs" in first_result + assert "path" in first_result + + # First path should be the optimal path (same as basic shortest path) + assert first_result["totalCost"] == 9.0 + expected_node_ids = [292, 188, 243, 196, 261, 2, 230] + assert first_result["nodeIds"] == expected_node_ids + + # Test with non-existent stations + result = await mcp_client.call_tool( + "yens_shortest_paths", + { + "sourceNode": "NonExistentStation1", + "targetNode": "NonExistentStation2", + "nodeIdentifierProperty": "name", + "k": 2, + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_minimum_weight_spanning_tree(mcp_client): + result = await mcp_client.call_tool( + "minimum_weight_spanning_tree", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "totalWeight" in result_data + assert "edges" in result_data + + edges = result_data["edges"] + assert len(edges) == 301 + assert result_data["totalWeight"] > 0 + + first_edge = edges[0] + assert "nodeId" in first_edge + assert "parentId" in first_edge + assert "nodeName" in first_edge + assert "parentName" in first_edge + assert "weight" in first_edge + assert first_edge["weight"] > 0 + + # Test with non-existent source node + result = await mcp_client.call_tool( + "minimum_weight_spanning_tree", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_minimum_directed_steiner_tree(mcp_client): + result = await mcp_client.call_tool( + "minimum_directed_steiner_tree", + { + "sourceNode": "Canada Water", + "targetNodes": ["Tower Hill", "King's Cross St. Pancras", "London Bridge"], + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "totalWeight" in result_data + assert "edges" in result_data + + assert result_data["totalWeight"] > 0 + edges = result_data["edges"] + assert len(edges) > 0 + + first_edge = edges[0] + assert "nodeId" in first_edge + assert "parentId" in first_edge + assert "weight" in first_edge + assert first_edge["weight"] > 0 + + # Test with non-existent source node + result = await mcp_client.call_tool( + "minimum_directed_steiner_tree", + { + "sourceNode": "NonExistentStation", + "targetNodes": ["Tower Hill"], + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_prize_collecting_steiner_tree(mcp_client): + result = await mcp_client.call_tool( + "prize_collecting_steiner_tree", + { + "relationshipWeightProperty": "time", + "prizeProperty": "zone", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "totalWeight" in result_data + assert "edges" in result_data + assert result_data["totalWeight"] > 0 + + edges = result_data["edges"] + assert len(edges) > 0 + + first_edge = edges[0] + assert "nodeId" in first_edge + assert "parentId" in first_edge + assert "weight" in first_edge + assert first_edge["weight"] > 0 + + +@pytest.mark.asyncio +async def test_all_pairs_shortest_paths(mcp_client): + result = await mcp_client.call_tool( + "all_pairs_shortest_paths", + { + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "paths" in result_data + + paths = result_data["paths"] + assert len(paths) == 302 * 302 + + first_path = paths[0] + assert "sourceNodeId" in first_path + assert "targetNodeId" in first_path + assert "distance" in first_path + + distance = first_path["distance"] + assert isinstance(distance, (int, float)) + assert distance >= 0 or distance == float("inf") + finite_distances = [ + path["distance"] for path in paths if path["distance"] != float("inf") + ] + assert len(finite_distances) > 0 # Should have at least some connected node pairs + + +@pytest.mark.asyncio +async def test_random_walk(mcp_client): + result = await mcp_client.call_tool( + "random_walk", + { + "sourceNodes": ["Canada Water"], + "nodeIdentifierProperty": "name", + "walkLength": 5, + "walksPerNode": 3, + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + walks = result_data["walks"] + assert len(walks) == 3 + + # Test with no source nodes specified (should use all nodes) + result = await mcp_client.call_tool( + "random_walk", + { + "walkLength": 3, + "walksPerNode": 1, + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is True + + walks = result_data["walks"] + # Should have 302 walks (1 walk per node for 302 nodes) + assert len(walks) == 302 + + +@pytest.mark.asyncio +async def test_breadth_first_search(mcp_client): + result = await mcp_client.call_tool( + "breadth_first_search", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "maxDepth": 3, + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "traversals" in result_data + + traversals = result_data["traversals"] + assert len(traversals) > 0 + + first_traversal = traversals[0] + assert "sourceNode" in first_traversal + assert "nodeIds" in first_traversal + assert "nodeNames" in first_traversal + assert "visitedNodes" in first_traversal + assert first_traversal["visitedNodes"] > 0 + assert "Canada Water" in first_traversal["nodeNames"][0] + + # Test with non-existent source node + result = await mcp_client.call_tool( + "breadth_first_search", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_depth_first_search(mcp_client): + result = await mcp_client.call_tool( + "depth_first_search", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "maxDepth": 3, + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "traversals" in result_data + + traversals = result_data["traversals"] + assert len(traversals) > 0 + + first_traversal = traversals[0] + assert "sourceNode" in first_traversal + assert "nodeIds" in first_traversal + assert "nodeNames" in first_traversal + assert "visitedNodes" in first_traversal + assert first_traversal["visitedNodes"] > 0 + assert "Canada Water" in first_traversal["nodeNames"][0] + + # Test with non-existent source node + result = await mcp_client.call_tool( + "depth_first_search", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_bellman_ford_single_source_shortest_path(mcp_client): + result = await mcp_client.call_tool( + "bellman_ford_single_source_shortest_path", + { + "sourceNode": "Canada Water", + "nodeIdentifierProperty": "name", + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + + assert result_data["found"] is True + assert "paths" in result_data + + paths = result_data["paths"] + assert len(paths) > 0 + + first_path = paths[0] + assert "index" in first_path + assert "sourceNode" in first_path + assert "targetNode" in first_path + assert "totalCost" in first_path + assert "nodeIds" in first_path + assert "nodeNames" in first_path + assert "costs" in first_path + assert "isNegativeCycle" in first_path + + source_node_id = first_path["sourceNode"] + for path in paths[:10]: # Check first 10 paths + assert path["sourceNode"] == source_node_id + + assert len(first_path["nodeIds"]) == len(first_path["nodeNames"]) + assert len(first_path["nodeIds"]) == len(first_path["costs"]) + assert "Canada Water" in first_path["nodeNames"][0] + + # Test with non-existent source node + result = await mcp_client.call_tool( + "bellman_ford_single_source_shortest_path", + { + "sourceNode": "NonExistentStation", + "nodeIdentifierProperty": "name", + }, + ) + + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data["found"] is False + + +@pytest.mark.asyncio +async def test_longest_path(mcp_client): + result = await mcp_client.call_tool( + "longest_path", + { + "relationshipWeightProperty": "time", + }, + ) + + assert len(result) == 1 + result_text = result[0]["text"] + result_data = json.loads(result_text) + assert result_data == { + "found": False, + "message": "No longest paths found. The graph may contain cycles or be empty.", + }