Skip to content

Commit 92e4c74

Browse files
committed
WIP
1 parent eb54a9f commit 92e4c74

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def mmr_traversal_search(
367367
lambda_mult: float = 0.5,
368368
score_threshold: float = float("-inf"),
369369
metadata_filter: dict[str, Any] = {}, # noqa: B006
370+
tag_filter: set[tuple[str, str]],
370371
) -> Iterable[Node]:
371372
"""Retrieve documents from this graph store using MMR-traversal.
372373
@@ -398,6 +399,7 @@ def mmr_traversal_search(
398399
score_threshold: Only documents with a score greater than or equal
399400
this threshold will be chosen. Defaults to -infinity.
400401
metadata_filter: Optional metadata to filter the results.
402+
tag_filter: Optional tags to filter graph edges to be traversed.
401403
"""
402404
query_embedding = self._embedding.embed_query(query)
403405
helper = MmrHelper(
@@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
444446
new_candidates = {}
445447
for adjacent in adjacents:
446448
if adjacent.target_content_id not in outgoing_tags:
447-
outgoing_tags[adjacent.target_content_id] = (
448-
adjacent.target_link_to_tags
449-
)
449+
if tag_filter.len() == 0:
450+
outgoing_tags[adjacent.target_content_id] = (
451+
adjacent.target_link_to_tags
452+
)
453+
else:
454+
outgoing_tags[adjacent.target_content_id] = (
455+
tag_filter.intersection(adjacent.target_link_to_tags)
456+
)
450457

451458
new_candidates[adjacent.target_content_id] = (
452459
adjacent.target_text_embedding
@@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None:
474481
for row in fetched:
475482
if row.content_id not in outgoing_tags:
476483
candidates[row.content_id] = row.text_embedding
477-
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
484+
if tag_filter.len() == 0:
485+
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
486+
else:
487+
outgoing_tags[row.content_id] = tag_filter.intersection(set(row.link_to_tags or []))
478488
helper.add_candidates(candidates)
479489

480490
if initial_roots:
@@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None:
522532
new_candidates = {}
523533
for adjacent in adjacents:
524534
if adjacent.target_content_id not in outgoing_tags:
525-
outgoing_tags[adjacent.target_content_id] = (
526-
adjacent.target_link_to_tags
527-
)
535+
if tag_filter.len() == 0:
536+
outgoing_tags[adjacent.target_content_id] = (
537+
adjacent.target_link_to_tags
538+
)
539+
else:
540+
outgoing_tags[adjacent.target_content_id] = (
541+
tag_filter.intersection(adjacent.target_link_to_tags)
542+
)
528543
new_candidates[adjacent.target_content_id] = (
529544
adjacent.target_text_embedding
530545
)

libs/knowledge-store/tests/integration_tests/test_graph_store.py

+6
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ def test_mmr_traversal(
211211
results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"])
212212
assert _result_ids(results) == ["v1", "v3", "v2"]
213213

214+
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("explicit", "link")))
215+
assert _result_ids(results) == ["v0", "v2"]
216+
217+
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("no", "match")))
218+
assert _result_ids(results) == []
219+
214220

215221
def test_write_retrieve_keywords(
216222
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],

0 commit comments

Comments
 (0)