|
| 1 | +from dataclasses import dataclass, field |
| 2 | +from typing import Iterable, Optional, Dict, Any, List, Tuple |
| 3 | +from pathlib import Path |
| 4 | +import sqlite3 |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | + |
| 8 | +# ========================= |
| 9 | +# Utilities & placeholders |
| 10 | +# ========================= |
| 11 | + |
| 12 | +def inchikey14_from_full(inchikey: str) -> Optional[str]: |
| 13 | + """Return the first 14 characters (inchikey14). Robust to hyphens/malformed keys.""" |
| 14 | + if not inchikey: |
| 15 | + return None |
| 16 | + s = str(inchikey).strip().upper() |
| 17 | + if "-" in s: |
| 18 | + return s.split("-", 1)[0][:14] |
| 19 | + return s[:14] if len(s) >= 14 else None |
| 20 | + |
| 21 | +def encode_fp_blob(fp: Optional[np.ndarray]) -> bytes: |
| 22 | + """Encode fingerprint as a uint8 BLOB. Accepts any numeric dtype -> coerces to uint8.""" |
| 23 | + if fp is None: |
| 24 | + return b"" |
| 25 | + fp = np.asarray(fp) |
| 26 | + if fp.dtype != np.uint8: |
| 27 | + fp = fp.astype(np.uint8, copy=False) |
| 28 | + return fp.tobytes(order="C") |
| 29 | + |
| 30 | +def decode_fp_blob(blob: bytes) -> np.ndarray: |
| 31 | + """Decode fingerprint BLOB back to uint8 array. Unknown length -> infer from blob size.""" |
| 32 | + if not blob: |
| 33 | + return np.zeros(0, dtype=np.uint8) |
| 34 | + return np.frombuffer(blob, dtype=np.uint8).copy() |
| 35 | + |
| 36 | +def compute_fingerprints(smiles: Optional[str], inchi: Optional[str]) -> np.ndarray: |
| 37 | + """ |
| 38 | + Placeholder: compute a molecular fingerprint from SMILES or InChI. |
| 39 | + For now return a dummy vector (replace with RDKit/Morgan etc. later). |
| 40 | + """ |
| 41 | + return np.array([0, 1, 0, 1], dtype=np.uint8) |
| 42 | + |
| 43 | +# ================================================== |
| 44 | +# Compound database (compounds table) in SQLite |
| 45 | +# ================================================== |
| 46 | + |
| 47 | +@dataclass |
| 48 | +class CompoundDatabase: |
| 49 | + sqlite_path: str |
| 50 | + # Extend as needed (add more classyfire-like fields here) |
| 51 | + compound_fields: List[str] = field(default_factory=lambda: [ |
| 52 | + "smiles", "inchi", "inchikey", "classyfire_class", "classyfire_superclass" |
| 53 | + ]) |
| 54 | + _conn: sqlite3.Connection = field(init=False, repr=False) |
| 55 | + |
| 56 | + def __post_init__(self): |
| 57 | + Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True) |
| 58 | + self._conn = sqlite3.connect(self.sqlite_path) |
| 59 | + self._conn.row_factory = sqlite3.Row |
| 60 | + self._ensure_schema() |
| 61 | + |
| 62 | + def close(self): |
| 63 | + try: |
| 64 | + self._conn.close() |
| 65 | + except Exception: |
| 66 | + pass |
| 67 | + |
| 68 | + # ---------- schema ---------- |
| 69 | + |
| 70 | + def _ensure_schema(self): |
| 71 | + cur = self._conn.cursor() |
| 72 | + # comp_id is inchikey14 (PRIMARY KEY). inchikey (full) must be unique as well if present. |
| 73 | + cur.executescript(""" |
| 74 | + PRAGMA journal_mode=WAL; |
| 75 | + CREATE TABLE IF NOT EXISTS compounds( |
| 76 | + comp_id TEXT PRIMARY KEY, -- inchikey14 |
| 77 | + smiles TEXT, |
| 78 | + inchi TEXT, |
| 79 | + inchikey TEXT UNIQUE, -- full InChIKey (27 chars) |
| 80 | + fingerprint BLOB, -- uint8 array |
| 81 | + classyfire_class TEXT, |
| 82 | + classyfire_superclass TEXT |
| 83 | + ); |
| 84 | + CREATE INDEX IF NOT EXISTS idx_compounds_smiles ON compounds(smiles); |
| 85 | + CREATE INDEX IF NOT EXISTS idx_compounds_inchi ON compounds(inchi); |
| 86 | + """) |
| 87 | + self._conn.commit() |
| 88 | + |
| 89 | + # ---------- upsert ---------- |
| 90 | + |
| 91 | + def upsert_compound( |
| 92 | + self, |
| 93 | + smiles: Optional[str] = None, |
| 94 | + inchi: Optional[str] = None, |
| 95 | + inchikey: Optional[str] = None, |
| 96 | + classyfire_class: Optional[str] = None, |
| 97 | + classyfire_superclass: Optional[str] = None, |
| 98 | + fingerprint: Optional[np.ndarray] = None, |
| 99 | + ) -> str: |
| 100 | + """Upsert a single compound. Returns comp_id (inchikey14).""" |
| 101 | + if inchikey is None: |
| 102 | + raise ValueError("inchikey is required to form comp_id (inchikey14).") |
| 103 | + comp_id = inchikey14_from_full(inchikey) |
| 104 | + if not comp_id: |
| 105 | + raise ValueError(f"Invalid InChIKey: {inchikey}") |
| 106 | + |
| 107 | + fp_blob = encode_fp_blob(fingerprint if fingerprint is not None else compute_fingerprints(smiles, inchi)) |
| 108 | + |
| 109 | + cur = self._conn.cursor() |
| 110 | + # Use INSERT ON CONFLICT for upsert semantics |
| 111 | + cur.execute(""" |
| 112 | + INSERT INTO compounds (comp_id, smiles, inchi, inchikey, fingerprint, classyfire_class, classyfire_superclass) |
| 113 | + VALUES (?, ?, ?, ?, ?, ?, ?) |
| 114 | + ON CONFLICT(comp_id) DO UPDATE SET |
| 115 | + smiles=excluded.smiles, |
| 116 | + inchi=excluded.inchi, |
| 117 | + inchikey=excluded.inchikey, |
| 118 | + fingerprint=excluded.fingerprint, |
| 119 | + classyfire_class=excluded.classyfire_class, |
| 120 | + classyfire_superclass=excluded.classyfire_superclass |
| 121 | + """, (comp_id, smiles, inchi, inchikey, fp_blob, classyfire_class, classyfire_superclass)) |
| 122 | + self._conn.commit() |
| 123 | + return comp_id |
| 124 | + |
| 125 | + def upsert_many(self, rows: Iterable[Dict[str, Any]]) -> List[str]: |
| 126 | + """ |
| 127 | + Upsert many compounds. Each row may include: |
| 128 | + smiles, inchi, inchikey (required), classyfire_class, classyfire_superclass, fingerprint (np.ndarray optional). |
| 129 | + Returns list of comp_ids. |
| 130 | + """ |
| 131 | + comp_ids: List[str] = [] |
| 132 | + cur = self._conn.cursor() |
| 133 | + cur.execute("BEGIN") |
| 134 | + try: |
| 135 | + for r in rows: |
| 136 | + inchikey = r.get("inchikey") |
| 137 | + if not inchikey: |
| 138 | + raise ValueError("Each row must contain 'inchikey'.") |
| 139 | + comp_id = inchikey14_from_full(inchikey) |
| 140 | + if not comp_id: |
| 141 | + raise ValueError(f"Invalid InChIKey: {inchikey}") |
| 142 | + |
| 143 | + smiles = r.get("smiles") |
| 144 | + inchi = r.get("inchi") |
| 145 | + fingerprint = r.get("fingerprint") |
| 146 | + fp_blob = encode_fp_blob(fingerprint if fingerprint is not None else compute_fingerprints(smiles, inchi)) |
| 147 | + classyfire_class = r.get("classyfire_class") |
| 148 | + classyfire_superclass = r.get("classyfire_superclass") |
| 149 | + |
| 150 | + cur.execute(""" |
| 151 | + INSERT INTO compounds (comp_id, smiles, inchi, inchikey, fingerprint, classyfire_class, classyfire_superclass) |
| 152 | + VALUES (?, ?, ?, ?, ?, ?, ?) |
| 153 | + ON CONFLICT(comp_id) DO UPDATE SET |
| 154 | + smiles=excluded.smiles, |
| 155 | + inchi=excluded.inchi, |
| 156 | + inchikey=excluded.inchikey, |
| 157 | + fingerprint=excluded.fingerprint, |
| 158 | + classyfire_class=excluded.classyfire_class, |
| 159 | + classyfire_superclass=excluded.classyfire_superclass |
| 160 | + """, (comp_id, smiles, inchi, inchikey, fp_blob, classyfire_class, classyfire_superclass)) |
| 161 | + comp_ids.append(comp_id) |
| 162 | + cur.execute("COMMIT") |
| 163 | + except Exception: |
| 164 | + cur.execute("ROLLBACK") |
| 165 | + raise |
| 166 | + return comp_ids |
| 167 | + |
| 168 | + # ---------- queries ---------- |
| 169 | + |
| 170 | + def get_compound(self, comp_id: str) -> Optional[Dict[str, Any]]: |
| 171 | + row = self._conn.execute("SELECT * FROM compounds WHERE comp_id = ?", (comp_id,)).fetchone() |
| 172 | + if not row: |
| 173 | + return None |
| 174 | + d = dict(row) |
| 175 | + d["fingerprint"] = decode_fp_blob(d["fingerprint"]) |
| 176 | + return d |
| 177 | + |
| 178 | + def sql_query(self, query: str) -> pd.DataFrame: |
| 179 | + return pd.read_sql_query(query, self._conn) |
| 180 | + |
| 181 | + |
| 182 | +# ================================================== |
| 183 | +# Mapping: spectrum <-> compound (spec_to_comp) |
| 184 | +# ================================================== |
| 185 | + |
| 186 | +@dataclass |
| 187 | +class SpecToCompoundMap: |
| 188 | + """Stores (spec_id -> comp_id) mappings in SQLite. Use the SAME DB file as SpectralDatabase for simplicity.""" |
| 189 | + sqlite_path: str |
| 190 | + _conn: sqlite3.Connection = field(init=False, repr=False) |
| 191 | + |
| 192 | + def __post_init__(self): |
| 193 | + Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True) |
| 194 | + self._conn = sqlite3.connect(self.sqlite_path) |
| 195 | + self._conn.row_factory = sqlite3.Row |
| 196 | + self._ensure_schema() |
| 197 | + |
| 198 | + def close(self): |
| 199 | + try: |
| 200 | + self._conn.close() |
| 201 | + except Exception: |
| 202 | + pass |
| 203 | + |
| 204 | + def _ensure_schema(self): |
| 205 | + cur = self._conn.cursor() |
| 206 | + # No strict FK enforcement (SpectralDatabase may have been created without FK pragma), |
| 207 | + # here: index both sides for fast lookup. |
| 208 | + cur.executescript(""" |
| 209 | + CREATE TABLE IF NOT EXISTS spec_to_comp( |
| 210 | + spec_id INTEGER NOT NULL, |
| 211 | + comp_id TEXT NOT NULL, |
| 212 | + PRIMARY KEY (spec_id), |
| 213 | + -- implicit: comp_id should exist in compounds.comp_id (not enforced here) |
| 214 | + -- to enforce FK, you can enable PRAGMA foreign_keys=ON and create a FK to compounds(comp_id) |
| 215 | + -- if both tables are in the same SQLite file. |
| 216 | + CHECK (length(comp_id) = 14) |
| 217 | + ); |
| 218 | + CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON spec_to_comp(comp_id); |
| 219 | + """) |
| 220 | + self._conn.commit() |
| 221 | + |
| 222 | + # ---------- API ---------- |
| 223 | + |
| 224 | + def link(self, spec_id: int, comp_id: str): |
| 225 | + """Insert or replace a single mapping.""" |
| 226 | + if not comp_id or len(comp_id) != 14: |
| 227 | + raise ValueError("comp_id must be inchikey14 (14 characters).") |
| 228 | + self._conn.execute(""" |
| 229 | + INSERT INTO spec_to_comp (spec_id, comp_id) |
| 230 | + VALUES (?, ?) |
| 231 | + ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id |
| 232 | + """, (spec_id, comp_id)) |
| 233 | + self._conn.commit() |
| 234 | + |
| 235 | + def link_many(self, pairs: Iterable[Tuple[int, str]]): |
| 236 | + """Bulk link (spec_id, comp_id).""" |
| 237 | + cur = self._conn.cursor() |
| 238 | + cur.execute("BEGIN") |
| 239 | + try: |
| 240 | + cur.executemany(""" |
| 241 | + INSERT INTO spec_to_comp (spec_id, comp_id) |
| 242 | + VALUES (?, ?) |
| 243 | + ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id |
| 244 | + """, list(pairs)) |
| 245 | + cur.execute("COMMIT") |
| 246 | + except Exception: |
| 247 | + cur.execute("ROLLBACK") |
| 248 | + raise |
| 249 | + |
| 250 | + def get_comp_id_for_specs(self, spec_ids: List[int]) -> pd.DataFrame: |
| 251 | + """Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids.""" |
| 252 | + if not spec_ids: |
| 253 | + return pd.DataFrame(columns=["spec_id", "comp_id"]) |
| 254 | + placeholders = ",".join("?" * len(spec_ids)) |
| 255 | + rows = self._conn.execute( |
| 256 | + f"SELECT spec_id, comp_id FROM spec_to_comp WHERE spec_id IN ({placeholders})", |
| 257 | + spec_ids |
| 258 | + ).fetchall() |
| 259 | + return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) |
| 260 | + |
| 261 | + def get_specs_for_comp(self, comp_id: str) -> List[int]: |
| 262 | + """Return list of spec_ids for a given comp_id.""" |
| 263 | + rows = self._conn.execute("SELECT spec_id FROM spec_to_comp WHERE comp_id = ?", (comp_id,)).fetchall() |
| 264 | + return [r[0] for r in rows] |
| 265 | + |
| 266 | + |
| 267 | +# ================================================== |
| 268 | +# Integrations with SpectralDatabase |
| 269 | +# ================================================== |
| 270 | + |
| 271 | +def map_from_spectraldb_metadata( |
| 272 | + spectral_db_sqlite_path: str, |
| 273 | + mapping_sqlite_path: Optional[str] = None, |
| 274 | + compounds_sqlite_path: Optional[str] = None, |
| 275 | + *, |
| 276 | + create_missing_compounds: bool = True |
| 277 | +) -> Tuple[int, int]: |
| 278 | + """ |
| 279 | + Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14), |
| 280 | + populate spec_to_comp, and optionally upsert minimal compounds. |
| 281 | +
|
| 282 | + Returns: (n_mapped, n_new_compounds) |
| 283 | + """ |
| 284 | + # We do not import the class to avoid circular imports; use sqlite directly. |
| 285 | + s_conn = sqlite3.connect(spectral_db_sqlite_path) |
| 286 | + s_conn.row_factory = sqlite3.Row |
| 287 | + |
| 288 | + map_db_path = mapping_sqlite_path or spectral_db_sqlite_path |
| 289 | + c_db_path = compounds_sqlite_path or spectral_db_sqlite_path |
| 290 | + |
| 291 | + mapper = SpecToCompoundMap(map_db_path) |
| 292 | + compdb = CompoundDatabase(c_db_path) |
| 293 | + |
| 294 | + # Pull spec_id + inchikey from SpectralDatabase.spectra table |
| 295 | + # (the earlier SpectralDatabase example stores metadata as columns; ensure 'inchikey' exists there). |
| 296 | + rows = s_conn.execute("SELECT spec_id, inchikey FROM spectra").fetchall() |
| 297 | + |
| 298 | + to_link: List[Tuple[int, str]] = [] |
| 299 | + new_comp_rows: List[Dict[str, Any]] = [] |
| 300 | + |
| 301 | + for r in rows: |
| 302 | + spec_id = int(r["spec_id"]) |
| 303 | + ik_full = r["inchikey"] |
| 304 | + if not ik_full: |
| 305 | + continue |
| 306 | + comp_id = inchikey14_from_full(ik_full) |
| 307 | + if not comp_id: |
| 308 | + continue |
| 309 | + to_link.append((spec_id, comp_id)) |
| 310 | + |
| 311 | + if create_missing_compounds: |
| 312 | + new_comp_rows.append({ |
| 313 | + "smiles": None, |
| 314 | + "inchi": None, |
| 315 | + "inchikey": ik_full, |
| 316 | + "classyfire_class": None, |
| 317 | + "classyfire_superclass": None, |
| 318 | + "fingerprint": None, # will be replaced by compute_fingerprints(...) |
| 319 | + }) |
| 320 | + |
| 321 | + # Bulk linking |
| 322 | + if to_link: |
| 323 | + mapper.link_many(to_link) |
| 324 | + |
| 325 | + # Upsert compounds |
| 326 | + n_new_compounds = 0 |
| 327 | + if create_missing_compounds and new_comp_rows: |
| 328 | + # Deduplicate by comp_id to avoid redundant upserts |
| 329 | + seen: set[str] = set() |
| 330 | + dedup_rows: List[Dict[str, Any]] = [] |
| 331 | + for r in new_comp_rows: |
| 332 | + cid = inchikey14_from_full(r["inchikey"]) |
| 333 | + if cid and cid not in seen: |
| 334 | + seen.add(cid) |
| 335 | + dedup_rows.append(r) |
| 336 | + before = compdb.sql_query("SELECT COUNT(*) AS n FROM compounds")["n"].iloc[0] |
| 337 | + compdb.upsert_many(dedup_rows) |
| 338 | + after = compdb.sql_query("SELECT COUNT(*) AS n FROM compounds")["n"].iloc[0] |
| 339 | + n_new_compounds = int(after - before) |
| 340 | + |
| 341 | + n_mapped = len(to_link) |
| 342 | + |
| 343 | + # tidy |
| 344 | + mapper.close() |
| 345 | + compdb.close() |
| 346 | + s_conn.close() |
| 347 | + |
| 348 | + return n_mapped, n_new_compounds |
| 349 | + |
| 350 | + |
| 351 | +def get_unique_compounds_from_spectraldb( |
| 352 | + spectral_db_sqlite_path: str, |
| 353 | + external_meta: Optional[pd.DataFrame] = None, |
| 354 | + external_key_col: str = "inchikey14" |
| 355 | +) -> pd.DataFrame: |
| 356 | + """ |
| 357 | + Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14. |
| 358 | + Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided, |
| 359 | + it will be left-joined on `external_key_col` (default 'inchikey14'). |
| 360 | + """ |
| 361 | + conn = sqlite3.connect(spectral_db_sqlite_path) |
| 362 | + conn.row_factory = sqlite3.Row |
| 363 | + |
| 364 | + # pull spec_id + inchikey from spectra |
| 365 | + df = pd.read_sql_query("SELECT spec_id, inchikey FROM spectra", conn) |
| 366 | + conn.close() |
| 367 | + |
| 368 | + if df.empty: |
| 369 | + base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"]) |
| 370 | + if external_meta is not None: |
| 371 | + return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) |
| 372 | + return base |
| 373 | + |
| 374 | + # Compute inchikey14 |
| 375 | + ik14 = df["inchikey"].fillna("").map(inchikey14_from_full) |
| 376 | + df["inchikey14"] = ik14 |
| 377 | + |
| 378 | + # Aggregate |
| 379 | + agg = ( |
| 380 | + df.dropna(subset=["inchikey14"]) |
| 381 | + .groupby(["inchikey14"], as_index=False) |
| 382 | + .agg(n_spectra=("spec_id", "count"), |
| 383 | + inchikey=("inchikey", "first")) # first full key seen |
| 384 | + ) |
| 385 | + |
| 386 | + # Optional join with external meta |
| 387 | + if external_meta is not None and not external_meta.empty: |
| 388 | + agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) |
| 389 | + |
| 390 | + # Order by prevalence |
| 391 | + agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True) |
| 392 | + return agg |
0 commit comments