Skip to content

Commit 88ef8a9

Browse files
committed
add compound database and mappers
1 parent 4692529 commit 88ef8a9

File tree

2 files changed

+399
-7
lines changed

2 files changed

+399
-7
lines changed

ms2query/compound_database.py

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
from dataclasses import dataclass, field
2+
from typing import Iterable, Optional, Dict, Any, List, Tuple
3+
from pathlib import Path
4+
import sqlite3
5+
import numpy as np
6+
import pandas as pd
7+
8+
# =========================
9+
# Utilities & placeholders
10+
# =========================
11+
12+
def inchikey14_from_full(inchikey: str) -> Optional[str]:
13+
"""Return the first 14 characters (inchikey14). Robust to hyphens/malformed keys."""
14+
if not inchikey:
15+
return None
16+
s = str(inchikey).strip().upper()
17+
if "-" in s:
18+
return s.split("-", 1)[0][:14]
19+
return s[:14] if len(s) >= 14 else None
20+
21+
def encode_fp_blob(fp: Optional[np.ndarray]) -> bytes:
22+
"""Encode fingerprint as a uint8 BLOB. Accepts any numeric dtype -> coerces to uint8."""
23+
if fp is None:
24+
return b""
25+
fp = np.asarray(fp)
26+
if fp.dtype != np.uint8:
27+
fp = fp.astype(np.uint8, copy=False)
28+
return fp.tobytes(order="C")
29+
30+
def decode_fp_blob(blob: bytes) -> np.ndarray:
31+
"""Decode fingerprint BLOB back to uint8 array. Unknown length -> infer from blob size."""
32+
if not blob:
33+
return np.zeros(0, dtype=np.uint8)
34+
return np.frombuffer(blob, dtype=np.uint8).copy()
35+
36+
def compute_fingerprints(smiles: Optional[str], inchi: Optional[str]) -> np.ndarray:
37+
"""
38+
Placeholder: compute a molecular fingerprint from SMILES or InChI.
39+
For now return a dummy vector (replace with RDKit/Morgan etc. later).
40+
"""
41+
return np.array([0, 1, 0, 1], dtype=np.uint8)
42+
43+
# ==================================================
44+
# Compound database (compounds table) in SQLite
45+
# ==================================================
46+
47+
@dataclass
48+
class CompoundDatabase:
49+
sqlite_path: str
50+
# Extend as needed (add more classyfire-like fields here)
51+
compound_fields: List[str] = field(default_factory=lambda: [
52+
"smiles", "inchi", "inchikey", "classyfire_class", "classyfire_superclass"
53+
])
54+
_conn: sqlite3.Connection = field(init=False, repr=False)
55+
56+
def __post_init__(self):
57+
Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True)
58+
self._conn = sqlite3.connect(self.sqlite_path)
59+
self._conn.row_factory = sqlite3.Row
60+
self._ensure_schema()
61+
62+
def close(self):
63+
try:
64+
self._conn.close()
65+
except Exception:
66+
pass
67+
68+
# ---------- schema ----------
69+
70+
def _ensure_schema(self):
71+
cur = self._conn.cursor()
72+
# comp_id is inchikey14 (PRIMARY KEY). inchikey (full) must be unique as well if present.
73+
cur.executescript("""
74+
PRAGMA journal_mode=WAL;
75+
CREATE TABLE IF NOT EXISTS compounds(
76+
comp_id TEXT PRIMARY KEY, -- inchikey14
77+
smiles TEXT,
78+
inchi TEXT,
79+
inchikey TEXT UNIQUE, -- full InChIKey (27 chars)
80+
fingerprint BLOB, -- uint8 array
81+
classyfire_class TEXT,
82+
classyfire_superclass TEXT
83+
);
84+
CREATE INDEX IF NOT EXISTS idx_compounds_smiles ON compounds(smiles);
85+
CREATE INDEX IF NOT EXISTS idx_compounds_inchi ON compounds(inchi);
86+
""")
87+
self._conn.commit()
88+
89+
# ---------- upsert ----------
90+
91+
def upsert_compound(
92+
self,
93+
smiles: Optional[str] = None,
94+
inchi: Optional[str] = None,
95+
inchikey: Optional[str] = None,
96+
classyfire_class: Optional[str] = None,
97+
classyfire_superclass: Optional[str] = None,
98+
fingerprint: Optional[np.ndarray] = None,
99+
) -> str:
100+
"""Upsert a single compound. Returns comp_id (inchikey14)."""
101+
if inchikey is None:
102+
raise ValueError("inchikey is required to form comp_id (inchikey14).")
103+
comp_id = inchikey14_from_full(inchikey)
104+
if not comp_id:
105+
raise ValueError(f"Invalid InChIKey: {inchikey}")
106+
107+
fp_blob = encode_fp_blob(fingerprint if fingerprint is not None else compute_fingerprints(smiles, inchi))
108+
109+
cur = self._conn.cursor()
110+
# Use INSERT ON CONFLICT for upsert semantics
111+
cur.execute("""
112+
INSERT INTO compounds (comp_id, smiles, inchi, inchikey, fingerprint, classyfire_class, classyfire_superclass)
113+
VALUES (?, ?, ?, ?, ?, ?, ?)
114+
ON CONFLICT(comp_id) DO UPDATE SET
115+
smiles=excluded.smiles,
116+
inchi=excluded.inchi,
117+
inchikey=excluded.inchikey,
118+
fingerprint=excluded.fingerprint,
119+
classyfire_class=excluded.classyfire_class,
120+
classyfire_superclass=excluded.classyfire_superclass
121+
""", (comp_id, smiles, inchi, inchikey, fp_blob, classyfire_class, classyfire_superclass))
122+
self._conn.commit()
123+
return comp_id
124+
125+
def upsert_many(self, rows: Iterable[Dict[str, Any]]) -> List[str]:
126+
"""
127+
Upsert many compounds. Each row may include:
128+
smiles, inchi, inchikey (required), classyfire_class, classyfire_superclass, fingerprint (np.ndarray optional).
129+
Returns list of comp_ids.
130+
"""
131+
comp_ids: List[str] = []
132+
cur = self._conn.cursor()
133+
cur.execute("BEGIN")
134+
try:
135+
for r in rows:
136+
inchikey = r.get("inchikey")
137+
if not inchikey:
138+
raise ValueError("Each row must contain 'inchikey'.")
139+
comp_id = inchikey14_from_full(inchikey)
140+
if not comp_id:
141+
raise ValueError(f"Invalid InChIKey: {inchikey}")
142+
143+
smiles = r.get("smiles")
144+
inchi = r.get("inchi")
145+
fingerprint = r.get("fingerprint")
146+
fp_blob = encode_fp_blob(fingerprint if fingerprint is not None else compute_fingerprints(smiles, inchi))
147+
classyfire_class = r.get("classyfire_class")
148+
classyfire_superclass = r.get("classyfire_superclass")
149+
150+
cur.execute("""
151+
INSERT INTO compounds (comp_id, smiles, inchi, inchikey, fingerprint, classyfire_class, classyfire_superclass)
152+
VALUES (?, ?, ?, ?, ?, ?, ?)
153+
ON CONFLICT(comp_id) DO UPDATE SET
154+
smiles=excluded.smiles,
155+
inchi=excluded.inchi,
156+
inchikey=excluded.inchikey,
157+
fingerprint=excluded.fingerprint,
158+
classyfire_class=excluded.classyfire_class,
159+
classyfire_superclass=excluded.classyfire_superclass
160+
""", (comp_id, smiles, inchi, inchikey, fp_blob, classyfire_class, classyfire_superclass))
161+
comp_ids.append(comp_id)
162+
cur.execute("COMMIT")
163+
except Exception:
164+
cur.execute("ROLLBACK")
165+
raise
166+
return comp_ids
167+
168+
# ---------- queries ----------
169+
170+
def get_compound(self, comp_id: str) -> Optional[Dict[str, Any]]:
171+
row = self._conn.execute("SELECT * FROM compounds WHERE comp_id = ?", (comp_id,)).fetchone()
172+
if not row:
173+
return None
174+
d = dict(row)
175+
d["fingerprint"] = decode_fp_blob(d["fingerprint"])
176+
return d
177+
178+
def sql_query(self, query: str) -> pd.DataFrame:
179+
return pd.read_sql_query(query, self._conn)
180+
181+
182+
# ==================================================
183+
# Mapping: spectrum <-> compound (spec_to_comp)
184+
# ==================================================
185+
186+
@dataclass
187+
class SpecToCompoundMap:
188+
"""Stores (spec_id -> comp_id) mappings in SQLite. Use the SAME DB file as SpectralDatabase for simplicity."""
189+
sqlite_path: str
190+
_conn: sqlite3.Connection = field(init=False, repr=False)
191+
192+
def __post_init__(self):
193+
Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True)
194+
self._conn = sqlite3.connect(self.sqlite_path)
195+
self._conn.row_factory = sqlite3.Row
196+
self._ensure_schema()
197+
198+
def close(self):
199+
try:
200+
self._conn.close()
201+
except Exception:
202+
pass
203+
204+
def _ensure_schema(self):
205+
cur = self._conn.cursor()
206+
# No strict FK enforcement (SpectralDatabase may have been created without FK pragma),
207+
# here: index both sides for fast lookup.
208+
cur.executescript("""
209+
CREATE TABLE IF NOT EXISTS spec_to_comp(
210+
spec_id INTEGER NOT NULL,
211+
comp_id TEXT NOT NULL,
212+
PRIMARY KEY (spec_id),
213+
-- implicit: comp_id should exist in compounds.comp_id (not enforced here)
214+
-- to enforce FK, you can enable PRAGMA foreign_keys=ON and create a FK to compounds(comp_id)
215+
-- if both tables are in the same SQLite file.
216+
CHECK (length(comp_id) = 14)
217+
);
218+
CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON spec_to_comp(comp_id);
219+
""")
220+
self._conn.commit()
221+
222+
# ---------- API ----------
223+
224+
def link(self, spec_id: int, comp_id: str):
225+
"""Insert or replace a single mapping."""
226+
if not comp_id or len(comp_id) != 14:
227+
raise ValueError("comp_id must be inchikey14 (14 characters).")
228+
self._conn.execute("""
229+
INSERT INTO spec_to_comp (spec_id, comp_id)
230+
VALUES (?, ?)
231+
ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id
232+
""", (spec_id, comp_id))
233+
self._conn.commit()
234+
235+
def link_many(self, pairs: Iterable[Tuple[int, str]]):
236+
"""Bulk link (spec_id, comp_id)."""
237+
cur = self._conn.cursor()
238+
cur.execute("BEGIN")
239+
try:
240+
cur.executemany("""
241+
INSERT INTO spec_to_comp (spec_id, comp_id)
242+
VALUES (?, ?)
243+
ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id
244+
""", list(pairs))
245+
cur.execute("COMMIT")
246+
except Exception:
247+
cur.execute("ROLLBACK")
248+
raise
249+
250+
def get_comp_id_for_specs(self, spec_ids: List[int]) -> pd.DataFrame:
251+
"""Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids."""
252+
if not spec_ids:
253+
return pd.DataFrame(columns=["spec_id", "comp_id"])
254+
placeholders = ",".join("?" * len(spec_ids))
255+
rows = self._conn.execute(
256+
f"SELECT spec_id, comp_id FROM spec_to_comp WHERE spec_id IN ({placeholders})",
257+
spec_ids
258+
).fetchall()
259+
return pd.DataFrame(rows, columns=["spec_id", "comp_id"])
260+
261+
def get_specs_for_comp(self, comp_id: str) -> List[int]:
262+
"""Return list of spec_ids for a given comp_id."""
263+
rows = self._conn.execute("SELECT spec_id FROM spec_to_comp WHERE comp_id = ?", (comp_id,)).fetchall()
264+
return [r[0] for r in rows]
265+
266+
267+
# ==================================================
268+
# Integrations with SpectralDatabase
269+
# ==================================================
270+
271+
def map_from_spectraldb_metadata(
272+
spectral_db_sqlite_path: str,
273+
mapping_sqlite_path: Optional[str] = None,
274+
compounds_sqlite_path: Optional[str] = None,
275+
*,
276+
create_missing_compounds: bool = True
277+
) -> Tuple[int, int]:
278+
"""
279+
Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14),
280+
populate spec_to_comp, and optionally upsert minimal compounds.
281+
282+
Returns: (n_mapped, n_new_compounds)
283+
"""
284+
# We do not import the class to avoid circular imports; use sqlite directly.
285+
s_conn = sqlite3.connect(spectral_db_sqlite_path)
286+
s_conn.row_factory = sqlite3.Row
287+
288+
map_db_path = mapping_sqlite_path or spectral_db_sqlite_path
289+
c_db_path = compounds_sqlite_path or spectral_db_sqlite_path
290+
291+
mapper = SpecToCompoundMap(map_db_path)
292+
compdb = CompoundDatabase(c_db_path)
293+
294+
# Pull spec_id + inchikey from SpectralDatabase.spectra table
295+
# (the earlier SpectralDatabase example stores metadata as columns; ensure 'inchikey' exists there).
296+
rows = s_conn.execute("SELECT spec_id, inchikey FROM spectra").fetchall()
297+
298+
to_link: List[Tuple[int, str]] = []
299+
new_comp_rows: List[Dict[str, Any]] = []
300+
301+
for r in rows:
302+
spec_id = int(r["spec_id"])
303+
ik_full = r["inchikey"]
304+
if not ik_full:
305+
continue
306+
comp_id = inchikey14_from_full(ik_full)
307+
if not comp_id:
308+
continue
309+
to_link.append((spec_id, comp_id))
310+
311+
if create_missing_compounds:
312+
new_comp_rows.append({
313+
"smiles": None,
314+
"inchi": None,
315+
"inchikey": ik_full,
316+
"classyfire_class": None,
317+
"classyfire_superclass": None,
318+
"fingerprint": None, # will be replaced by compute_fingerprints(...)
319+
})
320+
321+
# Bulk linking
322+
if to_link:
323+
mapper.link_many(to_link)
324+
325+
# Upsert compounds
326+
n_new_compounds = 0
327+
if create_missing_compounds and new_comp_rows:
328+
# Deduplicate by comp_id to avoid redundant upserts
329+
seen: set[str] = set()
330+
dedup_rows: List[Dict[str, Any]] = []
331+
for r in new_comp_rows:
332+
cid = inchikey14_from_full(r["inchikey"])
333+
if cid and cid not in seen:
334+
seen.add(cid)
335+
dedup_rows.append(r)
336+
before = compdb.sql_query("SELECT COUNT(*) AS n FROM compounds")["n"].iloc[0]
337+
compdb.upsert_many(dedup_rows)
338+
after = compdb.sql_query("SELECT COUNT(*) AS n FROM compounds")["n"].iloc[0]
339+
n_new_compounds = int(after - before)
340+
341+
n_mapped = len(to_link)
342+
343+
# tidy
344+
mapper.close()
345+
compdb.close()
346+
s_conn.close()
347+
348+
return n_mapped, n_new_compounds
349+
350+
351+
def get_unique_compounds_from_spectraldb(
352+
spectral_db_sqlite_path: str,
353+
external_meta: Optional[pd.DataFrame] = None,
354+
external_key_col: str = "inchikey14"
355+
) -> pd.DataFrame:
356+
"""
357+
Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14.
358+
Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided,
359+
it will be left-joined on `external_key_col` (default 'inchikey14').
360+
"""
361+
conn = sqlite3.connect(spectral_db_sqlite_path)
362+
conn.row_factory = sqlite3.Row
363+
364+
# pull spec_id + inchikey from spectra
365+
df = pd.read_sql_query("SELECT spec_id, inchikey FROM spectra", conn)
366+
conn.close()
367+
368+
if df.empty:
369+
base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"])
370+
if external_meta is not None:
371+
return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col)
372+
return base
373+
374+
# Compute inchikey14
375+
ik14 = df["inchikey"].fillna("").map(inchikey14_from_full)
376+
df["inchikey14"] = ik14
377+
378+
# Aggregate
379+
agg = (
380+
df.dropna(subset=["inchikey14"])
381+
.groupby(["inchikey14"], as_index=False)
382+
.agg(n_spectra=("spec_id", "count"),
383+
inchikey=("inchikey", "first")) # first full key seen
384+
)
385+
386+
# Optional join with external meta
387+
if external_meta is not None and not external_meta.empty:
388+
agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col)
389+
390+
# Order by prevalence
391+
agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True)
392+
return agg

0 commit comments

Comments
 (0)