Skip to content

Commit 80c1670

Browse files
committed
add and update tests
1 parent 8453185 commit 80c1670

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

tests/test_ann_index.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import io
2+
import json
3+
import sqlite3
4+
from typing import List, Tuple
5+
6+
import numpy as np
7+
import pytest
8+
from matchms import Spectrum
9+
10+
from ms2query.database import ANNIndex
11+
from ms2query.database.spectra_merging import ensure_merged_tables
12+
13+
# --- small helpers for array <-> BLOB used in tests (mirrors the production helpers) ---
14+
15+
def _ndarray_to_blob(arr: np.ndarray) -> bytes:
16+
with io.BytesIO() as f:
17+
np.save(f, arr, allow_pickle=False)
18+
return f.getvalue()
19+
20+
21+
@pytest.fixture()
22+
def conn() -> sqlite3.Connection:
23+
# In-memory DB for tests
24+
return sqlite3.connect(":memory:")
25+
26+
27+
@pytest.fixture()
28+
def ann(conn) -> ANNIndex:
29+
# Instantiate with dummy model path; we’ll monkeypatch load_model.
30+
return ANNIndex(
31+
conn=conn,
32+
model_path="dummy_model.pt",
33+
faiss_metric="ip",
34+
faiss_factory=None,
35+
normalize_embeddings=True,
36+
)
37+
38+
39+
def test_ensure_schema_creates_tables(ann: ANNIndex):
40+
"""Schema should be created with all required columns."""
41+
ann.ensure_schema()
42+
cur = ann.conn.cursor()
43+
# Check both tables exist
44+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='merged_spectra';")
45+
assert cur.fetchone() is not None
46+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='merged_embeddings';")
47+
assert cur.fetchone() is not None
48+
49+
# Check a few critical columns exist
50+
cur.execute("PRAGMA table_info('merged_spectra');")
51+
cols = {row[1] for row in cur.fetchall()}
52+
for required in ("merged_id", "comp_id", "precursor_mz", "mz", "intensities", "num_merged"):
53+
assert required in cols
54+
55+
56+
def _insert_synthetic_merged_rows(conn: sqlite3.Connection) -> Tuple[int, int]:
57+
"""
58+
Insert two tiny merged_spectra rows with minimal viable metadata.
59+
Returns their merged_ids (sqlite autoincrement).
60+
"""
61+
cur = conn.cursor()
62+
ensure_merged_tables(conn)
63+
64+
# Synthetic peaks
65+
mz1 = np.array([100.0, 150.0, 200.0], dtype=np.float64)
66+
it1 = np.array([0.2, 0.3, 0.5], dtype=np.float32)
67+
68+
mz2 = np.array([101.0, 151.0, 201.0], dtype=np.float64)
69+
it2 = np.array([0.4, 0.1, 0.5], dtype=np.float32)
70+
71+
base_cols = (
72+
"comp_id, ionmode, charge, precursor_mz, smiles, inchikey, inchi, name, "
73+
"instrument_type, adduct, collision_energy, num_merged, source_spec_ids, mz, intensities"
74+
)
75+
q = f"INSERT INTO merged_spectra ({base_cols}) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
76+
77+
cur.execute(
78+
q,
79+
(
80+
"C1", "positive", 1, 300.123, "C(CO)O", "AAAA-BBBB-CCCC", "InChI=1S/...", "Compound A",
81+
"QTOF", "[M+H]+", "NCE 20", 3, json.dumps([11, 12, 13]),
82+
sqlite3.Binary(_ndarray_to_blob(mz1)), sqlite3.Binary(_ndarray_to_blob(it1))
83+
),
84+
)
85+
id1 = cur.lastrowid
86+
87+
cur.execute(
88+
q,
89+
(
90+
"C2", "positive", 1, 450.5, "CCN(CC)CC", "XXXX-YYYY-ZZZZ", "InChI=1S/...", "Compound B",
91+
"Orbitrap", "[M+H]+", "NCE 25", 2, json.dumps([21, 22]),
92+
sqlite3.Binary(_ndarray_to_blob(mz2)), sqlite3.Binary(_ndarray_to_blob(it2))
93+
),
94+
)
95+
id2 = cur.lastrowid
96+
97+
conn.commit()
98+
return id1, id2
99+
100+
101+
def test_compute_embeddings_inserts_rows(ann: ANNIndex, monkeypatch):
102+
"""Embeddings should be computed and written; rerun with only_missing yields 0 new rows."""
103+
id1, id2 = _insert_synthetic_merged_rows(ann.conn)
104+
105+
# Monkeypatch model loading and embedding to be deterministic & light.
106+
class _DummyModel:
107+
pass
108+
109+
def fake_load_model(_path):
110+
return _DummyModel()
111+
112+
def fake_compute_embedding_array(model, specs: List[Spectrum]) -> np.ndarray:
113+
# simple deterministic embedding: [precursor_mz, charge, sum(intens), len(peaks)]
114+
out = []
115+
for s in specs:
116+
pmz = float(s.metadata["precursor_mz"])
117+
charge = float(s.metadata.get("charge") or 0)
118+
intens_sum = float(np.sum(s.peaks.intensities))
119+
n_peaks = float(len(s.peaks.mz))
120+
out.append([pmz, charge, intens_sum, n_peaks])
121+
return np.asarray(out, dtype=np.float32)
122+
123+
monkeypatch.setattr("ms2query.database.ann_index.load_model", fake_load_model)
124+
monkeypatch.setattr("ms2query.database.ann_index.compute_embedding_array", fake_compute_embedding_array)
125+
126+
inserted = ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=True)
127+
assert inserted == 2
128+
129+
# Confirm rows exist
130+
cur = ann.conn.cursor()
131+
cur.execute("SELECT COUNT(1) FROM merged_embeddings;")
132+
assert cur.fetchone()[0] == 2
133+
134+
# Re-run with only_missing: should insert 0
135+
inserted2 = ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=True)
136+
assert inserted2 == 0
137+
138+
139+
def test_build_index_and_query(ann: ANNIndex, monkeypatch):
140+
"""Build index from stored embeddings and query it; top-1 should be the intended nearest."""
141+
id1, id2 = _insert_synthetic_merged_rows(ann.conn)
142+
143+
# Same monkeypatch as previous test (model + embeddings)
144+
class _DummyModel:
145+
pass
146+
147+
def fake_load_model(_path):
148+
return _DummyModel()
149+
150+
def fake_compute_embedding_array(model, specs: List[Spectrum]) -> np.ndarray:
151+
# Embedding consistent with test_compute_embeddings
152+
out = []
153+
for s in specs:
154+
pmz = float(s.metadata["precursor_mz"])
155+
charge = float(s.metadata.get("charge") or 0)
156+
intens_sum = float(np.sum(s.peaks.intensities))
157+
n_peaks = float(len(s.peaks.mz))
158+
out.append([pmz, charge, intens_sum, n_peaks])
159+
return np.asarray(out, dtype=np.float32)
160+
161+
monkeypatch.setattr("ms2query.database.ann_index.load_model", fake_load_model)
162+
monkeypatch.setattr("ms2query.database.ann_index.compute_embedding_array", fake_compute_embedding_array)
163+
164+
# Compute embeddings
165+
ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=False)
166+
167+
# Build FAISS index
168+
index = ann.build_index()
169+
assert index.ntotal == 2
170+
171+
# Prepare a query spectrum that should be closest to the 2nd row (precursor_mz=450.5)
172+
q_mz = np.array([100.0, 200.0], dtype=np.float32)
173+
q_it = np.array([0.5, 0.5], dtype=np.float32)
174+
q_spec = Spectrum(mz=q_mz, intensities=q_it, metadata={"precursor_mz": 450.5, "ionmode": "positive", "charge": 1})
175+
176+
# Query
177+
results = ann.query(q_spec, k=2, include_metadata=True, as_dataframe=True)
178+
assert isinstance(results, list) and len(results) == 1
179+
df = results[0]
180+
assert {"rank", "merged_id", "score", "distance", "comp_id", "name"}.issubset(df.columns)
181+
182+
# Top-1 hit should be the row with precursor_mz=450.5 (id2)
183+
top1 = df.iloc[0]
184+
assert int(top1["merged_id"]) == id2
185+
assert top1["comp_id"] == "C2"
186+
assert top1["name"] == "Compound B"
187+
188+
# Scores should be non-increasing by rank
189+
assert np.all(df["score"].values[:-1] >= df["score"].values[1:])

tests/test_compound_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
# >>> adjust to your package/module path
9-
from ms2query.compound_database import (
9+
from ms2query.database.compound_database import (
1010
CompoundDatabase,
1111
SpecToCompoundMap,
1212
map_from_spectraldb_metadata,

0 commit comments

Comments
 (0)