|
1 | 1 | import logging |
2 | 2 | from typing import Dict, Any |
3 | 3 |
|
4 | | -from graphdatascience import GraphDataScience |
5 | | - |
6 | 4 | from .algorithm_handler import AlgorithmHandler |
7 | 5 | from .gds import projected_graph |
8 | 6 |
|
@@ -201,21 +199,97 @@ def execute(self, arguments: Dict[str, Any]) -> Any: |
201 | 199 |
|
202 | 200 |
|
203 | 201 | class FilteredKNearestNeighborsHandler(AlgorithmHandler): |
204 | | - def filtered_k_nearest_neighbors( |
205 | | - self, db_url: str, username: str, password: str, **kwargs |
| 202 | + def handle_input_nodes( |
| 203 | + self, |
| 204 | + input_nodes, |
| 205 | + input_nodes_variable_name, |
| 206 | + node_identifier_property, |
| 207 | + call_params, |
206 | 208 | ): |
207 | | - gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False) |
208 | | - with projected_graph(gds) as G: |
| 209 | + # Handle input nodes - convert names to IDs if nodeIdentifierProperty is provided |
| 210 | + if input_nodes is not None and node_identifier_property is not None: |
| 211 | + if isinstance(input_nodes, list): |
| 212 | + # Handle list of node names |
| 213 | + query = f""" |
| 214 | + UNWIND $names AS name |
| 215 | + MATCH (s) |
| 216 | + WHERE toLower(s.{node_identifier_property}) CONTAINS toLower(name) |
| 217 | + RETURN id(s) as node_id |
| 218 | + """ |
| 219 | + df = self.gds.run_cypher( |
| 220 | + query, |
| 221 | + params={ |
| 222 | + "names": input_nodes, |
| 223 | + }, |
| 224 | + ) |
| 225 | + input_node_ids = df["node_id"].tolist() |
| 226 | + call_params[input_nodes_variable_name] = input_node_ids |
| 227 | + else: |
| 228 | + # Handle single node name |
| 229 | + query = f""" |
| 230 | + MATCH (s) |
| 231 | + WHERE toLower(s.{node_identifier_property}) CONTAINS toLower($name) |
| 232 | + RETURN id(s) as node_id |
| 233 | + """ |
| 234 | + df = self.gds.run_cypher( |
| 235 | + query, |
| 236 | + params={ |
| 237 | + "name": input_nodes, |
| 238 | + }, |
| 239 | + ) |
| 240 | + if not df.empty: |
| 241 | + call_params[input_nodes_variable_name] = int(df["node_id"].iloc[0]) |
| 242 | + elif input_nodes is not None: |
| 243 | + # If input_nodes provided but no nodeIdentifierProperty, pass through as-is |
| 244 | + call_params[input_nodes_variable_name] = input_nodes |
| 245 | + |
| 246 | + def filtered_k_nearest_neighbors(self, **kwargs): |
| 247 | + with projected_graph(self.gds) as G: |
| 248 | + params = { |
| 249 | + k: v |
| 250 | + for k, v in kwargs.items() |
| 251 | + if v is not None |
| 252 | + and k |
| 253 | + not in [ |
| 254 | + "nodeIdentifierProperty", |
| 255 | + "sourceNodeFilter", |
| 256 | + "targetNodeFilter", |
| 257 | + ] |
| 258 | + } |
| 259 | + node_identifier_property = kwargs.get("nodeIdentifierProperty") |
| 260 | + source_nodes = kwargs.get("sourceNodeFilter", None) |
| 261 | + target_nodes = kwargs.get("targetNodeFilter", None) |
| 262 | + self.handle_input_nodes( |
| 263 | + source_nodes, "sourceNodeFilter", node_identifier_property, params |
| 264 | + ) |
| 265 | + self.handle_input_nodes( |
| 266 | + target_nodes, "targetNodeFilter", node_identifier_property, params |
| 267 | + ) |
| 268 | + |
209 | 269 | logger.info(f"Filtered K-Nearest Neighbors parameters: {kwargs}") |
210 | | - filtered_k_nearest_neighbors_result = gds.knn.filtered.stream(G, **kwargs) |
| 270 | + filtered_k_nearest_neighbors_result = self.gds.knn.filtered.stream( |
| 271 | + G, **params |
| 272 | + ) |
| 273 | + |
| 274 | + # Add node names to the results if nodeIdentifierProperty is provided |
| 275 | + node_identifier_property = kwargs.get("nodeIdentifierProperty") |
| 276 | + if node_identifier_property is not None: |
| 277 | + node1_name_values = [ |
| 278 | + self.gds.util.asNode(node_id).get(node_identifier_property) |
| 279 | + for node_id in filtered_k_nearest_neighbors_result["node1"] |
| 280 | + ] |
| 281 | + node2_name_values = [ |
| 282 | + self.gds.util.asNode(node_id).get(node_identifier_property) |
| 283 | + for node_id in filtered_k_nearest_neighbors_result["node2"] |
| 284 | + ] |
| 285 | + filtered_k_nearest_neighbors_result["node1Name"] = node1_name_values |
| 286 | + filtered_k_nearest_neighbors_result["node2Name"] = node2_name_values |
211 | 287 |
|
212 | 288 | return filtered_k_nearest_neighbors_result |
213 | 289 |
|
214 | 290 | def execute(self, arguments: Dict[str, Any]) -> Any: |
215 | 291 | return self.filtered_k_nearest_neighbors( |
216 | | - self.db_url, |
217 | | - self.username, |
218 | | - self.password, |
| 292 | + nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"), |
219 | 293 | sourceNodeFilter=arguments.get("sourceNodeFilter"), |
220 | 294 | targetNodeFilter=arguments.get("targetNodeFilter"), |
221 | 295 | nodeProperties=arguments.get("nodeProperties"), |
|
0 commit comments