@@ -749,19 +749,48 @@ def neighbors(self, target_emb, k=None):
749749 """
750750 Retrieve k nearest neighbors by cosine distance.
751751
752- :param target_emb: Query embedding vector.
753- :type target_emb: Sequence[float]
752+ Accepts either a single embedding vector or a list/array of embedding vectors.
753+ Maintains backward compatibility: single embeddings return a single result list,
754+ batch embeddings return a list of result lists.
755+
756+ :param target_emb: Query embedding vector (1D) or list/array of vectors (2D).
757+ If 1D, returns a single result list.
758+ If 2D, returns a list of result lists.
759+ :type target_emb: Union[Sequence[float], np.ndarray, List]
754760 :param k: Override number of neighbors (defaults to self.k).
755761 :type k: int
756- :return: List of (item_id, distance) pairs ordered by proximity.
757- :rtype: List[Tuple[Any, float]]
762+ :return: If target_emb is 1D: List of (item_id, distance) tuples (backward compatible).
763+ If target_emb is 2D: List of lists of (item_id, distance) tuples.
764+ :rtype: Union[List[Tuple[Any, float]], List[List[Tuple[Any, float]]]]
758765 """
759766 k = k or self .k
760- dists , indexes = self .model .kneighbors ([target_emb ],
761- min (k , len (self .model .ix2id )),
767+ k_neighbors = min (k , len (self .model .ix2id ))
768+
769+ # Detect if input is single embedding (1D) or batch (2D)
770+ is_single = False
771+ if isinstance (target_emb , np .ndarray ):
772+ is_single = target_emb .ndim == 1
773+ query_embeddings = target_emb if target_emb .ndim == 2 else [target_emb ]
774+ else :
775+ # Convert to numpy array to check dimensionality
776+ target_emb_array = np .array (target_emb )
777+ is_single = target_emb_array .ndim == 1
778+ query_embeddings = target_emb_array if target_emb_array .ndim == 2 else [target_emb_array ]
779+
780+ # Query all embeddings at once
781+ dists , indexes = self .model .kneighbors (query_embeddings ,
782+ k_neighbors ,
762783 return_distance = True )
763- dists , indexes = dists [0 ], indexes [0 ]
764- return [(self .model .ix2id [indexes [ix ]], dist ) for ix , dist in enumerate (dists )]
784+
785+ # Convert results to list of neighbor lists
786+ results = []
787+ for i in range (len (dists )):
788+ neighbors_list = [(self .model .ix2id [indexes [i ][j ]], dists [i ][j ])
789+ for j in range (len (dists [i ]))]
790+ results .append (neighbors_list )
791+
792+ # If input was single, return single result (backward compatible)
793+ return results [0 ] if is_single else results
765794
766795 __call__ = neighbors
767796
0 commit comments