Skip to content

Commit 405c365

Browse files
committed
fix bugs and add method
1 parent 647aab9 commit 405c365

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

ms2query/database/ann_index.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from .spectra_merging import ensure_merged_tables # schema with precursor_mz + metadata fields
1212
from .database_utils import blob_to_ndarray, ndarray_to_blob
13+
from ms2query.spectral_processing import normalize_spectrum_sum
14+
from ms2query.spectral_processing import normalize_spectrum_sum
1315

1416

1517
@dataclass
@@ -64,6 +66,7 @@ def load_model(self):
6466
"""Lazy-load MS2DeepScore model."""
6567
if self._model is None:
6668
self._model = load_model(self.model_path)
69+
self._model.eval()
6770
return self._model
6871

6972
# ---------- step 2a: embeddings ----------
@@ -289,6 +292,9 @@ def query(
289292
if isinstance(queries, Spectrum):
290293
queries = [queries]
291294

295+
# Assure same processing as for index building (here only minimal first variant: normalize to sum=1)
296+
queries = [normalize_spectrum_sum(q) for q in queries]
297+
292298
model = self.load_model()
293299
Q = compute_embedding_array(model, queries).astype(np.float32, copy=False)
294300

@@ -311,7 +317,9 @@ def query(
311317
if mid == -1:
312318
continue
313319
dist = float(distances[qi, rk])
314-
score = dist if self.faiss_metric.lower() == "ip" else -dist
320+
score = dist if self.faiss_metric.lower() == "ip" else 1 - dist
321+
dist = 1 - dist if self.faiss_metric.lower() == "ip" else dist
322+
315323
item: Dict[str, Any] = {"rank": rk+1, "merged_id": mid, "score": score, "distance": dist}
316324

317325
if include_metadata or include_peaks or include_sources:
@@ -363,3 +371,47 @@ def query(
363371
df = df[cols_front + cols_rest]
364372
dfs.append(df)
365373
return dfs
374+
375+
def get_embeddings(
376+
self,
377+
ids: Optional[List[int]] = None,
378+
*,
379+
normalized: Optional[bool] = None,
380+
) -> Tuple[np.ndarray, np.ndarray]:
381+
"""
382+
Fetch embeddings by merged_id (or all if ids=None).
383+
Returns (ids_array[int64], embeddings[float32 of shape (n, d)]).
384+
If normalized is True and faiss_metric=='ip', L2-normalize before returning.
385+
"""
386+
cur = self.conn.cursor()
387+
if ids is None:
388+
cur.execute("SELECT merged_id, embedding FROM merged_embeddings ORDER BY merged_id ASC;")
389+
else:
390+
ph = ",".join("?" for _ in ids)
391+
cur.execute(f"SELECT merged_id, embedding FROM merged_embeddings WHERE merged_id IN ({ph}) ORDER BY merged_id ASC;", ids)
392+
393+
mids = []
394+
vecs = []
395+
for mid, blob in cur:
396+
mids.append(int(mid))
397+
vecs.append(blob_to_ndarray(blob).astype(np.float32, copy=False))
398+
if not vecs:
399+
return np.empty((0,), dtype=np.int64), np.empty((0, 0), dtype=np.float32)
400+
401+
X = np.vstack(vecs).astype(np.float32, copy=False)
402+
if normalized is None:
403+
normalized = (self.faiss_metric.lower() == "ip" and self.normalize_embeddings)
404+
if normalized:
405+
faiss.normalize_L2(X)
406+
return np.asarray(mids, dtype=np.int64), X
407+
408+
def get_embedding_for_id(
409+
self,
410+
merged_id: int,
411+
*,
412+
normalized: Optional[bool] = None,
413+
) -> Optional[np.ndarray]:
414+
ids, X = self.get_embeddings([merged_id], normalized=normalized)
415+
if X.shape[0] == 0:
416+
return None
417+
return X[0]

0 commit comments

Comments
 (0)