Skip to content

Commit ccf7e6e

Browse files
Replace with existing methods
1 parent 709fdce commit ccf7e6e

File tree

1 file changed

+84
-142
lines changed

1 file changed

+84
-142
lines changed

mcp_server/src/mcp_server_neo4j_gds/similarity_algorithm_handlers.py

Lines changed: 84 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
from .algorithm_handler import AlgorithmHandler
55
from .gds import projected_graph
6+
from .node_translator import (
7+
translate_ids_to_identifiers,
8+
translate_identifiers_to_ids,
9+
)
610

711
logger = logging.getLogger("mcp_server_neo4j_gds")
812

@@ -20,18 +24,20 @@ def node_similarity(self, **kwargs):
2024

2125
# Add node names to the results if nodeIdentifierProperty is provided
2226
node_identifier_property = kwargs.get("nodeIdentifierProperty")
23-
if node_identifier_property is not None:
24-
node1_name_values = [
25-
self.gds.util.asNode(node_id).get(node_identifier_property)
26-
for node_id in node_similarity_result["node1"]
27-
]
28-
node2_name_values = [
29-
self.gds.util.asNode(node_id).get(node_identifier_property)
30-
for node_id in node_similarity_result["node2"]
31-
]
32-
node_similarity_result["node1Name"] = node1_name_values
33-
node_similarity_result["node2Name"] = node2_name_values
34-
27+
translate_ids_to_identifiers(
28+
self.gds,
29+
node_identifier_property,
30+
node_similarity_result,
31+
"node1",
32+
"node1Name",
33+
)
34+
translate_ids_to_identifiers(
35+
self.gds,
36+
node_identifier_property,
37+
node_similarity_result,
38+
"node2",
39+
"node2Name",
40+
)
3541
return node_similarity_result
3642

3743
def execute(self, arguments: Dict[str, Any]) -> Any:
@@ -51,50 +57,6 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
5157

5258

5359
class FilteredNodeSimilarityHandler(AlgorithmHandler):
54-
def handle_input_nodes(
55-
self,
56-
input_nodes,
57-
input_nodes_variable_name,
58-
node_identifier_property,
59-
call_params,
60-
):
61-
# Handle input nodes - convert names to IDs if nodeIdentifierProperty is provided
62-
if input_nodes is not None and node_identifier_property is not None:
63-
if isinstance(input_nodes, list):
64-
# Handle list of node names
65-
query = f"""
66-
UNWIND $names AS name
67-
MATCH (s)
68-
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower(name)
69-
RETURN id(s) as node_id
70-
"""
71-
df = self.gds.run_cypher(
72-
query,
73-
params={
74-
"names": input_nodes,
75-
},
76-
)
77-
input_node_ids = df["node_id"].tolist()
78-
call_params[input_nodes_variable_name] = input_node_ids
79-
else:
80-
# Handle single node name
81-
query = f"""
82-
MATCH (s)
83-
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower($name)
84-
RETURN id(s) as node_id
85-
"""
86-
df = self.gds.run_cypher(
87-
query,
88-
params={
89-
"name": input_nodes,
90-
},
91-
)
92-
if not df.empty:
93-
call_params[input_nodes_variable_name] = int(df["node_id"].iloc[0])
94-
elif input_nodes is not None:
95-
# If input_nodes provided but no nodeIdentifierProperty, pass through as-is
96-
call_params[input_nodes_variable_name] = input_nodes
97-
9860
def filtered_node_similarity(self, **kwargs):
9961
with projected_graph(self.gds) as G:
10062
params = {
@@ -111,11 +73,19 @@ def filtered_node_similarity(self, **kwargs):
11173
node_identifier_property = kwargs.get("nodeIdentifierProperty")
11274
source_nodes = kwargs.get("sourceNodeFilter", None)
11375
target_nodes = kwargs.get("targetNodeFilter", None)
114-
self.handle_input_nodes(
115-
source_nodes, "sourceNodeFilter", node_identifier_property, params
76+
translate_identifiers_to_ids(
77+
self.gds,
78+
source_nodes,
79+
"sourceNodeFilter",
80+
node_identifier_property,
81+
params,
11682
)
117-
self.handle_input_nodes(
118-
target_nodes, "targetNodeFilter", node_identifier_property, params
83+
translate_ids_to_identifiers(
84+
self.gds,
85+
target_nodes,
86+
"targetNodeFilter",
87+
node_identifier_property,
88+
params,
11989
)
12090
logger.info(f"Filtered Node Similarity parameters: {params}")
12191
filtered_node_similarity_result = self.gds.nodeSimilarity.filtered.stream(
@@ -124,18 +94,20 @@ def filtered_node_similarity(self, **kwargs):
12494

12595
# Add node names to the results if nodeIdentifierProperty is provided
12696
node_identifier_property = kwargs.get("nodeIdentifierProperty")
127-
if node_identifier_property is not None:
128-
node1_name_values = [
129-
self.gds.util.asNode(node_id).get(node_identifier_property)
130-
for node_id in filtered_node_similarity_result["node1"]
131-
]
132-
node2_name_values = [
133-
self.gds.util.asNode(node_id).get(node_identifier_property)
134-
for node_id in filtered_node_similarity_result["node2"]
135-
]
136-
filtered_node_similarity_result["node1Name"] = node1_name_values
137-
filtered_node_similarity_result["node2Name"] = node2_name_values
138-
97+
translate_ids_to_identifiers(
98+
self.gds,
99+
node_identifier_property,
100+
filtered_node_similarity_result,
101+
"node1",
102+
"node1Name",
103+
)
104+
translate_ids_to_identifiers(
105+
self.gds,
106+
node_identifier_property,
107+
filtered_node_similarity_result,
108+
"node2",
109+
"node2Name",
110+
)
139111
return filtered_node_similarity_result
140112

141113
def execute(self, arguments: Dict[str, Any]) -> Any:
@@ -169,17 +141,20 @@ def k_nearest_neighbors(self, **kwargs):
169141

170142
# Add node names to the results if nodeIdentifierProperty is provided
171143
node_identifier_property = kwargs.get("nodeIdentifierProperty")
172-
if node_identifier_property is not None:
173-
node1_name_values = [
174-
self.gds.util.asNode(node_id).get(node_identifier_property)
175-
for node_id in k_nearest_neighbors_result["node1"]
176-
]
177-
node2_name_values = [
178-
self.gds.util.asNode(node_id).get(node_identifier_property)
179-
for node_id in k_nearest_neighbors_result["node2"]
180-
]
181-
k_nearest_neighbors_result["node1Name"] = node1_name_values
182-
k_nearest_neighbors_result["node2Name"] = node2_name_values
144+
translate_ids_to_identifiers(
145+
self.gds,
146+
node_identifier_property,
147+
k_nearest_neighbors_result,
148+
"node1",
149+
"node1Name",
150+
)
151+
translate_ids_to_identifiers(
152+
self.gds,
153+
node_identifier_property,
154+
k_nearest_neighbors_result,
155+
"node2",
156+
"node2Name",
157+
)
183158

184159
return k_nearest_neighbors_result
185160

@@ -199,50 +174,6 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
199174

200175

201176
class FilteredKNearestNeighborsHandler(AlgorithmHandler):
202-
def handle_input_nodes(
203-
self,
204-
input_nodes,
205-
input_nodes_variable_name,
206-
node_identifier_property,
207-
call_params,
208-
):
209-
# Handle input nodes - convert names to IDs if nodeIdentifierProperty is provided
210-
if input_nodes is not None and node_identifier_property is not None:
211-
if isinstance(input_nodes, list):
212-
# Handle list of node names
213-
query = f"""
214-
UNWIND $names AS name
215-
MATCH (s)
216-
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower(name)
217-
RETURN id(s) as node_id
218-
"""
219-
df = self.gds.run_cypher(
220-
query,
221-
params={
222-
"names": input_nodes,
223-
},
224-
)
225-
input_node_ids = df["node_id"].tolist()
226-
call_params[input_nodes_variable_name] = input_node_ids
227-
else:
228-
# Handle single node name
229-
query = f"""
230-
MATCH (s)
231-
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower($name)
232-
RETURN id(s) as node_id
233-
"""
234-
df = self.gds.run_cypher(
235-
query,
236-
params={
237-
"name": input_nodes,
238-
},
239-
)
240-
if not df.empty:
241-
call_params[input_nodes_variable_name] = int(df["node_id"].iloc[0])
242-
elif input_nodes is not None:
243-
# If input_nodes provided but no nodeIdentifierProperty, pass through as-is
244-
call_params[input_nodes_variable_name] = input_nodes
245-
246177
def filtered_k_nearest_neighbors(self, **kwargs):
247178
with projected_graph(self.gds) as G:
248179
params = {
@@ -259,11 +190,19 @@ def filtered_k_nearest_neighbors(self, **kwargs):
259190
node_identifier_property = kwargs.get("nodeIdentifierProperty")
260191
source_nodes = kwargs.get("sourceNodeFilter", None)
261192
target_nodes = kwargs.get("targetNodeFilter", None)
262-
self.handle_input_nodes(
263-
source_nodes, "sourceNodeFilter", node_identifier_property, params
193+
translate_identifiers_to_ids(
194+
self.gds,
195+
source_nodes,
196+
"sourceNodeFilter",
197+
node_identifier_property,
198+
params,
264199
)
265-
self.handle_input_nodes(
266-
target_nodes, "targetNodeFilter", node_identifier_property, params
200+
translate_ids_to_identifiers(
201+
self.gds,
202+
target_nodes,
203+
"targetNodeFilter",
204+
node_identifier_property,
205+
params,
267206
)
268207

269208
logger.info(f"Filtered K-Nearest Neighbors parameters: {kwargs}")
@@ -273,17 +212,20 @@ def filtered_k_nearest_neighbors(self, **kwargs):
273212

274213
# Add node names to the results if nodeIdentifierProperty is provided
275214
node_identifier_property = kwargs.get("nodeIdentifierProperty")
276-
if node_identifier_property is not None:
277-
node1_name_values = [
278-
self.gds.util.asNode(node_id).get(node_identifier_property)
279-
for node_id in filtered_k_nearest_neighbors_result["node1"]
280-
]
281-
node2_name_values = [
282-
self.gds.util.asNode(node_id).get(node_identifier_property)
283-
for node_id in filtered_k_nearest_neighbors_result["node2"]
284-
]
285-
filtered_k_nearest_neighbors_result["node1Name"] = node1_name_values
286-
filtered_k_nearest_neighbors_result["node2Name"] = node2_name_values
215+
translate_ids_to_identifiers(
216+
self.gds,
217+
node_identifier_property,
218+
filtered_k_nearest_neighbors_result,
219+
"node1",
220+
"node1Name",
221+
)
222+
translate_ids_to_identifiers(
223+
self.gds,
224+
node_identifier_property,
225+
filtered_k_nearest_neighbors_result,
226+
"node2",
227+
"node2Name",
228+
)
287229

288230
return filtered_k_nearest_neighbors_result
289231

0 commit comments

Comments
 (0)