Skip to content

Commit fa4afdd

Browse files
committed
Add SpectrumDataSet for handling sets of spectra, for method development.
1 parent 9ecd3c1 commit fa4afdd

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import copy
2+
from collections import Counter
3+
from typing import List, Dict, Iterable
4+
5+
import numpy as np
6+
from matchms import Spectrum
7+
from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi
8+
9+
from ms2deepscore.models import compute_embedding_array, SiameseSpectralModel
10+
from tqdm import tqdm
11+
12+
13+
class SpectrumSetBase:
14+
"""Stores a spectrum dataset making it easy and fast to split on molecules"""
15+
16+
def __init__(self, spectra: List[Spectrum], progress_bars=False):
17+
self._spectra = []
18+
self.spectrum_indexes_per_inchikey = {}
19+
self.progress_bars = progress_bars
20+
# init spectra
21+
self._add_spectra_and_group_per_inchikey(spectra)
22+
23+
def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]):
24+
starting_index = len(self._spectra)
25+
updated_inchikeys = set()
26+
for i, spectrum in enumerate(
27+
tqdm(spectra, desc="Adding spectra and grouping per Inchikey", disable=not self.progress_bars)
28+
):
29+
self._spectra.append(spectrum)
30+
spectrum_index = starting_index + i
31+
inchikey = spectrum.get("inchikey")[:14]
32+
updated_inchikeys.add(inchikey)
33+
if inchikey in self.spectrum_indexes_per_inchikey:
34+
self.spectrum_indexes_per_inchikey[inchikey].append(spectrum_index)
35+
else:
36+
self.spectrum_indexes_per_inchikey[inchikey] = [
37+
spectrum_index,
38+
]
39+
return updated_inchikeys
40+
41+
def add_spectra(self, new_spectra: "SpectrumSetBase"):
42+
return self._add_spectra_and_group_per_inchikey(new_spectra.spectra)
43+
44+
def subset_spectra(self, spectrum_indexes) -> "SpectrumSetBase":
45+
"""Returns a new instance of a subset of the spectra"""
46+
new_instance = copy.copy(self)
47+
new_instance._spectra = []
48+
new_instance.spectrum_indexes_per_inchikey = {}
49+
new_instance._add_spectra_and_group_per_inchikey([self._spectra[index] for index in spectrum_indexes])
50+
return new_instance
51+
52+
def spectra_per_inchikey(self, inchikey) -> List[Spectrum]:
53+
matching_spectra = []
54+
for index in self.spectrum_indexes_per_inchikey[inchikey]:
55+
matching_spectra.append(self._spectra[index])
56+
return matching_spectra
57+
58+
@property
59+
def spectra(self):
60+
return self._spectra
61+
62+
def copy(self):
63+
"""This copy method ensures all spectra are"""
64+
new_instance = copy.copy(self)
65+
new_instance._spectra = self._spectra.copy()
66+
new_instance.spectrum_indexes_per_inchikey = copy.deepcopy(self.spectrum_indexes_per_inchikey)
67+
return new_instance
68+
69+
70+
class SpectraWithFingerprints(SpectrumSetBase):
71+
"""Stores a spectrum dataset making it easy and fast to split on molecules"""
72+
73+
def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096):
74+
super().__init__(spectra)
75+
self.fingerprint_type = fingerprint_type
76+
self.nbits = nbits
77+
self.inchikey_fingerprint_pairs: Dict[str, np.array] = {}
78+
# init spectra
79+
self.update_fingerprint_per_inchikey(self.spectrum_indexes_per_inchikey.keys())
80+
81+
def add_spectra(self, new_spectra: "SpectraWithFingerprints"):
82+
updated_inchikeys = super().add_spectra(new_spectra)
83+
if hasattr(new_spectra, "inchikey_fingerprint_pairs"):
84+
if new_spectra.nbits == self.nbits and new_spectra.fingerprint_type == self.fingerprint_type:
85+
if len(self.inchikey_fingerprint_pairs.keys() & new_spectra.inchikey_fingerprint_pairs.keys()) == 0:
86+
self.inchikey_fingerprint_pairs = (
87+
self.inchikey_fingerprint_pairs | new_spectra.inchikey_fingerprint_pairs
88+
)
89+
return
90+
self.update_fingerprint_per_inchikey(updated_inchikeys)
91+
92+
def update_fingerprint_per_inchikey(self, inchikeys_to_update: Iterable[str]):
93+
for inchikey in tqdm(
94+
inchikeys_to_update, desc="Adding fingerprints to Inchikeys", disable=not self.progress_bars
95+
):
96+
spectra = self.spectra_per_inchikey(inchikey)
97+
most_common_inchi = Counter([spectrum.get("inchi") for spectrum in spectra]).most_common(1)[0][0]
98+
fingerprint = _derive_fingerprint_from_inchi(
99+
most_common_inchi, fingerprint_type=self.fingerprint_type, nbits=self.nbits
100+
)
101+
if not isinstance(fingerprint, np.ndarray):
102+
raise ValueError(f"Fingerprint could not be set for InChI: {most_common_inchi}")
103+
self.inchikey_fingerprint_pairs[inchikey] = fingerprint
104+
105+
def copy(self):
106+
"""This copy method ensures all spectra are"""
107+
new_instance = super().copy()
108+
new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs)
109+
return new_instance
110+
111+
112+
class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints):
113+
def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs):
114+
super().__init__(spectra, **kwargs)
115+
self.ms2deepscore_model = ms2deepscore_model
116+
self.embeddings: np.ndarray = compute_embedding_array(self.ms2deepscore_model, spectra)
117+
118+
def add_spectra(self, new_spectra: "SpectraWithMS2DeepScoreEmbeddings"):
119+
super().add_spectra(new_spectra)
120+
if hasattr(new_spectra, "embeddings"):
121+
new_embeddings = new_spectra.embeddings
122+
else:
123+
new_embeddings = compute_embedding_array(self.ms2deepscore_model, new_spectra.spectra)
124+
self.embeddings = np.vstack([self.embeddings, new_embeddings])
125+
126+
def subset_spectra(self, spectrum_indexes) -> "SpectraWithMS2DeepScoreEmbeddings":
127+
"""Returns a new instance of a subset of the spectra"""
128+
new_instance = super().subset_spectra(spectrum_indexes)
129+
new_instance.embeddings = self.embeddings[spectrum_indexes]
130+
return new_instance

0 commit comments

Comments
 (0)