1414
1515from vicinity import Metric
1616from vicinity .backends import AbstractBackend , BasicBackend , BasicVectorStore , get_backend_class
17- from vicinity .datatypes import Backend , PathLike
17+ from vicinity .datatypes import Backend , PathLike , QueryResult
1818
1919logger = logging .getLogger (__name__ )
2020
@@ -114,7 +114,7 @@ def query(
114114 self ,
115115 vectors : npt .NDArray ,
116116 k : int = 10 ,
117- ) -> list [list [ tuple [ str , float ]] ]:
117+ ) -> list [QueryResult ]:
118118 """
119119 Find the nearest neighbors to some arbitrary vector.
120120
@@ -140,22 +140,26 @@ def query_threshold(
140140 self ,
141141 vectors : npt .NDArray ,
142142 threshold : float = 0.5 ,
143- ) -> list [list [str ]]:
143+ max_k : int = 100 ,
144+ ) -> list [QueryResult ]:
144145 """
145- Find the nearest neighbors to some arbitrary vector with some threshold.
146+ Find the nearest neighbors to some arbitrary vector with some threshold. Note: the output is not sorted.
146147
147148 :param vectors: The vectors to find the most similar vectors to.
148149 :param threshold: The threshold to use.
150+ :param max_k: The maximum number of neighbors to consider for the threshold query.
149151
150- :return: For each item in the input, all items above the threshold are returned.
152+ :return: For each item in the input, the items above the threshold are returned in the form of
153+ (NAME, SIMILARITY) tuples.
151154 """
152- vectors = np .array (vectors )
155+ vectors = np .asarray (vectors )
153156 if np .ndim (vectors ) == 1 :
154157 vectors = vectors [None , :]
155158
156159 out = []
157- for indexes in self .backend .threshold (vectors , threshold ):
158- out .append ([self .items [idx ] for idx in indexes ])
160+ for indices , distances in self .backend .threshold (vectors , threshold , max_k = max_k ):
161+ distances .clip (min = 0 , out = distances )
162+ out .append ([(self .items [idx ], dist ) for idx , dist in zip (indices , distances )])
159163
160164 return out
161165
0 commit comments