1010
1111from .spectra_merging import ensure_merged_tables # schema with precursor_mz + metadata fields
1212from .database_utils import blob_to_ndarray , ndarray_to_blob
13+ from ms2query .spectral_processing import normalize_spectrum_sum
14+ from ms2query .spectral_processing import normalize_spectrum_sum
1315
1416
1517@dataclass
@@ -64,6 +66,7 @@ def load_model(self):
6466 """Lazy-load MS2DeepScore model."""
6567 if self ._model is None :
6668 self ._model = load_model (self .model_path )
69+ self ._model .eval ()
6770 return self ._model
6871
6972 # ---------- step 2a: embeddings ----------
@@ -289,6 +292,9 @@ def query(
289292 if isinstance (queries , Spectrum ):
290293 queries = [queries ]
291294
295+ # Assure same processing as for index building (here only minimal first variant: normalize to sum=1)
296+ queries = [normalize_spectrum_sum (q ) for q in queries ]
297+
292298 model = self .load_model ()
293299 Q = compute_embedding_array (model , queries ).astype (np .float32 , copy = False )
294300
@@ -311,7 +317,9 @@ def query(
311317 if mid == - 1 :
312318 continue
313319 dist = float (distances [qi , rk ])
314- score = dist if self .faiss_metric .lower () == "ip" else - dist
320+ score = dist if self .faiss_metric .lower () == "ip" else 1 - dist
321+ dist = 1 - dist if self .faiss_metric .lower () == "ip" else dist
322+
315323 item : Dict [str , Any ] = {"rank" : rk + 1 , "merged_id" : mid , "score" : score , "distance" : dist }
316324
317325 if include_metadata or include_peaks or include_sources :
@@ -363,3 +371,47 @@ def query(
363371 df = df [cols_front + cols_rest ]
364372 dfs .append (df )
365373 return dfs
374+
375+ def get_embeddings (
376+ self ,
377+ ids : Optional [List [int ]] = None ,
378+ * ,
379+ normalized : Optional [bool ] = None ,
380+ ) -> Tuple [np .ndarray , np .ndarray ]:
381+ """
382+ Fetch embeddings by merged_id (or all if ids=None).
383+ Returns (ids_array[int64], embeddings[float32 of shape (n, d)]).
384+ If normalized is True and faiss_metric=='ip', L2-normalize before returning.
385+ """
386+ cur = self .conn .cursor ()
387+ if ids is None :
388+ cur .execute ("SELECT merged_id, embedding FROM merged_embeddings ORDER BY merged_id ASC;" )
389+ else :
390+ ph = "," .join ("?" for _ in ids )
391+ cur .execute (f"SELECT merged_id, embedding FROM merged_embeddings WHERE merged_id IN ({ ph } ) ORDER BY merged_id ASC;" , ids )
392+
393+ mids = []
394+ vecs = []
395+ for mid , blob in cur :
396+ mids .append (int (mid ))
397+ vecs .append (blob_to_ndarray (blob ).astype (np .float32 , copy = False ))
398+ if not vecs :
399+ return np .empty ((0 ,), dtype = np .int64 ), np .empty ((0 , 0 ), dtype = np .float32 )
400+
401+ X = np .vstack (vecs ).astype (np .float32 , copy = False )
402+ if normalized is None :
403+ normalized = (self .faiss_metric .lower () == "ip" and self .normalize_embeddings )
404+ if normalized :
405+ faiss .normalize_L2 (X )
406+ return np .asarray (mids , dtype = np .int64 ), X
407+
408+ def get_embedding_for_id (
409+ self ,
410+ merged_id : int ,
411+ * ,
412+ normalized : Optional [bool ] = None ,
413+ ) -> Optional [np .ndarray ]:
414+ ids , X = self .get_embeddings ([merged_id ], normalized = normalized )
415+ if X .shape [0 ] == 0 :
416+ return None
417+ return X [0 ]
0 commit comments