Skip to content

Commit 2c35eac

Browse files
authored
feat: Added scores to query threshold (#54)
* Changed query threshold logic * Updated method * Used mask
1 parent 2533a86 commit 2c35eac

10 files changed

Lines changed: 60 additions & 45 deletions

File tree

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vicinity/backends/annoy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,10 @@ def delete(self, indices: list[int]) -> None:
126126
"""Delete vectors from the backend."""
127127
raise NotImplementedError("Deletion is not supported in Annoy backend.")
128128

129-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
129+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
130130
"""Threshold the backend."""
131-
out: list[npt.NDArray] = []
132-
for x, y in self.query(vectors, 100):
133-
out.append(x[y < threshold])
131+
out: QueryResult = []
132+
for x, y in self.query(vectors, max_k):
133+
mask = y < threshold
134+
out.append((x[mask], y[mask]))
134135
return out

vicinity/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def delete(self, indices: list[int]) -> None:
9292
raise NotImplementedError()
9393

9494
@abstractmethod
95-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
95+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
9696
"""Threshold the backend."""
9797
raise NotImplementedError()
9898

vicinity/backends/basic.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,25 @@ def threshold(
150150
self,
151151
vectors: npt.NDArray,
152152
threshold: float,
153-
) -> list[npt.NDArray]:
153+
max_k: int,
154+
) -> QueryResult:
154155
"""
155156
Batched distance thresholding.
156157
157158
:param vectors: The vectors to threshold.
158159
:param threshold: The threshold to use.
159-
:return: A list of lists of indices of vectors that are below the threshold
160+
:param max_k: The maximum number of neighbors to consider.
161+
:return: A list of tuples with the indices and distances.
160162
"""
161-
out: list[npt.NDArray] = []
163+
out: QueryResult = []
162164
for i in range(0, len(vectors), 1024):
163165
batch = vectors[i : i + 1024]
164166
distances = self._dist(batch)
165167
for dists in distances:
166-
indices = np.flatnonzero(dists <= threshold)
167-
sorted_indices = indices[np.argsort(dists[indices])]
168-
out.append(sorted_indices)
168+
mask = dists <= threshold
169+
indices = np.flatnonzero(mask)
170+
filtered_distances = dists[mask]
171+
out.append((indices, filtered_distances))
169172
return out
170173

171174
def query(

vicinity/backends/faiss.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ def delete(self, indices: list[int]) -> None:
164164
"""Delete vectors from the backend."""
165165
raise NotImplementedError("Deletion is not supported in FAISS backends.")
166166

167-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
167+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
168168
"""Query vectors within a distance threshold, using range_search if supported."""
169-
out: list[npt.NDArray] = []
169+
out: QueryResult = []
170170
if self.arguments.metric == "cosine":
171171
vectors = normalize(vectors)
172172

@@ -179,13 +179,15 @@ def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]
179179
dist = D[start:end]
180180
if self.arguments.metric == "cosine":
181181
dist = 1 - dist
182-
out.append(idx[dist < threshold])
182+
mask = dist < threshold
183+
out.append((idx[mask], dist[mask]))
183184
else:
184-
distances, indices = self.index.search(vectors, 100)
185+
distances, indices = self.index.search(vectors, max_k)
185186
for dist, idx in zip(distances, indices):
186187
if self.arguments.metric == "cosine":
187188
dist = 1 - dist
188-
out.append(idx[dist < threshold])
189+
mask = dist < threshold
190+
out.append((idx[mask], dist[mask]))
189191

190192
return out
191193

vicinity/backends/hnsw.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,11 @@ def delete(self, indices: list[int]) -> None:
104104
"""Delete vectors from the backend."""
105105
raise NotImplementedError("Deletion is not supported in HNSW backend.")
106106

107-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
107+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
108108
"""Threshold the backend."""
109-
out: list[npt.NDArray] = []
110-
for x, y in self.query(vectors, 100):
111-
out.append(x[y < threshold])
109+
out: QueryResult = []
110+
for x, y in self.query(vectors, max_k):
111+
mask = y < threshold
112+
out.append((x[mask], y[mask]))
112113

113114
return out

vicinity/backends/pynndescent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def delete(self, indices: list[int]) -> None:
8080
"""Delete vectors from the backend."""
8181
raise NotImplementedError("Deletion is not supported in PyNNDescent backend.")
8282

83-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
83+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
8484
"""Find neighbors within a distance threshold."""
8585
normalized_vectors = normalize_or_copy(vectors)
86-
indices, distances = self.index.query(normalized_vectors, k=100)
87-
result = []
86+
indices, distances = self.index.query(normalized_vectors, k=max_k)
87+
out: QueryResult = []
8888
for idx, dist in zip(indices, distances):
89-
within_threshold = idx[dist < threshold]
90-
result.append(within_threshold)
91-
return result
89+
mask = dist < threshold
90+
out.append((idx[mask], dist[mask]))
91+
return out
9292

9393
def save(self, base_path: Path) -> None:
9494
"""Save the vectors and configuration to a specified path."""

vicinity/backends/usearch.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,12 @@ def delete(self, indices: list[int]) -> None:
129129
"""Delete vectors from the index (not supported by Usearch)."""
130130
raise NotImplementedError("Dynamic deletion is not supported in Usearch.")
131131

132-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
133-
"""Threshold the backend and return filtered keys."""
134-
return [
135-
np.array(keys_row)[np.array(distances_row, dtype=np.float32) < threshold]
136-
for keys_row, distances_row in self.query(vectors, 100)
137-
]
132+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
133+
"""Query vectors within a distance threshold and return keys and distances."""
134+
out: QueryResult = []
135+
for keys_row, distances_row in self.query(vectors, max_k):
136+
keys_row = np.array(keys_row)
137+
distances_row = np.array(distances_row, dtype=np.float32)
138+
mask = distances_row < threshold
139+
out.append((keys_row[mask], distances_row[mask]))
140+
return out

vicinity/backends/voyager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def delete(self, indices: list[int]) -> None:
9494
"""Delete vectors from the backend."""
9595
raise NotImplementedError("Deletion is not supported in Voyager backend.")
9696

97-
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
97+
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
9898
"""Threshold the backend."""
99-
out: list[npt.NDArray] = []
100-
for x, y in self.query(vectors, len(self)):
101-
out.append(x[y < threshold])
99+
out: list[tuple[npt.NDArray, npt.NDArray]] = []
100+
for x, y in self.query(vectors, max_k):
101+
mask = y < threshold
102+
out.append((x[mask], y[mask]))
102103

103104
return out
104105

vicinity/vicinity.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from vicinity import Metric
1616
from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class
17-
from vicinity.datatypes import Backend, PathLike
17+
from vicinity.datatypes import Backend, PathLike, QueryResult
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -114,7 +114,7 @@ def query(
114114
self,
115115
vectors: npt.NDArray,
116116
k: int = 10,
117-
) -> list[list[tuple[str, float]]]:
117+
) -> list[QueryResult]:
118118
"""
119119
Find the nearest neighbors to some arbitrary vector.
120120
@@ -140,22 +140,26 @@ def query_threshold(
140140
self,
141141
vectors: npt.NDArray,
142142
threshold: float = 0.5,
143-
) -> list[list[str]]:
143+
max_k: int = 100,
144+
) -> list[QueryResult]:
144145
"""
145-
Find the nearest neighbors to some arbitrary vector with some threshold.
146+
Find the nearest neighbors to some arbitrary vector with some threshold. Note: the output is not sorted.
146147
147148
:param vectors: The vectors to find the most similar vectors to.
148149
:param threshold: The threshold to use.
150+
:param max_k: The maximum number of neighbors to consider for the threshold query.
149151
150-
:return: For each item in the input, all items above the threshold are returned.
152+
:return: For each item in the input, the items above the threshold are returned in the form of
153+
(NAME, SIMILARITY) tuples.
151154
"""
152-
vectors = np.array(vectors)
155+
vectors = np.asarray(vectors)
153156
if np.ndim(vectors) == 1:
154157
vectors = vectors[None, :]
155158

156159
out = []
157-
for indexes in self.backend.threshold(vectors, threshold):
158-
out.append([self.items[idx] for idx in indexes])
160+
for indices, distances in self.backend.threshold(vectors, threshold, max_k=max_k):
161+
distances.clip(min=0, out=distances)
162+
out.append([(self.items[idx], dist) for idx, dist in zip(indices, distances)])
159163

160164
return out
161165

0 commit comments

Comments
 (0)