Skip to content

Commit 33ccc47

Browse files
Support filtered knn
1 parent 28608b1 commit 33ccc47

File tree

3 files changed

+165
-14
lines changed

3 files changed

+165
-14
lines changed

mcp_server/src/mcp_server_neo4j_gds/similarity_algorithm_handlers.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
22
from typing import Dict, Any
33

4-
from graphdatascience import GraphDataScience
5-
64
from .algorithm_handler import AlgorithmHandler
75
from .gds import projected_graph
86

@@ -201,21 +199,97 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
201199

202200

203201
class FilteredKNearestNeighborsHandler(AlgorithmHandler):
204-
def filtered_k_nearest_neighbors(
205-
self, db_url: str, username: str, password: str, **kwargs
202+
def handle_input_nodes(
203+
self,
204+
input_nodes,
205+
input_nodes_variable_name,
206+
node_identifier_property,
207+
call_params,
206208
):
207-
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
208-
with projected_graph(gds) as G:
209+
# Handle input nodes - convert names to IDs if nodeIdentifierProperty is provided
210+
if input_nodes is not None and node_identifier_property is not None:
211+
if isinstance(input_nodes, list):
212+
# Handle list of node names
213+
query = f"""
214+
UNWIND $names AS name
215+
MATCH (s)
216+
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower(name)
217+
RETURN id(s) as node_id
218+
"""
219+
df = self.gds.run_cypher(
220+
query,
221+
params={
222+
"names": input_nodes,
223+
},
224+
)
225+
input_node_ids = df["node_id"].tolist()
226+
call_params[input_nodes_variable_name] = input_node_ids
227+
else:
228+
# Handle single node name
229+
query = f"""
230+
MATCH (s)
231+
WHERE toLower(s.{node_identifier_property}) CONTAINS toLower($name)
232+
RETURN id(s) as node_id
233+
"""
234+
df = self.gds.run_cypher(
235+
query,
236+
params={
237+
"name": input_nodes,
238+
},
239+
)
240+
if not df.empty:
241+
call_params[input_nodes_variable_name] = int(df["node_id"].iloc[0])
242+
elif input_nodes is not None:
243+
# If input_nodes provided but no nodeIdentifierProperty, pass through as-is
244+
call_params[input_nodes_variable_name] = input_nodes
245+
246+
def filtered_k_nearest_neighbors(self, **kwargs):
247+
with projected_graph(self.gds) as G:
248+
params = {
249+
k: v
250+
for k, v in kwargs.items()
251+
if v is not None
252+
and k
253+
not in [
254+
"nodeIdentifierProperty",
255+
"sourceNodeFilter",
256+
"targetNodeFilter",
257+
]
258+
}
259+
node_identifier_property = kwargs.get("nodeIdentifierProperty")
260+
source_nodes = kwargs.get("sourceNodeFilter", None)
261+
target_nodes = kwargs.get("targetNodeFilter", None)
262+
self.handle_input_nodes(
263+
source_nodes, "sourceNodeFilter", node_identifier_property, params
264+
)
265+
self.handle_input_nodes(
266+
target_nodes, "targetNodeFilter", node_identifier_property, params
267+
)
268+
209269
logger.info(f"Filtered K-Nearest Neighbors parameters: {kwargs}")
210-
filtered_k_nearest_neighbors_result = gds.knn.filtered.stream(G, **kwargs)
270+
filtered_k_nearest_neighbors_result = self.gds.knn.filtered.stream(
271+
G, **params
272+
)
273+
274+
# Add node names to the results if nodeIdentifierProperty is provided
275+
node_identifier_property = kwargs.get("nodeIdentifierProperty")
276+
if node_identifier_property is not None:
277+
node1_name_values = [
278+
self.gds.util.asNode(node_id).get(node_identifier_property)
279+
for node_id in filtered_k_nearest_neighbors_result["node1"]
280+
]
281+
node2_name_values = [
282+
self.gds.util.asNode(node_id).get(node_identifier_property)
283+
for node_id in filtered_k_nearest_neighbors_result["node2"]
284+
]
285+
filtered_k_nearest_neighbors_result["node1Name"] = node1_name_values
286+
filtered_k_nearest_neighbors_result["node2Name"] = node2_name_values
211287

212288
return filtered_k_nearest_neighbors_result
213289

214290
def execute(self, arguments: Dict[str, Any]) -> Any:
215291
return self.filtered_k_nearest_neighbors(
216-
self.db_url,
217-
self.username,
218-
self.password,
292+
nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"),
219293
sourceNodeFilter=arguments.get("sourceNodeFilter"),
220294
targetNodeFilter=arguments.get("targetNodeFilter"),
221295
nodeProperties=arguments.get("nodeProperties"),

mcp_server/src/mcp_server_neo4j_gds/similarity_algorithm_specs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
"nodeIdentifierProperty": {
118118
"type": "string",
119119
"description": "Property name to use for identifying nodes (e.g., 'name', 'Name', 'title'). Use get_node_properties_keys to find available properties.",
120-
}
120+
},
121121
},
122122
},
123123
),
@@ -185,7 +185,7 @@
185185
"nodeIdentifierProperty": {
186186
"type": "string",
187187
"description": "Property name to use for identifying nodes (e.g., 'name', 'Name', 'title'). Use get_node_properties_keys to find available properties.",
188-
}
188+
},
189189
},
190190
"required": ["nodeProperties"],
191191
},
@@ -243,12 +243,12 @@
243243
},
244244
"seedTargetNodes": {
245245
"type": "boolean",
246-
"description": "Enable seeding of target nodes.",
246+
"description": "Enable seeding of target nodes. If seeded, every node picks some of the target nodes initially. This guarantees that for every node we can avoid empty result (when the algorithm did not find for it any similar neighbors from the target set). Can only be used if targetNodeFilter is set.",
247247
},
248248
"nodeIdentifierProperty": {
249249
"type": "string",
250250
"description": "Property name to use for identifying nodes (e.g., 'name', 'Name', 'title'). Use get_node_properties_keys to find available properties.",
251-
}
251+
},
252252
},
253253
"required": ["sourceNodeFilter", "targetNodeFilter", "nodeProperties"],
254254
},

mcp_server/tests/test_similarity_algorithms.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,80 @@ async def test_k_nearest_neighbors(mcp_client):
117117
lines = result_text.strip().split("\n")
118118
data_lines = [line for line in lines[1:] if line.strip()]
119119
assert len(data_lines) == 302 * 3
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_filtered_knn(mcp_client):
124+
# test source-filter only
125+
result = await mcp_client.call_tool(
126+
"filtered_k_nearest_neighbors",
127+
{
128+
"nodeIdentifierProperty": "name",
129+
"topK": 3,
130+
"sourceNodeFilter": ["Acton Town"],
131+
"nodeProperties": "rail",
132+
},
133+
)
134+
135+
assert len(result) == 1
136+
result_text = result[0]["text"]
137+
# Verify structure of a result entry
138+
assert "node1" in result_text
139+
assert "node2" in result_text
140+
assert "node1Name" in result_text
141+
assert "node2Name" in result_text
142+
assert "similarity" in result_text
143+
lines = result_text.strip().split("\n")
144+
data_lines = [line for line in lines[1:] if line.strip()]
145+
assert len(data_lines) == 3
146+
assert "Acton Town" in data_lines[0]
147+
148+
# test target-filter alone
149+
150+
result = await mcp_client.call_tool(
151+
"filtered_k_nearest_neighbors",
152+
{
153+
"nodeIdentifierProperty": "name",
154+
"topK": 3,
155+
"targetNodeFilter": "Stamford Brook",
156+
"nodeProperties": "rail",
157+
},
158+
)
159+
assert len(result) == 1
160+
result_text = result[0]["text"]
161+
# Verify structure of a result entry
162+
assert "node1" in result_text
163+
assert "node2" in result_text
164+
assert "node1Name" in result_text
165+
assert "node2Name" in result_text
166+
assert "similarity" in result_text
167+
lines = result_text.strip().split("\n")
168+
data_lines = [line for line in lines[1:] if line.strip()]
169+
assert len(data_lines) > 0
170+
assert "Stamford Brook" in data_lines[0]
171+
172+
# test combination of filters
173+
result = await mcp_client.call_tool(
174+
"filtered_node_similarity",
175+
{
176+
"nodeIdentifierProperty": "name",
177+
"topK": 3,
178+
"sourceNodeFilter": ["Acton Town"],
179+
"targetNodeFilter": ["Stamford Brook"],
180+
"seedTargetNodes": True, # k-nn filtering is a bit special, it might not necessarily find answer if this is not specified (at least for this small example graph)
181+
},
182+
)
183+
184+
assert len(result) == 1
185+
result_text = result[0]["text"]
186+
# Verify structure of a result entry
187+
assert "node1" in result_text
188+
assert "node2" in result_text
189+
assert "node1Name" in result_text
190+
assert "node2Name" in result_text
191+
assert "similarity" in result_text
192+
lines = result_text.strip().split("\n")
193+
data_lines = [line for line in lines[1:] if line.strip()]
194+
assert len(data_lines) == 1
195+
assert "Acton Town" in data_lines[0]
196+
assert "Stamford Brook" in data_lines[0]

0 commit comments

Comments
 (0)