Skip to content

Commit 7a0b6cf

Browse files
Enhance KNNModel to support both single and batch embedding queries with backward compatibility
1 parent 966fa37 commit 7a0b6cf

1 file changed

Lines changed: 37 additions & 8 deletions

File tree

src/sdialog/util.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -749,19 +749,48 @@ def neighbors(self, target_emb, k=None):
749749
"""
750750
Retrieve k nearest neighbors by cosine distance.
751751
752-
:param target_emb: Query embedding vector.
753-
:type target_emb: Sequence[float]
752+
Accepts either a single embedding vector or a list/array of embedding vectors.
753+
Maintains backward compatibility: single embeddings return a single result list,
754+
batch embeddings return a list of result lists.
755+
756+
:param target_emb: Query embedding vector (1D) or list/array of vectors (2D).
757+
If 1D, returns a single result list.
758+
If 2D, returns a list of result lists.
759+
:type target_emb: Union[Sequence[float], np.ndarray, List]
754760
:param k: Override number of neighbors (defaults to self.k).
755761
:type k: int
756-
:return: List of (item_id, distance) pairs ordered by proximity.
757-
:rtype: List[Tuple[Any, float]]
762+
:return: If target_emb is 1D: List of (item_id, distance) tuples (backward compatible).
763+
If target_emb is 2D: List of lists of (item_id, distance) tuples.
764+
:rtype: Union[List[Tuple[Any, float]], List[List[Tuple[Any, float]]]]
758765
"""
759766
k = k or self.k
760-
dists, indexes = self.model.kneighbors([target_emb],
761-
min(k, len(self.model.ix2id)),
767+
k_neighbors = min(k, len(self.model.ix2id))
768+
769+
# Detect if input is single embedding (1D) or batch (2D)
770+
is_single = False
771+
if isinstance(target_emb, np.ndarray):
772+
is_single = target_emb.ndim == 1
773+
query_embeddings = target_emb if target_emb.ndim == 2 else [target_emb]
774+
else:
775+
# Convert to numpy array to check dimensionality
776+
target_emb_array = np.array(target_emb)
777+
is_single = target_emb_array.ndim == 1
778+
query_embeddings = target_emb_array if target_emb_array.ndim == 2 else [target_emb_array]
779+
780+
# Query all embeddings at once
781+
dists, indexes = self.model.kneighbors(query_embeddings,
782+
k_neighbors,
762783
return_distance=True)
763-
dists, indexes = dists[0], indexes[0]
764-
return [(self.model.ix2id[indexes[ix]], dist) for ix, dist in enumerate(dists)]
784+
785+
# Convert results to list of neighbor lists
786+
results = []
787+
for i in range(len(dists)):
788+
neighbors_list = [(self.model.ix2id[indexes[i][j]], dists[i][j])
789+
for j in range(len(dists[i]))]
790+
results.append(neighbors_list)
791+
792+
# If input was single, return single result (backward compatible)
793+
return results[0] if is_single else results
765794

766795
__call__ = neighbors
767796

0 commit comments

Comments
 (0)