8
8
from graph_retriever import Content
9
9
from graph_retriever .adapters import Adapter
10
10
from graph_retriever .edges import Edge , IdEdge , MetadataEdge
11
+ from graph_retriever .utils .math import cosine_similarity
11
12
12
13
13
14
def assert_valid_result (content : Content ):
@@ -40,6 +41,48 @@ def assert_ids_any_order(
40
41
assert set (result_ids ) == set (expected ), "should contain exactly expected IDs"
41
42
42
43
44
+ def cosine_similarity_scores (
45
+ adapter : Adapter , query_or_embedding : str | list [float ], ids : list [str ]
46
+ ) -> dict [str , float ]:
47
+ """Return the cosine similarity scores for the given IDs and query embedding."""
48
+ if len (ids ) == 0 :
49
+ return {}
50
+
51
+ docs = adapter .get (ids )
52
+ found_ids = (d .id for d in docs )
53
+ assert set (ids ) == set (found_ids ), "can't find all IDs"
54
+
55
+ if isinstance (query_or_embedding , str ):
56
+ query_embedding = adapter .search_with_embedding (query_or_embedding , k = 0 )[0 ]
57
+ else :
58
+ query_embedding = query_or_embedding
59
+
60
+ scores : list [float ] = cosine_similarity (
61
+ [query_embedding ],
62
+ [d .embedding for d in docs ],
63
+ )[0 ]
64
+
65
+ return {doc .id : score for doc , score in zip (docs , scores )}
66
+
67
+
68
+ def assert_ids_in_cosine_similarity_order (
69
+ results : Iterable [Content ],
70
+ expected : list [str ],
71
+ query_embedding : list [float ],
72
+ adapter : Adapter ,
73
+ ) -> None :
74
+ """Assert the results are valid and in cosine similarity order."""
75
+ assert_valid_results (results )
76
+ result_ids = [r .id for r in results ]
77
+
78
+ similarity_scores = cosine_similarity_scores (adapter , query_embedding , expected )
79
+ expected = sorted (expected , key = lambda id : similarity_scores [id ], reverse = True )
80
+
81
+ assert result_ids == expected , (
82
+ "should contain expected IDs in cosine similarity order"
83
+ )
84
+
85
+
43
86
@dataclass (kw_only = True )
44
87
class AdapterComplianceCase (abc .ABC ):
45
88
"""
@@ -77,8 +120,40 @@ class GetCase(AdapterComplianceCase):
77
120
GetCase (id = "one" , request = ["boar" ], expected = ["boar" ]),
78
121
GetCase (
79
122
id = "many" ,
80
- request = ["boar" , "chinchilla" , "cobra" ],
81
- expected = ["boar" , "chinchilla" , "cobra" ],
123
+ request = [
124
+ "alligator" ,
125
+ "barracuda" ,
126
+ "chameleon" ,
127
+ "cobra" ,
128
+ "crocodile" ,
129
+ "dolphin" ,
130
+ "eel" ,
131
+ "fish" ,
132
+ "gecko" ,
133
+ "iguana" ,
134
+ "jellyfish" ,
135
+ "komodo dragon" ,
136
+ "lizard" ,
137
+ "manatee" ,
138
+ "narwhal" ,
139
+ ],
140
+ expected = [
141
+ "alligator" ,
142
+ "barracuda" ,
143
+ "chameleon" ,
144
+ "cobra" ,
145
+ "crocodile" ,
146
+ "dolphin" ,
147
+ "eel" ,
148
+ "fish" ,
149
+ "gecko" ,
150
+ "iguana" ,
151
+ "jellyfish" ,
152
+ "komodo dragon" ,
153
+ "lizard" ,
154
+ "manatee" ,
155
+ "narwhal" ,
156
+ ],
82
157
),
83
158
GetCase (
84
159
id = "missing" ,
@@ -410,7 +485,7 @@ def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
410
485
411
486
Generally, this should *not* change the expected results, unless the the
412
487
adapter being tested uses wildly different distance metrics or a
413
- different embedding. The `AnimalsEmbedding` is deterimistic and the
488
+ different embedding. The `AnimalsEmbedding` is deterministic and the
414
489
results across vector stores should generally be deterministic and
415
490
consistent.
416
491
@@ -469,7 +544,7 @@ def test_search_with_embedding(
469
544
search_case .query , ** search_case .kwargs
470
545
)
471
546
assert_is_embedding (embedding )
472
- assert_ids_any_order (results , expected )
547
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
473
548
474
549
async def test_asearch_with_embedding (
475
550
self , adapter : Adapter , search_case : SearchCase
@@ -480,21 +555,21 @@ async def test_asearch_with_embedding(
480
555
search_case .query , ** search_case .kwargs
481
556
)
482
557
assert_is_embedding (embedding )
483
- assert_ids_any_order (results , expected )
558
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
484
559
485
560
def test_search (self , adapter : Adapter , search_case : SearchCase ) -> None :
486
561
"""Run tests for `search`."""
487
562
expected = self .expected ("search" , search_case )
488
563
embedding , _ = adapter .search_with_embedding (search_case .query , k = 0 )
489
564
results = adapter .search (embedding , ** search_case .kwargs )
490
- assert_ids_any_order (results , expected )
565
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
491
566
492
567
async def test_asearch (self , adapter : Adapter , search_case : SearchCase ) -> None :
493
568
"""Run tests for `asearch`."""
494
569
expected = self .expected ("asearch" , search_case )
495
570
embedding , _ = await adapter .asearch_with_embedding (search_case .query , k = 0 )
496
571
results = await adapter .asearch (embedding , ** search_case .kwargs )
497
- assert_ids_any_order (results , expected )
572
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
498
573
499
574
def test_adjacent (self , adapter : Adapter , adjacent_case : AdjacentCase ) -> None :
500
575
"""Run tests for `adjacent."""
@@ -506,7 +581,7 @@ def test_adjacent(self, adapter: Adapter, adjacent_case: AdjacentCase) -> None:
506
581
k = adjacent_case .k ,
507
582
filter = adjacent_case .filter ,
508
583
)
509
- assert_ids_any_order (results , expected )
584
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
510
585
511
586
async def test_aadjacent (
512
587
self , adapter : Adapter , adjacent_case : AdjacentCase
@@ -520,4 +595,4 @@ async def test_aadjacent(
520
595
k = adjacent_case .k ,
521
596
filter = adjacent_case .filter ,
522
597
)
523
- assert_ids_any_order (results , expected )
598
+ assert_ids_in_cosine_similarity_order (results , expected , embedding , adapter )
0 commit comments