Skip to content

Commit b823573

Browse files
committed
add ANN index (faiss)
1 parent a76ca07 commit b823573

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed

ms2query/database/ann_index.py

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
from dataclasses import dataclass, field
2+
from typing import List, Dict, Any, Optional, Tuple, Union
3+
import sqlite3
4+
import json
5+
import numpy as np
6+
import pandas as pd
7+
from matchms import Spectrum
8+
from ms2deepscore.models import load_model
9+
from ms2deepscore import compute_embedding_array
10+
import faiss
11+
12+
from .spectra_merging import ensure_merged_tables # schema with precursor_mz + metadata fields
13+
from .database_utils import blob_to_ndarray, ndarray_to_blob
14+
15+
16+
@dataclass
17+
class ANNIndex:
18+
"""
19+
End-to-end manager for MS2DeepScore ANN indices backed by SQLite.
20+
21+
Responsibilities
22+
---------------
23+
- Compute and persist MS2DeepScore embeddings for rows in `merged_spectra`
24+
into `merged_embeddings`.
25+
- Build a FAISS index whose IDs are `merged_id`.
26+
- Query the index and resolve results to rich metadata (and optionally peaks/sources).
27+
28+
Attributes
29+
----------
30+
conn : sqlite3.Connection
31+
Connection to the same SQLite DB that contains `merged_spectra` and `merged_embeddings`.
32+
model_file_name : str
33+
Path to the MS2DeepScore .pt model.
34+
faiss_metric:
35+
Metric the index was built with. If "ip" and `normalize_query=True`, we L2-normalize
36+
query embeddings (cosine-like behavior).
37+
faiss_factory:
38+
Select available faiss factory ( e.g. "IVF4096,Flat", "HNSW32", ...see faiss documentation).
39+
normalize_embeddings:
40+
Normalize stored vectors for IP to behave cosine-like. Default is True.
41+
_model:
42+
MS2DeepScore model.
43+
_index:
44+
ANN index.
45+
46+
Notes
47+
-----
48+
- This class is intentionally state-light, it holds a `sqlite3.Connection`, a FAISS index,
49+
and a lazily-loaded MS2DeepScore model.
50+
"""
51+
conn: sqlite3.Connection
52+
model_path: str
53+
faiss_metric: str = "ip"
54+
faiss_factory: Optional[str] = None
55+
normalize_embeddings: bool = True
56+
_model: Any = field(default=None, init=False, repr=False)
57+
_index: Optional[faiss.Index] = field(default=None, init=False, repr=False)
58+
59+
# ---------- lifecycle ----------
60+
def ensure_schema(self) -> None:
61+
"""Create (if absent) the merged tables with the richer schema."""
62+
ensure_merged_tables(self.conn)
63+
64+
def load_model(self):
65+
"""Lazy-load MS2DeepScore model."""
66+
if self._model is None:
67+
self._model = load_model(self.model_path)
68+
return self._model
69+
70+
# ---------- step 2a: embeddings ----------
71+
def compute_embeddings_to_sqlite(
72+
self,
73+
*,
74+
batch_rows: int = 1024,
75+
only_missing: bool = True,
76+
commit_every: int = 0,
77+
) -> int:
78+
"""
79+
Compute MS2DeepScore embeddings for `merged_spectra` and write to `merged_embeddings`.
80+
81+
Parameters
82+
----------
83+
batch_rows : int
84+
Number of DB rows to embed at once.
85+
only_missing : bool
86+
If True, only embed rows not already in `merged_embeddings`.
87+
commit_every : int
88+
Forces a commit after this many merged_ids. 0 means commit only per batch.
89+
90+
Returns
91+
-------
92+
int
93+
Number of embeddings inserted/updated.
94+
"""
95+
self.ensure_schema()
96+
cur = self.conn.cursor()
97+
cur.execute("PRAGMA foreign_keys = ON;")
98+
99+
if only_missing:
100+
query = """
101+
SELECT s.merged_id, s.mz, s.intensities, s.precursor_mz, s.ionmode, s.charge
102+
FROM merged_spectra s
103+
LEFT JOIN merged_embeddings e ON s.merged_id = e.merged_id
104+
WHERE e.merged_id IS NULL
105+
ORDER BY s.merged_id ASC;
106+
"""
107+
else:
108+
query = """
109+
SELECT merged_id, mz, intensities, precursor_mz, ionmode, charge
110+
FROM merged_spectra
111+
ORDER BY merged_id ASC;
112+
"""
113+
cur.execute(query)
114+
115+
model = self.load_model()
116+
117+
inserted = 0
118+
buf: List[Tuple[int, bytes, bytes, float, str, Optional[int]]] = []
119+
done_since_commit = 0
120+
121+
def flush(batch: List[Tuple]) -> int:
122+
if not batch:
123+
return 0
124+
specs: List[Spectrum] = []
125+
mids: List[int] = []
126+
for mid, mz_blob, it_blob, prec_mz, ionmode, charge in batch:
127+
mz = blob_to_ndarray(mz_blob).astype(np.float32, copy=False)
128+
it = blob_to_ndarray(it_blob).astype(np.float32, copy=False)
129+
specs.append(Spectrum(mz=mz, intensities=it, metadata={
130+
"precursor_mz": float(prec_mz),
131+
"ionmode": ionmode,
132+
"charge": charge,
133+
}))
134+
mids.append(mid)
135+
136+
emb = compute_embedding_array(model, specs).astype(np.float32, copy=False)
137+
q = "INSERT OR REPLACE INTO merged_embeddings (merged_id, embedding) VALUES (?, ?);"
138+
with self.conn:
139+
for mid, vec in zip(mids, emb):
140+
self.conn.execute(q, (mid, sqlite3.Binary(ndarray_to_blob(vec))))
141+
return len(batch)
142+
143+
while True:
144+
rows = cur.fetchmany(batch_rows)
145+
if not rows:
146+
break
147+
buf.extend(rows)
148+
# process buffer in batch_rows-sized chunks
149+
while len(buf) >= batch_rows:
150+
inserted += flush(buf[:batch_rows])
151+
buf = buf[batch_rows:]
152+
done_since_commit += batch_rows
153+
if commit_every and done_since_commit >= commit_every:
154+
self.conn.commit()
155+
done_since_commit = 0
156+
157+
inserted += flush(buf)
158+
return inserted
159+
160+
# ---------- step 2b: build faiss ----------
161+
def build_index(self, *, index_path: Optional[str] = None) -> faiss.Index:
162+
"""
163+
Build a FAISS index from `merged_embeddings`. Uses IndexIDMap2 (ids=merged_id).
164+
Optionally saves to disk.
165+
166+
Returns
167+
-------
168+
faiss.Index
169+
"""
170+
cur = self.conn.cursor()
171+
cur.execute("SELECT merged_id, embedding FROM merged_embeddings ORDER BY merged_id ASC;")
172+
173+
first = cur.fetchone()
174+
if not first:
175+
raise ValueError("No embeddings present in 'merged_embeddings'.")
176+
first_id, first_blob = first
177+
first_vec = blob_to_ndarray(first_blob).astype(np.float32, copy=False)
178+
d = int(first_vec.shape[-1])
179+
180+
metric = faiss.METRIC_INNER_PRODUCT if self.faiss_metric.lower() == "ip" else faiss.METRIC_L2
181+
base = (faiss.index_factory(d, self.faiss_factory, metric)
182+
if self.faiss_factory else
183+
(faiss.IndexFlatIP(d) if metric == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(d)))
184+
index = faiss.IndexIDMap2(base)
185+
186+
# add first vector
187+
X = first_vec[None, :]
188+
if self.faiss_metric.lower() == "ip" and self.normalize_embeddings:
189+
faiss.normalize_L2(X)
190+
index.add_with_ids(X, np.array([first_id], dtype=np.int64))
191+
192+
# stream the rest
193+
BATCH = 8192
194+
ids_buf, vec_buf = [], []
195+
while True:
196+
rows = cur.fetchmany(BATCH)
197+
if not rows:
198+
break
199+
for mid, blob in rows:
200+
ids_buf.append(mid)
201+
vec_buf.append(blob_to_ndarray(blob))
202+
if ids_buf:
203+
Xb = np.vstack(vec_buf).astype(np.float32, copy=False)
204+
if self.faiss_metric.lower() == "ip" and self.normalize_embeddings:
205+
faiss.normalize_L2(Xb)
206+
index.add_with_ids(Xb, np.asarray(ids_buf, dtype=np.int64))
207+
ids_buf.clear()
208+
vec_buf.clear()
209+
210+
if index_path:
211+
faiss.write_index(index, index_path)
212+
213+
self._index = index
214+
return index
215+
216+
def load_index(self, index_path: str) -> faiss.Index:
217+
"""Load an existing FAISS index from disk."""
218+
self._index = faiss.read_index(index_path)
219+
return self._index
220+
221+
@property
222+
def index(self) -> faiss.Index:
223+
if self._index is None:
224+
raise RuntimeError("FAISS index not built/loaded. Call build_index() or load_index().")
225+
return self._index
226+
227+
# ---------- querying ----------
228+
def _fetch_merged_rows(
229+
self,
230+
ids: List[int],
231+
*,
232+
include_peaks: bool = False,
233+
) -> Dict[int, Dict[str, Any]]:
234+
"""Bulk-fetch rows from merged_spectra for a list of merged_id values."""
235+
if not ids:
236+
return {}
237+
out: Dict[int, Dict[str, Any]] = {}
238+
CHUNK = 1000
239+
q_base = (
240+
"SELECT merged_id, comp_id, ionmode, charge, precursor_mz, "
241+
"smiles, inchikey, inchi, name, instrument_type, adduct, collision_energy, "
242+
"num_merged, source_spec_ids, mz, intensities "
243+
"FROM merged_spectra WHERE merged_id IN ({ph});"
244+
)
245+
for i in range(0, len(ids), CHUNK):
246+
chunk = ids[i:i+CHUNK]
247+
q = q_base.format(ph=",".join("?" for _ in chunk))
248+
rows = self.conn.execute(q, chunk).fetchall()
249+
for r in rows:
250+
(mid, comp_id, ionmode, charge, precursor_mz, smiles, inchikey, inchi, name,
251+
instrument_type, adduct, collision_energy, num_merged, source_spec_ids,
252+
mz_blob, intens_blob) = r
253+
row = {
254+
"merged_id": mid, "comp_id": comp_id, "ionmode": ionmode, "charge": charge,
255+
"precursor_mz": precursor_mz, "smiles": smiles, "inchikey": inchikey,
256+
"inchi": inchi, "name": name, "instrument_type": instrument_type,
257+
"adduct": adduct, "collision_energy": collision_energy, "num_merged": num_merged,
258+
"source_spec_ids": json.loads(source_spec_ids) if source_spec_ids else [],
259+
}
260+
if include_peaks:
261+
row["mz"] = blob_to_ndarray(mz_blob).astype(np.float32, copy=False)
262+
row["intensities"] = blob_to_ndarray(intens_blob).astype(np.float32, copy=False)
263+
out[mid] = row
264+
return out
265+
266+
def query(
267+
self,
268+
queries: Union[Spectrum, List[Spectrum]],
269+
*,
270+
k: int = 10,
271+
include_metadata: bool = True,
272+
include_peaks: bool = False,
273+
include_sources: bool = False,
274+
sdb: Optional[Any] = None,
275+
as_dataframe: bool = True,
276+
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
277+
"""
278+
Embed query spectrum(ae), search FAISS, and resolve hits via SQLite.
279+
280+
Returns per-query results (DataFrame by default) with columns:
281+
['rank','merged_id','score','distance', 'comp_id','name','ionmode','charge',
282+
'precursor_mz','adduct','collision_energy','num_merged','source_spec_ids', ...]
283+
(plus 'merged_spectrum' / 'source_spectra' if requested).
284+
285+
Notes
286+
-----
287+
- For `faiss_metric="ip"` with normalized vectors, 'score' is cosine-like.
288+
- For `faiss_metric="l2"`, 'score' is just `-distance` (so “higher is better”).
289+
"""
290+
if isinstance(queries, Spectrum):
291+
queries = [queries]
292+
293+
model = self.load_model()
294+
Q = compute_embedding_array(model, queries).astype(np.float32, copy=False)
295+
296+
if self.faiss_metric.lower() == "ip" and self.normalize_embeddings:
297+
faiss.normalize_L2(Q)
298+
299+
distances, ids = self.index.search(Q, k) # (nq, k)
300+
nq = distances.shape[0]
301+
flat_ids = [int(x) for x in ids.flatten().tolist() if x != -1]
302+
303+
rows_by_id: Dict[int, Dict[str, Any]] = {}
304+
if include_metadata or include_peaks or include_sources:
305+
rows_by_id = self._fetch_merged_rows(flat_ids, include_peaks=include_peaks)
306+
307+
results_all: List[List[Dict[str, Any]]] = []
308+
for qi in range(nq):
309+
one: List[Dict[str, Any]] = []
310+
for rk in range(k):
311+
mid = int(ids[qi, rk])
312+
if mid == -1:
313+
continue
314+
dist = float(distances[qi, rk])
315+
score = dist if self.faiss_metric.lower() == "ip" else -dist
316+
item: Dict[str, Any] = {"rank": rk+1, "merged_id": mid, "score": score, "distance": dist}
317+
318+
if include_metadata or include_peaks or include_sources:
319+
row = rows_by_id.get(mid)
320+
if row:
321+
item.update({k: v for k, v in row.items() if k not in ("mz", "intensities")})
322+
if include_peaks:
323+
mz = row.get("mz")
324+
it = row.get("intensities")
325+
item["merged_spectrum"] = Spectrum(mz=mz, intensities=it, metadata={
326+
"precursor_mz": row["precursor_mz"],
327+
"ionmode": row["ionmode"],
328+
"charge": row["charge"],
329+
"comp_id": row["comp_id"],
330+
"num_merged": row["num_merged"],
331+
"source_spec_ids": row["source_spec_ids"],
332+
"name": row.get("name"),
333+
"adduct": row.get("adduct"),
334+
}) if (mz is not None and it is not None) else None
335+
if include_sources:
336+
if sdb is None:
337+
raise ValueError("include_sources=True requires sdb.")
338+
item["source_spectra"] = sdb.get_spectra_by_ids(row.get("source_spec_ids", []))
339+
else:
340+
# keep columns consistent
341+
if include_peaks:
342+
item["merged_spectrum"] = None
343+
if include_sources:
344+
item["source_spectra"] = []
345+
346+
one.append(item)
347+
results_all.append(one)
348+
349+
if not as_dataframe:
350+
return results_all
351+
352+
# convert to tidy DataFrames with a nice column order
353+
dfs: List[pd.DataFrame] = []
354+
base_order = [
355+
"rank", "merged_id", "score", "distance", "comp_id", "name",
356+
"ionmode", "charge", "precursor_mz", "adduct", "collision_energy",
357+
"num_merged", "source_spec_ids"
358+
]
359+
for one in results_all:
360+
df = pd.DataFrame(one)
361+
# move known columns up front
362+
cols_front = [c for c in base_order if c in df.columns]
363+
cols_rest = [c for c in df.columns if c not in cols_front]
364+
df = df[cols_front + cols_rest]
365+
dfs.append(df)
366+
return dfs

0 commit comments

Comments
 (0)