@@ -48,23 +48,24 @@ class PredParams(pecos.BaseParams):
4848 """Prediction Parameters of PairwiseANN class
4949
5050 Attributes:
51- topk (int): maximum number of candidates (sorted by distances, nearest first) return by the searcher per query
51+ batch_size (int): maximum number of (input, label) pairs te be inference on for the Searchers
52+ only_topk (int): maximum number of candidates (sorted by distances, nearest first) return by kNN
5253 """
5354
54- topk : int = 10
55+ batch_size : int = 1024
56+ only_topk : int = 10
5557
5658 class Searchers (object ):
57- def __init__ (self , model , max_batch_size = 256 , max_only_topk = 10 , num_searcher = 1 ):
59+ def __init__ (self , model , pred_params , num_searcher = 1 ):
5860 self .searchers_ptr = model .fn_dict ["searchers_create" ](
5961 model .model_ptr ,
6062 num_searcher ,
6163 )
6264 self .destruct_fn = model .fn_dict ["searchers_destruct" ]
6365
6466 # searchers also hold the memory of returned np.ndarray
65- self .max_batch_size = max_batch_size
66- self .max_only_topk = max_only_topk
67- max_nnz = max_batch_size * max_only_topk
67+ self .pred_params = pred_params
68+ max_nnz = pred_params .batch_size * pred_params .only_topk
6869 self .Imat = np .zeros (max_nnz , dtype = np .uint32 )
6970 self .Mmat = np .zeros (max_nnz , dtype = np .uint32 )
7071 self .Dmat = np .zeros (max_nnz , dtype = np .float32 )
@@ -214,11 +215,18 @@ def save(self, model_folder):
214215 c_model_dir = f"{ model_folder } /c_model"
215216 self .fn_dict ["save" ](self .model_ptr , c_char_p (c_model_dir .encode ("utf-8" )))
216217
217- def searchers_create (self , max_batch_size = 256 , max_only_topk = 10 , num_searcher = 1 ):
218+ def get_pred_params (self ):
219+ """Return a deep copy of prediction parameters
220+
221+ Returns:
222+ copied_pred_params (dict): Prediction parameters.
223+ """
224+ return copy .deepcopy (self .pred_params )
225+
226+ def searchers_create (self , pred_params = None , num_searcher = 1 ):
218227 """create searchers that pre-allocate intermediate variables (e.g., topk_queue)
219228 Args:
220- max_batch_size (int): the maximum batch size for the input/label pairs to be inference
221- max_only_topk (int): the maximum only topk for the kNN to return
229+ pred_params (Pairwise.PredParams, optional): instance of pecos.ann.pairwise.Pairwise.PredParams
222230 num_searcher: number of searcher for multi-thread inference
223231 Returns:
224232 PairwiseANN.Searchers: the pre-allocated PairwiseANN.Searchers (class object)
@@ -227,31 +235,25 @@ def searchers_create(self, max_batch_size=256, max_only_topk=10, num_searcher=1)
227235 raise ValueError ("self.model_ptr must exist before using searchers_create()" )
228236 if num_searcher <= 0 :
229237 raise ValueError ("num_searcher={} <= 0 is NOT valid" .format (num_searcher ))
230- return PairwiseANN .Searchers (self , max_batch_size , max_only_topk , num_searcher )
231-
232- def get_pred_params (self ):
233- """Return a deep copy of prediction parameters
234-
235- Returns:
236- copied_pred_params (dict): Prediction parameters.
237- """
238- return copy .deepcopy (self .pred_params )
238+ pred_params = self .get_pred_params () if pred_params is None else pred_params
239+ return PairwiseANN .Searchers (self , pred_params , num_searcher )
239240
240- def predict (self , input_feat , label_keys , searchers , pred_params = None , is_same_input = False ):
241+ def predict (self , input_feat , label_keys , searchers , is_same_input = False ):
241242 """predict with multi-thread. The searchers are required to be provided.
242243 Args:
243244 input_feat (numpy.array or smat.csr_matrix): input feature matrix (first key) to find kNN.
244- if is_same_input == False, the shape should be (batch_size, feat_dim)
245- if is_same_input == True, the shape should be (1, feat_dim)
246- label_keys (numpy.array): the label keys (second key) to find kNN. The shape should be (batch_size, ).
247- searchers (c_void_p): pointer to C/C++ vector<pecos::ann::PairwiseANN:Searcher>. Created by PairwiseANN.searchers_create().
248- pred_params (Pairwise.PredParams, optional): instance of pecos.ann.pairwise.Pairwise.PredParams.
245+ if is_same_input == False, the shape should be (batch_size, feat_dim).
246+ if is_same_input == True, the shape should be (1, feat_dim).
247+ label_keys (numpy.array): the label keys (second key) to find kNN.
248+ The shape should be (batch_size, ).
249+ searchers (c_void_p): pointer to C/C++ vector<pecos::ann::PairwiseANN:Searcher>.
250+ Created by PairwiseANN.searchers_create().
249251 is_same_input (bool): whether to use the same first row of X to do prediction.
250252 For real-time inference with same input query, set is_same_input = True.
251253 For batch prediction with varying input querues, set is_same_input = False.
252254 Returns:
253255 Imat (np.array): returned kNN input key indices. Shape of (batch_size, topk)
254- Mmat (np.array): returned kNN masking array. 1/0 mean value is or is not presented. Shape of (batch_size, topk)
256+ Mmat (np.array): returned kNN masking array. 1/0 mean value IS/ISNOT presented. Shape of (batch_size, topk)
255257 Dmat (np.array): returned kNN distance array. Shape of (batch_size, topk)
256258 Vmat (np.array): returned kNN value array. Shape of (batch_size, topk)
257259 """
@@ -273,19 +275,16 @@ def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_i
273275 if not is_same_input and input_feat_py .rows != label_keys .shape [0 ]:
274276 raise ValueError (f"input_feat_py.rows != label_keys.shape[0]" )
275277
276- batch_size = label_keys .shape [0 ]
277- pred_params = self .get_pred_params () if pred_params is None else pred_params
278- only_topk = pred_params .topk
279- cur_nnz = batch_size * only_topk
280- if batch_size > searchers .max_batch_size :
281- raise ValueError (f"cur_batch_size > searchers.max_batch_size" )
282- if only_topk > searchers .max_only_topk :
283- raise ValueError (f"cur_only_topk > searchers.max_only_topk" )
278+ cur_bsz = label_keys .shape [0 ]
279+ if cur_bsz > searchers .pred_params .batch_size :
280+ raise ValueError (f"cur_batch_size > searchers.batch_size!" )
281+ only_topk = searchers .pred_params .only_topk
282+ cur_nnz = cur_bsz * only_topk
284283
285284 searchers .reset (cur_nnz )
286285 self .fn_dict ["predict" ](
287286 searchers .ctypes (),
288- batch_size ,
287+ cur_bsz ,
289288 only_topk ,
290289 input_feat_py ,
291290 label_keys .ctypes .data_as (POINTER (c_uint32 )),
@@ -295,8 +294,8 @@ def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_i
295294 searchers .Vmat .ctypes .data_as (POINTER (c_float )),
296295 c_bool (is_same_input ),
297296 )
298- Imat = searchers .Imat [:cur_nnz ].reshape (batch_size , only_topk )
299- Mmat = searchers .Mmat [:cur_nnz ].reshape (batch_size , only_topk )
300- Dmat = searchers .Dmat [:cur_nnz ].reshape (batch_size , only_topk )
301- Vmat = searchers .Vmat [:cur_nnz ].reshape (batch_size , only_topk )
297+ Imat = searchers .Imat [:cur_nnz ].reshape (cur_bsz , only_topk )
298+ Mmat = searchers .Mmat [:cur_nnz ].reshape (cur_bsz , only_topk )
299+ Dmat = searchers .Dmat [:cur_nnz ].reshape (cur_bsz , only_topk )
300+ Vmat = searchers .Vmat [:cur_nnz ].reshape (cur_bsz , only_topk )
302301 return Imat , Mmat , Dmat , Vmat
0 commit comments