Skip to content

Commit 4c83f98

Browse files
authored
updated adapter tests to ensure documents are returned in similarity order (#172)
* updated tests to ensure documents are returned in similarity order * updated test
1 parent a509f0e commit 4c83f98

File tree

11 files changed

+173
-96
lines changed

11 files changed

+173
-96
lines changed

packages/graph-retriever/src/graph_retriever/content.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ class Content:
1818
The content.
1919
embedding :
2020
The embedding of the content.
21-
score :
22-
The similarity of the embedding to the query.
23-
This is optional, and may not be set depending on the content.
2421
metadata :
2522
The metadata associated with the content.
2623
mime_type :
@@ -31,16 +28,14 @@ class Content:
3128
content: str
3229
embedding: list[float]
3330
metadata: dict[str, Any] = dataclasses.field(default_factory=dict)
34-
3531
mime_type: str = "text/plain"
36-
score: float | None = None
3732

3833
@staticmethod
3934
def new(
4035
id: str,
4136
content: str,
4237
embedding: list[float] | Callable[[str], list[float]],
43-
score: float | None = None,
38+
*,
4439
metadata: dict[str, Any] | None = None,
4540
mime_type: str = "text/plain",
4641
) -> Content:
@@ -56,8 +51,6 @@ def new(
5651
embedding :
5752
The embedding, or a function to apply to the content to compute the
5853
embedding.
59-
score :
60-
The similarity of the embedding to the query.
6154
metadata :
6255
The metadata associated with the content.
6356
mime_type :
@@ -72,7 +65,6 @@ def new(
7265
id=id,
7366
content=content,
7467
embedding=embedding(content) if callable(embedding) else embedding,
75-
score=score,
7668
metadata=metadata or {},
7769
mime_type=mime_type,
7870
)

packages/graph-retriever/src/graph_retriever/strategies/scored.py

+7
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,10 @@ def iteration(self, nodes: Iterable[Node], tracker: NodeTracker) -> None:
6666
node = highest.node
6767
node.extra_metadata["_score"] = highest.score
6868
limit -= tracker.select_and_traverse([node])
69+
70+
@override
71+
def finalize_nodes(self, selected):
72+
selected = sorted(
73+
selected, key=lambda node: node.extra_metadata["_score"], reverse=True
74+
)
75+
return super().finalize_nodes(selected)

packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py

+84-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from graph_retriever import Content
99
from graph_retriever.adapters import Adapter
1010
from graph_retriever.edges import Edge, IdEdge, MetadataEdge
11+
from graph_retriever.utils.math import cosine_similarity
1112

1213

1314
def assert_valid_result(content: Content):
@@ -40,6 +41,48 @@ def assert_ids_any_order(
4041
assert set(result_ids) == set(expected), "should contain exactly expected IDs"
4142

4243

44+
def cosine_similarity_scores(
45+
adapter: Adapter, query_or_embedding: str | list[float], ids: list[str]
46+
) -> dict[str, float]:
47+
"""Return the cosine similarity scores for the given IDs and query embedding."""
48+
if len(ids) == 0:
49+
return {}
50+
51+
docs = adapter.get(ids)
52+
found_ids = (d.id for d in docs)
53+
assert set(ids) == set(found_ids), "can't find all IDs"
54+
55+
if isinstance(query_or_embedding, str):
56+
query_embedding = adapter.search_with_embedding(query_or_embedding, k=0)[0]
57+
else:
58+
query_embedding = query_or_embedding
59+
60+
scores: list[float] = cosine_similarity(
61+
[query_embedding],
62+
[d.embedding for d in docs],
63+
)[0]
64+
65+
return {doc.id: score for doc, score in zip(docs, scores)}
66+
67+
68+
def assert_ids_in_cosine_similarity_order(
69+
results: Iterable[Content],
70+
expected: list[str],
71+
query_embedding: list[float],
72+
adapter: Adapter,
73+
) -> None:
74+
"""Assert the results are valid and in cosine similarity order."""
75+
assert_valid_results(results)
76+
result_ids = [r.id for r in results]
77+
78+
similarity_scores = cosine_similarity_scores(adapter, query_embedding, expected)
79+
expected = sorted(expected, key=lambda id: similarity_scores[id], reverse=True)
80+
81+
assert result_ids == expected, (
82+
"should contain expected IDs in cosine similarity order"
83+
)
84+
85+
4386
@dataclass(kw_only=True)
4487
class AdapterComplianceCase(abc.ABC):
4588
"""
@@ -77,8 +120,40 @@ class GetCase(AdapterComplianceCase):
77120
GetCase(id="one", request=["boar"], expected=["boar"]),
78121
GetCase(
79122
id="many",
80-
request=["boar", "chinchilla", "cobra"],
81-
expected=["boar", "chinchilla", "cobra"],
123+
request=[
124+
"alligator",
125+
"barracuda",
126+
"chameleon",
127+
"cobra",
128+
"crocodile",
129+
"dolphin",
130+
"eel",
131+
"fish",
132+
"gecko",
133+
"iguana",
134+
"jellyfish",
135+
"komodo dragon",
136+
"lizard",
137+
"manatee",
138+
"narwhal",
139+
],
140+
expected=[
141+
"alligator",
142+
"barracuda",
143+
"chameleon",
144+
"cobra",
145+
"crocodile",
146+
"dolphin",
147+
"eel",
148+
"fish",
149+
"gecko",
150+
"iguana",
151+
"jellyfish",
152+
"komodo dragon",
153+
"lizard",
154+
"manatee",
155+
"narwhal",
156+
],
82157
),
83158
GetCase(
84159
id="missing",
@@ -410,7 +485,7 @@ def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
410485
411486
Generally, this should *not* change the expected results, unless the the
412487
adapter being tested uses wildly different distance metrics or a
413-
different embedding. The `AnimalsEmbedding` is deterimistic and the
488+
different embedding. The `AnimalsEmbedding` is deterministic and the
414489
results across vector stores should generally be deterministic and
415490
consistent.
416491
@@ -469,7 +544,7 @@ def test_search_with_embedding(
469544
search_case.query, **search_case.kwargs
470545
)
471546
assert_is_embedding(embedding)
472-
assert_ids_any_order(results, expected)
547+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)
473548

474549
async def test_asearch_with_embedding(
475550
self, adapter: Adapter, search_case: SearchCase
@@ -480,21 +555,21 @@ async def test_asearch_with_embedding(
480555
search_case.query, **search_case.kwargs
481556
)
482557
assert_is_embedding(embedding)
483-
assert_ids_any_order(results, expected)
558+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)
484559

485560
def test_search(self, adapter: Adapter, search_case: SearchCase) -> None:
486561
"""Run tests for `search`."""
487562
expected = self.expected("search", search_case)
488563
embedding, _ = adapter.search_with_embedding(search_case.query, k=0)
489564
results = adapter.search(embedding, **search_case.kwargs)
490-
assert_ids_any_order(results, expected)
565+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)
491566

492567
async def test_asearch(self, adapter: Adapter, search_case: SearchCase) -> None:
493568
"""Run tests for `asearch`."""
494569
expected = self.expected("asearch", search_case)
495570
embedding, _ = await adapter.asearch_with_embedding(search_case.query, k=0)
496571
results = await adapter.asearch(embedding, **search_case.kwargs)
497-
assert_ids_any_order(results, expected)
572+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)
498573

499574
def test_adjacent(self, adapter: Adapter, adjacent_case: AdjacentCase) -> None:
500575
"""Run tests for `adjacent."""
@@ -506,7 +581,7 @@ def test_adjacent(self, adapter: Adapter, adjacent_case: AdjacentCase) -> None:
506581
k=adjacent_case.k,
507582
filter=adjacent_case.filter,
508583
)
509-
assert_ids_any_order(results, expected)
584+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)
510585

511586
async def test_aadjacent(
512587
self, adapter: Adapter, adjacent_case: AdjacentCase
@@ -520,4 +595,4 @@ async def test_aadjacent(
520595
k=adjacent_case.k,
521596
filter=adjacent_case.filter,
522597
)
523-
assert_ids_any_order(results, expected)
598+
assert_ids_in_cosine_similarity_order(results, expected, embedding, adapter)

packages/graph-retriever/src/graph_retriever/traversal.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -343,19 +343,18 @@ def _contents_to_new_nodes(
343343
c.id: c for c in contents if c.id not in self._discovered_node_ids
344344
}
345345

346-
# Compute scores (as needed).
347-
if any(c.score is None for c in content_dict.values()):
348-
scores = cosine_similarity(
349-
[self.strategy._query_embedding],
350-
[c.embedding for c in content_dict.values() if c.score is None],
351-
)[0]
352-
else:
353-
scores = []
346+
if len(content_dict) == 0:
347+
return []
348+
349+
# Compute scores.
350+
scores: list[float] = cosine_similarity(
351+
[self.strategy._query_embedding],
352+
[c.embedding for c in content_dict.values()],
353+
)[0]
354354

355355
# Create the nodes
356-
scores_it = iter(scores)
357356
nodes = []
358-
for content in content_dict.values():
357+
for content, score in zip(content_dict.values(), scores):
359358
# Determine incoming/outgoing edges.
360359
edges = self.edge_function(content)
361360

@@ -370,7 +369,6 @@ def _contents_to_new_nodes(
370369
default=0,
371370
)
372371

373-
score = content.score or next(scores_it)
374372
nodes.append(
375373
Node(
376374
id=content.id,
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Iterable
2-
from typing import cast
32

43
from graph_retriever.content import Content
54
from graph_retriever.utils.math import cosine_similarity_top_k
@@ -12,7 +11,7 @@ def top_k(
1211
k: int,
1312
) -> list[Content]:
1413
"""
15-
Select the top-k contents from the given contet.
14+
Select the top-k contents from the given content.
1615
1716
Parameters
1817
----------
@@ -26,35 +25,32 @@ def top_k(
2625
Returns
2726
-------
2827
list[Content]
29-
Top-K by similarity. All results will have their `score` set.
28+
Top-K by similarity.
3029
"""
3130
# TODO: Consider handling specially cases of already-sorted batches (merge).
3231
# TODO: Consider passing threshold here to limit results.
3332

3433
# Use dicts to de-duplicate by ID. This ensures we choose the top K distinct
3534
# content (rather than K copies of the same content).
36-
scored = {c.id: c for c in contents if c.score is not None}
37-
unscored = {c.id: c for c in contents if c.score is None if c.id not in scored}
35+
unscored = {c.id: c for c in contents}
3836

39-
if unscored:
40-
top_unscored = _similarity_sort_top_k(
41-
list(unscored.values()), embedding=embedding, k=k
42-
)
43-
scored.update(top_unscored)
37+
top_scored = _similarity_sort_top_k(
38+
list(unscored.values()), embedding=embedding, k=k
39+
)
4440

45-
sorted = list(scored.values())
41+
sorted = list(top_scored.values())
4642
sorted.sort(key=_score, reverse=True)
4743

48-
return sorted[:k]
44+
return [c[0] for c in sorted]
4945

5046

51-
def _score(content: Content) -> float:
52-
return cast(float, content.score)
47+
def _score(content_with_score: tuple[Content, float]) -> float:
48+
return content_with_score[1]
5349

5450

5551
def _similarity_sort_top_k(
5652
contents: list[Content], *, embedding: list[float], k: int
57-
) -> dict[str, Content]:
53+
) -> dict[str, tuple[Content, float]]:
5854
# Flatten the content and use a dict to deduplicate.
5955
# We need to do this *before* selecting the top_k to ensure we don't
6056
# get duplicates (and fail to produce `k`).
@@ -65,6 +61,5 @@ def _similarity_sort_top_k(
6561
results = {}
6662
for (_x, y), score in zip(top_k, scores):
6763
c = contents[y]
68-
c.score = score
69-
results[c.id] = c
64+
results[c.id] = (c, score)
7065
return results

packages/graph-retriever/tests/strategies/test_eager.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from graph_retriever.strategies import (
77
Eager,
88
)
9+
from graph_retriever.testing.adapter_tests import cosine_similarity_scores
910
from graph_retriever.testing.embeddings import (
1011
ParserEmbeddings,
1112
angular_2d_embedding,
@@ -133,23 +134,17 @@ async def test_animals_habitat(animals: Adapter, sync_or_async: SyncOrAsync):
133134
]
134135

135136

136-
async def test_animals_populates_metrics(animals: Adapter, sync_or_async: SyncOrAsync):
137-
"""Test that score and depth are populated."""
137+
async def test_animals_populates_metrics_and_order(
138+
animals: Adapter, sync_or_async: SyncOrAsync
139+
):
140+
"""Test that score and depth are populated and results are returned in order."""
138141
results = await sync_or_async.traverse(
139142
store=animals,
140143
query=ANIMALS_QUERY,
141144
edges=[("habitat", "habitat")],
142145
strategy=Eager(select_k=100, start_k=2, max_depth=2),
143146
)()
144147

145-
expected_similarity_scores = {
146-
"mongoose": 0.578682,
147-
"bobcat": 0.02297939,
148-
"cobra": 0.01365448699,
149-
"deer": 0.1869947,
150-
"elk": 0.02876833,
151-
"fox": 0.533316,
152-
}
153148
expected_depths = {
154149
"mongoose": 0,
155150
"bobcat": 1,
@@ -159,6 +154,10 @@ async def test_animals_populates_metrics(animals: Adapter, sync_or_async: SyncOr
159154
"fox": 0,
160155
}
161156

157+
expected_similarity_scores = cosine_similarity_scores(
158+
animals, ANIMALS_QUERY, list(expected_depths.keys())
159+
)
160+
162161
for n in results:
163162
assert n.extra_metadata["_similarity_score"] == pytest.approx(
164163
expected_similarity_scores[n.id]
@@ -167,6 +166,15 @@ async def test_animals_populates_metrics(animals: Adapter, sync_or_async: SyncOr
167166
f"incorrect depth for {n.id}"
168167
)
169168

169+
expected_ids_in_order = sorted(
170+
expected_similarity_scores.keys(),
171+
key=lambda id: expected_similarity_scores[id],
172+
reverse=True,
173+
)
174+
assert [n.id for n in results] == expected_ids_in_order, (
175+
"incorrect order of results"
176+
)
177+
170178

171179
async def test_animals_habitat_to_keywords(
172180
animals: Adapter, sync_or_async: SyncOrAsync

0 commit comments

Comments
 (0)