@@ -367,6 +367,7 @@ def mmr_traversal_search(
367
367
lambda_mult : float = 0.5 ,
368
368
score_threshold : float = float ("-inf" ),
369
369
metadata_filter : dict [str , Any ] = {}, # noqa: B006
370
+ tag_filter : set [tuple [str , str ]],
370
371
) -> Iterable [Node ]:
371
372
"""Retrieve documents from this graph store using MMR-traversal.
372
373
@@ -398,6 +399,7 @@ def mmr_traversal_search(
398
399
score_threshold: Only documents with a score greater than or equal
399
400
this threshold will be chosen. Defaults to -infinity.
400
401
metadata_filter: Optional metadata to filter the results.
402
+ tag_filter: Optional tags to filter graph edges to be traversed.
401
403
"""
402
404
query_embedding = self ._embedding .embed_query (query )
403
405
helper = MmrHelper (
@@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
444
446
new_candidates = {}
445
447
for adjacent in adjacents :
446
448
if adjacent .target_content_id not in outgoing_tags :
447
- outgoing_tags [adjacent .target_content_id ] = (
448
- adjacent .target_link_to_tags
449
- )
449
+ if tag_filter .len () == 0 :
450
+ outgoing_tags [adjacent .target_content_id ] = (
451
+ adjacent .target_link_to_tags
452
+ )
453
+ else :
454
+ outgoing_tags [adjacent .target_content_id ] = (
455
+ tag_filter .intersection (adjacent .target_link_to_tags )
456
+ )
450
457
451
458
new_candidates [adjacent .target_content_id ] = (
452
459
adjacent .target_text_embedding
@@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None:
474
481
for row in fetched :
475
482
if row .content_id not in outgoing_tags :
476
483
candidates [row .content_id ] = row .text_embedding
477
- outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
484
+ if tag_filter .len () == 0 :
485
+ outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
486
+ else :
487
+ outgoing_tags [row .content_id ] = tag_filter .intersection (set (row .link_to_tags or []))
478
488
helper .add_candidates (candidates )
479
489
480
490
if initial_roots :
@@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None:
522
532
new_candidates = {}
523
533
for adjacent in adjacents :
524
534
if adjacent .target_content_id not in outgoing_tags :
525
- outgoing_tags [adjacent .target_content_id ] = (
526
- adjacent .target_link_to_tags
527
- )
535
+ if tag_filter .len () == 0 :
536
+ outgoing_tags [adjacent .target_content_id ] = (
537
+ adjacent .target_link_to_tags
538
+ )
539
+ else :
540
+ outgoing_tags [adjacent .target_content_id ] = (
541
+ tag_filter .intersection (adjacent .target_link_to_tags )
542
+ )
528
543
new_candidates [adjacent .target_content_id ] = (
529
544
adjacent .target_text_embedding
530
545
)
0 commit comments