Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### New Features
1. Add a new get_relationship_properties_keys tool.
2. Add targetNode filtering for longest_path.
3. Add support for similarity algorithms.

### Bug Fixes
1. Return node names in several path algorithms that only returned node ids.
Expand Down
4 changes: 0 additions & 4 deletions mcp_server/src/mcp_server_neo4j_gds/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
)
from .similarity_algorithm_handlers import (
NodeSimilarityHandler,
FilteredNodeSimilarityHandler,
KNearestNeighborsHandler,
FilteredKNearestNeighborsHandler,
)
from .path_algorithm_handlers import (
DijkstraShortestPathHandler,
Expand Down Expand Up @@ -89,9 +87,7 @@ class AlgorithmRegistry:
"speaker_listener_label_propagation": SpeakerListenerLabelPropagationHandler,
# Similarity algorithms
"node_similarity": NodeSimilarityHandler,
"filtered_node_similarity": FilteredNodeSimilarityHandler,
"k_nearest_neighbors": KNearestNeighborsHandler,
"filtered_k_nearest_neighbors": FilteredKNearestNeighborsHandler,
# Path finding algorithms
"find_shortest_path": DijkstraShortestPathHandler,
"delta_stepping_shortest_path": DeltaSteppingShortestPathHandler,
Expand Down
3 changes: 3 additions & 0 deletions mcp_server/src/mcp_server_neo4j_gds/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pandas as pd
import json
from graphdatascience import GraphDataScience

from .similarity_algorithm_specs import similarity_tool_definitions
from .centrality_algorithm_specs import centrality_tool_definitions
from .community_algorithm_specs import community_tool_definitions
from .path_algorithm_specs import path_tool_definitions
Expand Down Expand Up @@ -91,6 +93,7 @@ async def handle_list_tools() -> list[types.Tool]:
+ centrality_tool_definitions
+ community_tool_definitions
+ path_tool_definitions
+ similarity_tool_definitions
)
logger.info(f"Returning {len(tools)} tools")
return tools
Expand Down
181 changes: 105 additions & 76 deletions mcp_server/src/mcp_server_neo4j_gds/similarity_algorithm_handlers.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,73 @@
import logging
from typing import Dict, Any

from graphdatascience import GraphDataScience

from .algorithm_handler import AlgorithmHandler
from .gds import projected_graph
from .node_translator import (
translate_ids_to_identifiers,
translate_identifiers_to_ids,
)

logger = logging.getLogger("mcp_server_neo4j_gds")


class NodeSimilarityHandler(AlgorithmHandler):
def node_similarity(self, db_url: str, username: str, password: str, **kwargs):
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
with projected_graph(gds) as G:
logger.info(f"Node Similarity parameters: {kwargs}")
node_similarity_result = gds.nodeSimilarity.stream(G, **kwargs)
def node_similarity(self, **kwargs):
with projected_graph(self.gds) as G:
params = {
k: v
for k, v in kwargs.items()
if v is not None
and k
not in [
"nodeIdentifierProperty",
"sourceNodeFilter",
"targetNodeFilter",
]
}
node_identifier_property = kwargs.get("nodeIdentifierProperty")
source_nodes = kwargs.get("sourceNodeFilter", None)
target_nodes = kwargs.get("targetNodeFilter", None)
translate_identifiers_to_ids(
self.gds,
source_nodes,
"sourceNodeFilter",
node_identifier_property,
params,
)
translate_identifiers_to_ids(
self.gds,
target_nodes,
"targetNodeFilter",
node_identifier_property,
params,
)
logger.info(f"Node Similarity parameters: {params}")
node_similarity_result = self.gds.nodeSimilarity.filtered.stream(
G, **params
)

# Add node names to the results if nodeIdentifierProperty is provided
node_identifier_property = kwargs.get("nodeIdentifierProperty")
translate_ids_to_identifiers(
self.gds,
node_identifier_property,
node_similarity_result,
"node1",
"node1Name",
)
translate_ids_to_identifiers(
self.gds,
node_identifier_property,
node_similarity_result,
"node2",
"node2Name",
)
return node_similarity_result

def execute(self, arguments: Dict[str, Any]) -> Any:
return self.node_similarity(
self.db_url,
self.username,
self.password,
similarityCutoff=arguments.get("similarityCutoff"),
degreeCutoff=arguments.get("degreeCutoff"),
upperDegreeCutoff=arguments.get("upperDegreeCutoff"),
topK=arguments.get("topK"),
bottomK=arguments.get("bottomK"),
topN=arguments.get("topN"),
bottomN=arguments.get("bottomN"),
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
similarityMetric=arguments.get("similarityMetric"),
useComponents=arguments.get("useComponents"),
)


class FilteredNodeSimilarityHandler(AlgorithmHandler):
def filtered_node_similarity(
self, db_url: str, username: str, password: str, **kwargs
):
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
with projected_graph(gds) as G:
logger.info(f"Filtered Node Similarity parameters: {kwargs}")
filtered_node_similarity_result = gds.nodeSimilarity.filtered.stream(
G, **kwargs
)

return filtered_node_similarity_result

def execute(self, arguments: Dict[str, Any]) -> Any:
return self.filtered_node_similarity(
self.db_url,
self.username,
self.password,
nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"),
sourceNodeFilter=arguments.get("sourceNodeFilter"),
targetNodeFilter=arguments.get("targetNodeFilter"),
similarityCutoff=arguments.get("similarityCutoff"),
Expand All @@ -70,47 +84,62 @@ def execute(self, arguments: Dict[str, Any]) -> Any:


class KNearestNeighborsHandler(AlgorithmHandler):
def k_nearest_neighbors(self, db_url: str, username: str, password: str, **kwargs):
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
with projected_graph(gds) as G:
def k_nearest_neighbors(self, **kwargs):
with projected_graph(self.gds) as G:
params = {
k: v
for k, v in kwargs.items()
if v is not None
and k
not in [
"nodeIdentifierProperty",
"sourceNodeFilter",
"targetNodeFilter",
]
}
node_identifier_property = kwargs.get("nodeIdentifierProperty")
source_nodes = kwargs.get("sourceNodeFilter", None)
target_nodes = kwargs.get("targetNodeFilter", None)
translate_identifiers_to_ids(
self.gds,
source_nodes,
"sourceNodeFilter",
node_identifier_property,
params,
)
translate_identifiers_to_ids(
self.gds,
target_nodes,
"targetNodeFilter",
node_identifier_property,
params,
)

logger.info(f"K-Nearest Neighbors parameters: {kwargs}")
k_nearest_neighbors_result = gds.knn.stream(G, **kwargs)
k_nearest_neighbors_result = self.gds.knn.filtered.stream(G, **params)

# Add node names to the results if nodeIdentifierProperty is provided
node_identifier_property = kwargs.get("nodeIdentifierProperty")
translate_ids_to_identifiers(
self.gds,
node_identifier_property,
k_nearest_neighbors_result,
"node1",
"node1Name",
)
translate_ids_to_identifiers(
self.gds,
node_identifier_property,
k_nearest_neighbors_result,
"node2",
"node2Name",
)

return k_nearest_neighbors_result

def execute(self, arguments: Dict[str, Any]) -> Any:
return self.k_nearest_neighbors(
self.db_url,
self.username,
self.password,
nodeProperties=arguments.get("nodeProperties"),
topK=arguments.get("topK"),
sampleRate=arguments.get("sampleRate"),
deltaThreshold=arguments.get("deltaThreshold"),
maxIterations=arguments.get("maxIterations"),
randomJoins=arguments.get("randomJoins"),
initialSampler=arguments.get("initialSampler"),
similarityCutoff=arguments.get("similarityCutoff"),
perturbationRate=arguments.get("perturbationRate"),
)


class FilteredKNearestNeighborsHandler(AlgorithmHandler):
def filtered_k_nearest_neighbors(
self, db_url: str, username: str, password: str, **kwargs
):
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
with projected_graph(gds) as G:
logger.info(f"Filtered K-Nearest Neighbors parameters: {kwargs}")
filtered_k_nearest_neighbors_result = gds.knn.filtered.stream(G, **kwargs)

return filtered_k_nearest_neighbors_result

def execute(self, arguments: Dict[str, Any]) -> Any:
return self.filtered_k_nearest_neighbors(
self.db_url,
self.username,
self.password,
nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"),
sourceNodeFilter=arguments.get("sourceNodeFilter"),
targetNodeFilter=arguments.get("targetNodeFilter"),
nodeProperties=arguments.get("nodeProperties"),
Expand Down
Loading