@@ -496,34 +496,39 @@ async def node_distance_reranker(
496496 sorted_uuids = rrf (results )
497497 scores : dict [str , float ] = {}
498498
499- for uuid in sorted_uuids :
500- # Find the shortest path to center node
501- records , _ , _ = await driver .execute_query (
502- """
499+ # Find the shortest path to center node
500+ query = Query ("""
503501 MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
504- MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
505- WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
506- RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
507- """ ,
508- edge_uuid = uuid ,
509- center_uuid = center_node_uuid ,
510- )
511- distance = 0.01
502+ MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
503+ RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
504+ """ )
512505
513- for record in records :
514- if (
515- record ['source_uuid' ] == center_node_uuid
516- or record ['target_uuid' ] == center_node_uuid
517- ):
518- continue
519- distance = record ['score' ]
506+ path_results = await asyncio .gather (
507+ * [
508+ driver .execute_query (
509+ query ,
510+ edge_uuid = uuid ,
511+ center_uuid = center_node_uuid ,
512+ )
513+ for uuid in sorted_uuids
514+ ]
515+ )
516+
517+ for uuid , result in zip (sorted_uuids , path_results ):
518+ records = result [0 ]
519+ record = records [0 ] if len (records ) > 0 else None
520+ distance : float = record ['score' ] if record is not None else float ('inf' )
521+ if record is not None and (
522+ record ['source_uuid' ] == center_node_uuid or record ['target_uuid' ] == center_node_uuid
523+ ):
524+ distance = 0
520525
521526 if uuid in scores :
522- scores [uuid ] = min (1 / distance , scores [uuid ])
527+ scores [uuid ] = min (distance , scores [uuid ])
523528 else :
524- scores [uuid ] = 1 / distance
529+ scores [uuid ] = distance
525530
526531 # rerank on shortest distance
527- sorted_uuids .sort (reverse = True , key = lambda cur_uuid : scores [cur_uuid ])
532+ sorted_uuids .sort (key = lambda cur_uuid : scores [cur_uuid ])
528533
529534 return sorted_uuids
0 commit comments