@@ -332,21 +332,57 @@ def replacement_assign(self, x, k, labels=None):
332332 self .assign_c (n , swig_ptr (x ), swig_ptr (labels ), k )
333333 return labels
334334
335- def replacement_train (self , x , numeric_type = faiss .Float32 ):
335+ def replacement_train (
336+ self , x , * , numeric_type = faiss .Float32 , xq_train = None
337+ ):
336338 """Trains the index on a representative set of vectors.
337339 The index must be trained before vectors can be added to it.
340+ Optionally accepts numeric_type to specify the type of
341+ input vectors.
342+ Optionally accepts a set of training query vectors for
343+ out-of-distribution training.
338344
339345 Parameters
340346 ----------
341347 x : array_like
342- Query vectors, shape (n, d) where d is appropriate for the index.
348+ Query vectors, shape (n, d) where d is appropriate
349+ for the index. `dtype` must be float32.
350+ numeric_type : type
351+ Numeric type of the input vectors.
352+ xq_train : array_like, optional
353+ Training query vectors, shape (n_train_q, d) where
354+ d is appropriate for the index.
343355 `dtype` must be float32.
344356 """
357+ # Prepare training data
345358 n , d = x .shape
346359 assert d == self .d
347360 x = np .ascontiguousarray (x , dtype = _numeric_to_str (numeric_type ))
361+
362+ # Prepare training queries if provided
363+ n_train_q , train_q = 0 , None
364+ if xq_train is not None :
365+ if numeric_type != faiss .Float32 :
366+ raise TypeError (
367+ "xq_train is only supported for numeric_type faiss.Float32"
368+ )
369+ n_train_q , d_train = xq_train .shape
370+ assert d_train == self .d
371+ train_q = swig_ptr (
372+ np .ascontiguousarray (
373+ xq_train ,
374+ dtype = _numeric_to_str (numeric_type ),
375+ )
376+ )
377+
378+ # Dispatch to train_c / train_ex
348379 if numeric_type == faiss .Float32 :
349- self .train_c (n , swig_ptr (x ))
380+ if train_q is not None :
381+ self .train_c (
382+ n , swig_ptr (x ), n_train_q , train_q
383+ )
384+ else :
385+ self .train_c (n , swig_ptr (x ))
350386 else :
351387 self .train_ex (n , swig_ptr (x ), numeric_type )
352388
0 commit comments