|
| 1 | +import copy |
| 2 | +from collections import Counter |
| 3 | +from typing import List, Dict, Iterable |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from matchms import Spectrum |
| 7 | +from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi |
| 8 | + |
| 9 | +from ms2deepscore.models import compute_embedding_array, SiameseSpectralModel |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | + |
| 13 | +class SpectrumSetBase: |
| 14 | + """Stores a spectrum dataset making it easy and fast to split on molecules""" |
| 15 | + |
| 16 | + def __init__(self, spectra: List[Spectrum], progress_bars=False): |
| 17 | + self._spectra = [] |
| 18 | + self.spectrum_indexes_per_inchikey = {} |
| 19 | + self.progress_bars = progress_bars |
| 20 | + # init spectra |
| 21 | + self._add_spectra_and_group_per_inchikey(spectra) |
| 22 | + |
| 23 | + def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]): |
| 24 | + starting_index = len(self._spectra) |
| 25 | + updated_inchikeys = set() |
| 26 | + for i, spectrum in enumerate( |
| 27 | + tqdm(spectra, desc="Adding spectra and grouping per Inchikey", disable=not self.progress_bars) |
| 28 | + ): |
| 29 | + self._spectra.append(spectrum) |
| 30 | + spectrum_index = starting_index + i |
| 31 | + inchikey = spectrum.get("inchikey")[:14] |
| 32 | + updated_inchikeys.add(inchikey) |
| 33 | + if inchikey in self.spectrum_indexes_per_inchikey: |
| 34 | + self.spectrum_indexes_per_inchikey[inchikey].append(spectrum_index) |
| 35 | + else: |
| 36 | + self.spectrum_indexes_per_inchikey[inchikey] = [ |
| 37 | + spectrum_index, |
| 38 | + ] |
| 39 | + return updated_inchikeys |
| 40 | + |
| 41 | + def add_spectra(self, new_spectra: "SpectrumSetBase"): |
| 42 | + return self._add_spectra_and_group_per_inchikey(new_spectra.spectra) |
| 43 | + |
| 44 | + def subset_spectra(self, spectrum_indexes) -> "SpectrumSetBase": |
| 45 | + """Returns a new instance of a subset of the spectra""" |
| 46 | + new_instance = copy.copy(self) |
| 47 | + new_instance._spectra = [] |
| 48 | + new_instance.spectrum_indexes_per_inchikey = {} |
| 49 | + new_instance._add_spectra_and_group_per_inchikey([self._spectra[index] for index in spectrum_indexes]) |
| 50 | + return new_instance |
| 51 | + |
| 52 | + def spectra_per_inchikey(self, inchikey) -> List[Spectrum]: |
| 53 | + matching_spectra = [] |
| 54 | + for index in self.spectrum_indexes_per_inchikey[inchikey]: |
| 55 | + matching_spectra.append(self._spectra[index]) |
| 56 | + return matching_spectra |
| 57 | + |
| 58 | + @property |
| 59 | + def spectra(self): |
| 60 | + return self._spectra |
| 61 | + |
| 62 | + def copy(self): |
| 63 | + """This copy method ensures all spectra are""" |
| 64 | + new_instance = copy.copy(self) |
| 65 | + new_instance._spectra = self._spectra.copy() |
| 66 | + new_instance.spectrum_indexes_per_inchikey = copy.deepcopy(self.spectrum_indexes_per_inchikey) |
| 67 | + return new_instance |
| 68 | + |
| 69 | + |
| 70 | +class SpectraWithFingerprints(SpectrumSetBase): |
| 71 | + """Stores a spectrum dataset making it easy and fast to split on molecules""" |
| 72 | + |
| 73 | + def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096): |
| 74 | + super().__init__(spectra) |
| 75 | + self.fingerprint_type = fingerprint_type |
| 76 | + self.nbits = nbits |
| 77 | + self.inchikey_fingerprint_pairs: Dict[str, np.array] = {} |
| 78 | + # init spectra |
| 79 | + self.update_fingerprint_per_inchikey(self.spectrum_indexes_per_inchikey.keys()) |
| 80 | + |
| 81 | + def add_spectra(self, new_spectra: "SpectraWithFingerprints"): |
| 82 | + updated_inchikeys = super().add_spectra(new_spectra) |
| 83 | + if hasattr(new_spectra, "inchikey_fingerprint_pairs"): |
| 84 | + if new_spectra.nbits == self.nbits and new_spectra.fingerprint_type == self.fingerprint_type: |
| 85 | + if len(self.inchikey_fingerprint_pairs.keys() & new_spectra.inchikey_fingerprint_pairs.keys()) == 0: |
| 86 | + self.inchikey_fingerprint_pairs = ( |
| 87 | + self.inchikey_fingerprint_pairs | new_spectra.inchikey_fingerprint_pairs |
| 88 | + ) |
| 89 | + return |
| 90 | + self.update_fingerprint_per_inchikey(updated_inchikeys) |
| 91 | + |
| 92 | + def update_fingerprint_per_inchikey(self, inchikeys_to_update: Iterable[str]): |
| 93 | + for inchikey in tqdm( |
| 94 | + inchikeys_to_update, desc="Adding fingerprints to Inchikeys", disable=not self.progress_bars |
| 95 | + ): |
| 96 | + spectra = self.spectra_per_inchikey(inchikey) |
| 97 | + most_common_inchi = Counter([spectrum.get("inchi") for spectrum in spectra]).most_common(1)[0][0] |
| 98 | + fingerprint = _derive_fingerprint_from_inchi( |
| 99 | + most_common_inchi, fingerprint_type=self.fingerprint_type, nbits=self.nbits |
| 100 | + ) |
| 101 | + if not isinstance(fingerprint, np.ndarray): |
| 102 | + raise ValueError(f"Fingerprint could not be set for InChI: {most_common_inchi}") |
| 103 | + self.inchikey_fingerprint_pairs[inchikey] = fingerprint |
| 104 | + |
| 105 | + def copy(self): |
| 106 | + """This copy method ensures all spectra are""" |
| 107 | + new_instance = super().copy() |
| 108 | + new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs) |
| 109 | + return new_instance |
| 110 | + |
| 111 | + |
| 112 | +class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints): |
| 113 | + def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs): |
| 114 | + super().__init__(spectra, **kwargs) |
| 115 | + self.ms2deepscore_model = ms2deepscore_model |
| 116 | + self.embeddings: np.ndarray = compute_embedding_array(self.ms2deepscore_model, spectra) |
| 117 | + |
| 118 | + def add_spectra(self, new_spectra: "SpectraWithMS2DeepScoreEmbeddings"): |
| 119 | + super().add_spectra(new_spectra) |
| 120 | + if hasattr(new_spectra, "embeddings"): |
| 121 | + new_embeddings = new_spectra.embeddings |
| 122 | + else: |
| 123 | + new_embeddings = compute_embedding_array(self.ms2deepscore_model, new_spectra.spectra) |
| 124 | + self.embeddings = np.vstack([self.embeddings, new_embeddings]) |
| 125 | + |
| 126 | + def subset_spectra(self, spectrum_indexes) -> "SpectraWithMS2DeepScoreEmbeddings": |
| 127 | + """Returns a new instance of a subset of the spectra""" |
| 128 | + new_instance = super().subset_spectra(spectrum_indexes) |
| 129 | + new_instance.embeddings = self.embeddings[spectrum_indexes] |
| 130 | + return new_instance |
0 commit comments