Skip to content

Commit f83048d

Browse files
committed
larger rework of tests
1 parent c7c0a6a commit f83048d

File tree

1 file changed

+89
-60
lines changed

1 file changed

+89
-60
lines changed

tests/test_compound_database.py

Lines changed: 89 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
# tests/test_compounds_and_mapping.py
12
import sqlite3
23
from pathlib import Path
34
import numpy as np
45
import pandas as pd
56
import pytest
67

8+
# >>> adjust to your package/module path
79
from ms2query.compound_database import (
810
CompoundDatabase,
911
SpecToCompoundMap,
1012
map_from_spectraldb_metadata,
1113
get_unique_compounds_from_spectraldb,
12-
compute_fingerprints,
14+
compute_fingerprints, # returns List[Optional[(bits, counts)]]
1315
inchikey14_from_full,
1416
)
1517

@@ -23,13 +25,32 @@ def make_tmp_db(tmp_path: Path, name: str = "test.sqlite") -> str:
2325
p.unlink()
2426
return str(p)
2527

26-
# Some example InChIKeys
28+
# Example InChIKeys
2729
IK_FULL_1 = "BSYNRYMUTXBXSQ-UHFFFAOYSA-N" # glucose
28-
IK_FULL_2 = "BSYNRYMUTXBXSQ-UHFFFAOYSA-O" # same first14, different suffix (stereo/isotope)
30+
IK_FULL_2 = "BSYNRYMUTXBXSQ-UHFFFAOYSA-O" # same first14, different suffix
2931
IK_FULL_3 = "BQJCRHHNABKAKU-KBQPJGBKSA-N" # ethanol
3032
IK14_1 = "BSYNRYMUTXBXSQ"
3133
IK14_3 = "BQJCRHHNABKAKU"
3234

35+
# -------------------------
36+
# Utilities
37+
# -------------------------
38+
39+
def create_min_spectral_table(sqlite_path: str, rows):
40+
"""Create a minimal spectra table (spec_id, inchikey) and insert rows."""
41+
con = sqlite3.connect(sqlite_path)
42+
cur = con.cursor()
43+
cur.executescript("""
44+
PRAGMA journal_mode=WAL;
45+
CREATE TABLE IF NOT EXISTS spectra(
46+
spec_id INTEGER PRIMARY KEY AUTOINCREMENT,
47+
inchikey TEXT
48+
);
49+
""")
50+
cur.executemany("INSERT INTO spectra(inchikey) VALUES (?)", [(r,) for r in rows])
51+
con.commit()
52+
con.close()
53+
3354
# -------------------------
3455
# Tests: low-level utilities
3556
# -------------------------
@@ -40,21 +61,33 @@ def test_inchikey14():
4061
assert inchikey14_from_full("BQJCRHHNABKAKU-KBQPJGBKSA-N") == IK14_3
4162
assert inchikey14_from_full("SHORT") is None # too short
4263

43-
def test_compute_fingerprints_placeholder():
44-
fp = compute_fingerprints("C(CO)O", None)
45-
assert isinstance(fp, np.ndarray)
46-
assert fp.dtype == np.uint8
47-
np.testing.assert_array_equal(fp, np.array([0, 1, 0, 1], dtype=np.uint8))
64+
def test_compute_fingerprints_contract():
65+
# API now expects list input in either smiles=... or inchis=...
66+
smiles = ["CCO", "C1=CC=CC=C1", None] # last one will be ignored by our call below
67+
# Call only with valid smiles strings
68+
fps = compute_fingerprints(smiles=[s for s in smiles if s is not None],
69+
inchis=None, sparse=True, count=True, radius=9, progress_bar=False)
70+
assert isinstance(fps, list)
71+
assert len(fps) == 2
72+
for fp in fps:
73+
# Optional[Tuple[np.ndarray, np.ndarray]]
74+
assert fp is None or (isinstance(fp, tuple) and len(fp) == 2)
75+
if fp is not None:
76+
bits, counts = fp
77+
assert isinstance(bits, np.ndarray) and bits.dtype == np.uint32
78+
assert isinstance(counts, np.ndarray)
79+
# counts are usually integer-like (could be float if you later scale)
80+
assert counts.ndim == 1
4881

4982
# -------------------------
50-
# Tests: CompoundDatabase
83+
# Tests: CompoundDatabase (no FP at upsert, backfill later)
5184
# -------------------------
5285

53-
def test_compound_upsert_and_get(tmp_path):
86+
def test_compound_upsert_and_get_and_backfill(tmp_path):
5487
db_path = make_tmp_db(tmp_path)
5588
cdb = CompoundDatabase(db_path)
5689

57-
# Upsert a compound
90+
# Upsert (no fingerprints written at this step)
5891
cid = cdb.upsert_compound(
5992
smiles="C(CO)O",
6093
inchi="InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3",
@@ -64,68 +97,62 @@ def test_compound_upsert_and_get(tmp_path):
6497
)
6598
assert cid == IK14_3
6699

67-
row = cdb.get_compound(cid)
68-
assert row is not None
69-
assert row["inchikey"] == IK_FULL_3
70-
assert isinstance(row["fingerprint"], np.ndarray)
71-
np.testing.assert_array_equal(row["fingerprint"], np.array([0,1,0,1], dtype=np.uint8))
72-
73-
# Upsert another with the same comp_id (different full IK) -> should overwrite row cleanly
74-
cid2 = cdb.upsert_compound(
75-
smiles="C6H12O6",
76-
inchi=None,
77-
inchikey=IK_FULL_1,
78-
classyfire_class="Carbohydrates",
79-
classyfire_superclass="Organic compounds",
80-
)
81-
assert cid2 == IK14_1
82-
row2 = cdb.get_compound(IK14_1)
83-
assert row2["inchikey"] == IK_FULL_1
100+
# Metadata-only getter
101+
meta = cdb.get_compound(cid)
102+
assert meta is not None
103+
assert set(meta.keys()) == {"comp_id","smiles","inchi","inchikey","classyfire_class","classyfire_superclass"}
104+
assert meta["inchikey"] == IK_FULL_3
105+
106+
# No fingerprint yet
107+
assert cdb.get_fingerprint(cid) is None
108+
109+
# Compute fingerprints for all missing (should fill this one)
110+
stats = cdb.compute_fingerprints_missing(batch_size=100, use_progress_bar=False)
111+
assert stats["attempted"] >= 1
112+
assert stats["updated"] >= 1
113+
114+
# Now fingerprint should be present
115+
fp = cdb.get_fingerprint(cid)
116+
assert fp is not None
117+
bits, counts = fp
118+
assert bits.dtype == np.uint32
119+
assert counts.ndim == 1
84120

85121
cdb.close()
86122

87-
def test_compound_upsert_many(tmp_path):
123+
def test_compound_upsert_many_and_batch_getters(tmp_path):
88124
db_path = make_tmp_db(tmp_path)
89125
cdb = CompoundDatabase(db_path)
90126

91-
# Insert two rows that collapse to the same comp_id (same first14), newest should win
92127
comp_ids = cdb.upsert_many([
93-
{"smiles": "X", "inchi": None, "inchikey": IK_FULL_1, "classyfire_class": "A"},
94-
{"smiles": "Y", "inchi": None, "inchikey": IK_FULL_2, "classyfire_class": "B"},
95-
{"smiles": "Z", "inchi": None, "inchikey": IK_FULL_3, "classyfire_class": "C"},
128+
{"smiles": "CCO", "inchi": None, "inchikey": IK_FULL_1, "classyfire_class": "A"}, # ethanol (valid)
129+
{"smiles": "c1ccccc1", "inchi": None, "inchikey": IK_FULL_2, "classyfire_class": "B"}, # benzene (valid)
130+
{"smiles": None, "inchi": "InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3",
131+
"inchikey": IK_FULL_3, "classyfire_class": "C"},
96132
])
97133
assert set(comp_ids) == {IK14_1, IK14_3}
98134

99-
row = cdb.get_compound(IK14_1)
100-
# After ON CONFLICT(comp_id) UPDATE, row reflects last data for that comp_id
101-
assert row["smiles"] in {"X", "Y"} # depends on order; both acceptable here
102-
assert row["inchikey"] in {IK_FULL_1, IK_FULL_2}
103-
assert row["classyfire_class"] in {"A", "B"}
135+
# Batch metadata (order preserved)
136+
df = cdb.get_compounds([IK14_3, IK14_1, "NOPE0000000000"])
137+
assert list(df["comp_id"]) == [IK14_3, IK14_1] # missing omitted
138+
assert set(["smiles","inchi","inchikey","classyfire_class","classyfire_superclass"]).issubset(df.columns)
104139

105-
count = cdb.sql_query("SELECT COUNT(*) as n FROM compounds")["n"].iloc[0]
106-
assert count == 2
140+
# No fingerprints yet
141+
fps = cdb.get_fingerprints([IK14_3, IK14_1, "NOPE0000000000"])
142+
assert fps[0] is None and fps[1] is None and fps[2] is None
143+
144+
# Backfill (will compute for rows with smiles OR inchi)
145+
stats = cdb.compute_fingerprints_missing(batch_size=100, use_progress_bar=False)
146+
assert stats["attempted"] >= 2
147+
fps = cdb.get_fingerprints([IK14_3, IK14_1, "NOPE0000000000"])
148+
assert fps[0] is not None and fps[1] is not None and fps[2] is None
107149

108150
cdb.close()
109151

110152
# -------------------------
111153
# Tests: SpecToCompoundMap + integration
112154
# -------------------------
113155

114-
def create_min_spectral_table(sqlite_path: str, rows):
115-
"""Create a minimal spectra table (spec_id, inchikey) and insert rows."""
116-
con = sqlite3.connect(sqlite_path)
117-
cur = con.cursor()
118-
cur.executescript("""
119-
PRAGMA journal_mode=WAL;
120-
CREATE TABLE IF NOT EXISTS spectra(
121-
spec_id INTEGER PRIMARY KEY AUTOINCREMENT,
122-
inchikey TEXT
123-
);
124-
""")
125-
cur.executemany("INSERT INTO spectra(inchikey) VALUES (?)", [(r,) for r in rows])
126-
con.commit()
127-
con.close()
128-
129156
def test_mapping_and_compound_creation(tmp_path):
130157
db_path = make_tmp_db(tmp_path)
131158

@@ -135,7 +162,7 @@ def test_mapping_and_compound_creation(tmp_path):
135162
# Run mapping (same db hosts compounds + mapping)
136163
n_mapped, n_new = map_from_spectraldb_metadata(db_path)
137164
assert n_mapped == 2 # two spectra had inchikeys
138-
assert n_new == 1 or n_new == 2 # depending on upsert collapsing; at least one unique comp
165+
assert n_new in (1, 2) # at least one unique comp
139166

140167
# Validate mapping contents
141168
mapper = SpecToCompoundMap(db_path)
@@ -147,17 +174,20 @@ def test_mapping_and_compound_creation(tmp_path):
147174
assert all(len(c) == 14 for c in df_map["comp_id"])
148175
mapper.close()
149176

150-
# Validate compounds exist
177+
# Validate compounds exist (metadata only, no FPs yet)
151178
cdb = CompoundDatabase(db_path)
152179
dfc = cdb.sql_query("SELECT comp_id, inchikey FROM compounds")
153180
assert not dfc.empty
154181
assert all(len(cid) == 14 for cid in dfc["comp_id"])
182+
# FPs should be empty prior to backfill
183+
empty = cdb.sql_query("SELECT COUNT(*) AS n FROM compounds WHERE COALESCE(LENGTH(fingerprint_bits),0)=0")
184+
assert empty["n"].iloc[0] >= 1
155185
cdb.close()
156186

157187
def test_mapper_link_and_get(tmp_path):
158188
db_path = make_tmp_db(tmp_path)
159189

160-
# need compounds table for FK-like behavior not enforced; mapping works independently
190+
# need compounds table; mapping works independently
161191
cdb = CompoundDatabase(db_path)
162192
cdb.upsert_compound(inchikey=IK_FULL_1) # ensure a compound exists
163193
cdb.close()
@@ -185,10 +215,9 @@ def test_get_unique_compounds_basic(tmp_path):
185215
create_min_spectral_table(db_path, [IK_FULL_1, IK_FULL_2, IK_FULL_3, None])
186216

187217
uniq = get_unique_compounds_from_spectraldb(db_path)
188-
# Expect 2 unique IK14 values
218+
# Column order follows the current function: inchikey14, inchikey, n_spectra
189219
assert list(uniq.columns[:3]) == ["inchikey14", "n_spectra", "inchikey"]
190220
assert set(uniq["inchikey14"]) == {IK14_1, IK14_3}
191-
# Counts: IK14_1 appears twice, IK14_3 once
192221
counts = dict(zip(uniq["inchikey14"], uniq["n_spectra"]))
193222
assert counts[IK14_1] == 2
194223
assert counts[IK14_3] == 1

0 commit comments

Comments
 (0)