1- import json
21from collections import defaultdict
32from pathlib import Path
4- from typing import Sequence , Tuple
3+ from typing import Sequence
54import numpy as np
65import pandas as pd
7- from matchms import Spectrum
86from matchms .importing import load_spectra
9- from ms2deepscore .models import load_model
7+ from matchms .Spectrum import Spectrum
8+ from ms2deepscore .models import SiameseSpectralModel , load_model
109from ms2deepscore .vector_operations import cosine_similarity_matrix
1110from tqdm import tqdm
1211from ms2query .benchmarking .AnnotatedSpectrumSet import AnnotatedSpectrumSet
13- from ms2query .benchmarking .Embeddings import Embeddings
12+ from ms2query .benchmarking .Embeddings import Embeddings , _to_json_serializable
1413from ms2query .benchmarking .Fingerprints import Fingerprints
1514from ms2query .benchmarking .TopKTanimotoScores import TopKTanimotoScores
1615
1716
17+ class MS2QueryLibrary :
18+ # Set default file names to enable save and load per library
19+ embedding_file_name = "embeddings.npz"
20+ top_k_tanimoto_scores_file_name = "top_k_tanimoto_scores.parquet"
21+ reference_metadata_file_name = "library_metadata.parquet"
22+ ms2deepscore_model_file_name = "ms2deepscore_model.pt"
23+ metadata_to_store = [
24+ "precursor_mz" ,
25+ "retention_time" ,
26+ "collision_energy" ,
27+ "compound_name" ,
28+ "smiles" ,
29+ "inchikey" ,
30+ ]
31+ fingerprint_type = "daylight"
32+ fingerprint_nbits = 4096
33+ top_k_inchikeys = 8
34+
35+ def __init__ (
36+ self ,
37+ ms2deepscore_model : SiameseSpectralModel ,
38+ reference_embeddings : Embeddings ,
39+ top_k_tanimoto_scores : TopKTanimotoScores ,
40+ reference_metadata : pd .DataFrame ,
41+ ):
42+ self .ms2deepscore_model = ms2deepscore_model
43+ self .reference_embeddings = reference_embeddings
44+ self .top_k_tanimoto_scores = top_k_tanimoto_scores
45+ self .reference_metadata = reference_metadata
46+
47+ # Check that the loaded files match
48+ if _to_json_serializable (ms2deepscore_model .model_settings .get_dict ()) != reference_embeddings .model_settings :
49+ raise ValueError (
50+ "The settings of the ms2deepscore model does not match the model used for creating the library embeddings"
51+ )
52+ if list (self .reference_metadata ["spectrum_hashes" ]) != [
53+ str (spectrum_hash ) for spectrum_hash in reference_embeddings .index_to_spectrum_hash
54+ ]:
55+ raise ValueError ("The loaded metadata does not match the used embeddings" )
56+ if {inchikey [:14 ] for inchikey in reference_metadata ["inchikey" ]} != set (
57+ top_k_tanimoto_scores .top_k_inchikeys_and_scores .index
58+ ):
59+ raise ValueError ("The inchikeys in the metadata and in the top_k_tanimoto_scores do not match" )
60+
61+ # Get the spectrum_indices_per_inchikey
62+ self .spectrum_indices_per_inchikey = defaultdict (list )
63+ for lib_spec_index , inchikey in enumerate (reference_metadata ["inchikey" ]):
64+ self .spectrum_indices_per_inchikey [inchikey [:14 ]].append (lib_spec_index )
65+
66+ @classmethod
67+ def load_from_directory (cls , library_file_directory ) -> "MS2QueryLibrary" :
68+ reference_embeddings_file = library_file_directory / cls .embedding_file_name
69+ top_k_tanimoto_scores_file = library_file_directory / cls .top_k_tanimoto_scores_file_name
70+ reference_metadata_file = library_file_directory / cls .reference_metadata_file_name
71+ ms2deepscore_model_file_name = library_file_directory = cls .ms2deepscore_model_file_name
72+ return cls .load_from_files (
73+ ms2deepscore_model_file_name , reference_embeddings_file , top_k_tanimoto_scores_file , reference_metadata_file
74+ )
75+
76+ @classmethod
77+ def load_from_files (
78+ cls ,
79+ ms2deepscore_model_file_name ,
80+ reference_embeddings_file ,
81+ top_k_tanimoto_scores_file ,
82+ reference_metadata_file ,
83+ ) -> "MS2QueryLibrary" :
84+ return cls (
85+ load_model (ms2deepscore_model_file_name ),
86+ Embeddings .load (reference_embeddings_file ),
87+ TopKTanimotoScores .load (top_k_tanimoto_scores_file ),
88+ pd .read_parquet (reference_metadata_file ),
89+ )
90+
91+ @classmethod
92+ def create_from_spectra (
93+ cls ,
94+ library_spectra : Sequence [Spectrum ],
95+ ms2deepscore_model_file_name : str ,
96+ store_file_directory = None ,
97+ store_files = True ,
98+ ) -> "MS2QueryLibrary" :
99+ """Creates all the files needed for MS2Query and stores them"""
100+ if store_file_directory is None :
101+ store_file_directory = Path (ms2deepscore_model_file_name ).parent
102+ if store_files :
103+ # Check the files don't exist yet
104+ for file in (
105+ store_file_directory / cls .embedding_file_name ,
106+ store_file_directory / cls .top_k_tanimoto_scores_file_name ,
107+ store_file_directory / cls .reference_metadata_file_name ,
108+ ):
109+ if file .exists ():
110+ raise FileExistsError (f"There is already a file stored with the name { file } " )
111+
112+ # library_spectra = list(tqdm(load_spectra(library_spectra_file), "Loading library spectra"))
113+ library_spectrum_set = AnnotatedSpectrumSet .create_spectrum_set (library_spectra )
114+ ms2deepscore_model = load_model (ms2deepscore_model_file_name )
115+ library_spectrum_set .add_embeddings (ms2deepscore_model )
116+
117+ fingerprints = Fingerprints .from_spectrum_set (library_spectrum_set , cls .fingerprint_type , cls .fingerprint_nbits )
118+ top_k_tanimoto_scores = TopKTanimotoScores .calculate_from_fingerprints (
119+ fingerprints , fingerprints , cls .top_k_inchikeys
120+ )
121+ reference_metadata = extract_metadata_from_library (
122+ library_spectrum_set ,
123+ cls .metadata_to_store ,
124+ )
125+
126+ if store_files :
127+ reference_metadata .to_parquet (store_file_directory / cls .reference_metadata_file_name )
128+ top_k_tanimoto_scores .save (store_file_directory / cls .top_k_tanimoto_scores_file_name )
129+ library_spectrum_set .embeddings .save (store_file_directory / cls .embedding_file_name )
130+ return cls (ms2deepscore_model , library_spectrum_set .embeddings , top_k_tanimoto_scores , reference_metadata )
131+
132+ def run_ms2query (
133+ self ,
134+ query_spectra : Sequence [Spectrum ],
135+ batch_size : int = 1000 ,
136+ ) -> pd .DataFrame :
137+
138+ query_embeddings = Embeddings .create_from_spectra (query_spectra , self .ms2deepscore_model )
139+
140+ num_of_query_embeddings = query_embeddings .embeddings .shape [0 ]
141+
142+ library_index_highest_ms2deepscore = np .zeros ((num_of_query_embeddings ), dtype = int )
143+ ms2query_scores = []
144+ for start_idx in tqdm (
145+ range (0 , num_of_query_embeddings , batch_size ),
146+ desc = "Predicting highest ms2deepscore per batch of "
147+ + str (min (batch_size , num_of_query_embeddings ))
148+ + " embeddings" ,
149+ ):
150+ # Do MS2DeepScore predictions for batch
151+ end_idx = min (start_idx + batch_size , num_of_query_embeddings )
152+ selected_query_embeddings = query_embeddings .embeddings [start_idx :end_idx ]
153+ score_matrix = cosine_similarity_matrix (selected_query_embeddings , self .reference_embeddings .embeddings )
154+ highest_score_idx = np .argmax (score_matrix , axis = 1 )
155+ library_index_highest_ms2deepscore [start_idx :end_idx ] = highest_score_idx
156+
157+ # get predicted inchikeys
158+ predicted_inchikeys = self .reference_metadata .iloc [highest_score_idx ]["inchikey" ]
159+ # Compute MS2Query reliability score
160+ ms2query_scores .extend (
161+ get_ms2query_reliability_prediction (
162+ predicted_inchikeys , self .spectrum_indices_per_inchikey , self .top_k_tanimoto_scores , score_matrix
163+ )
164+ )
165+
166+ # construct results df
167+ results = self .reference_metadata .iloc [library_index_highest_ms2deepscore ]
168+ results ["ms2query_reliability_prediction" ] = ms2query_scores
169+ return results
170+
171+
18172def run_ms2query_from_files (
19173 query_spectrum_file ,
20174 ms2deepscore_model_file_name ,
21175 reference_embeddings_file ,
22176 top_k_tanimoto_scores_file ,
23177 reference_metadata_file ,
178+ save_file_location ,
24179):
25- reference_embeddings = Embeddings .load (reference_embeddings_file )
26- top_k_tanimoto_scores = TopKTanimotoScores .load (top_k_tanimoto_scores_file )
27- reference_metadata = pd .read_parquet (reference_metadata_file )
28- # Get the spectrum_indices_per_inchikey
29- spectrum_indices_per_inchikey = defaultdict (list )
30- for lib_spec_index , inchikey in enumerate (reference_metadata ["inchikey" ]):
31- spectrum_indices_per_inchikey [inchikey [:14 ]].append (lib_spec_index )
32-
33- query_spectra = list (tqdm (load_spectra (query_spectrum_file ), desc = "loading_in_query_spectra" ))
34- ms2deepscore_model = load_model (ms2deepscore_model_file_name )
35- query_embeddings = Embeddings .create_from_spectra (query_spectra , ms2deepscore_model )
36- run_ms2query (
37- query_embeddings , reference_embeddings , reference_metadata , spectrum_indices_per_inchikey , top_k_tanimoto_scores
180+ ms2query_library = MS2QueryLibrary .load_from_files (
181+ ms2deepscore_model_file_name ,
182+ reference_embeddings_file ,
183+ top_k_tanimoto_scores_file ,
184+ reference_metadata_file ,
38185 )
39186
40-
41- def run_ms2query (
42- query_embeddings : Embeddings ,
43- library_embeddings : Embeddings ,
44- library_metadata : pd .DataFrame ,
45- spectrum_indices_per_inchikey : defaultdict [str , list [int ]],
46- top_k_tanimoto_scores : TopKTanimotoScores ,
47- batch_size : int = 1000 ,
48- ):
49- num_of_query_embeddings = query_embeddings .embeddings .shape [0 ]
50-
51- library_index_highest_ms2deepscore = np .zeros ((num_of_query_embeddings ), dtype = int )
52- ms2query_scores = []
53- for start_idx in tqdm (
54- range (0 , num_of_query_embeddings , batch_size ),
55- desc = "Predicting highest ms2deepscore per batch of "
56- + str (min (batch_size , num_of_query_embeddings ))
57- + " embeddings" ,
58- ):
59- # Do MS2DeepScore predictions for batch
60- end_idx = min (start_idx + batch_size , num_of_query_embeddings )
61- selected_query_embeddings = query_embeddings .embeddings [start_idx :end_idx ]
62- score_matrix = cosine_similarity_matrix (selected_query_embeddings , library_embeddings .embeddings )
63- highest_score_idx = np .argmax (score_matrix , axis = 1 )
64- library_index_highest_ms2deepscore [start_idx :end_idx ] = highest_score_idx
65-
66- # get predicted inchikeys
67- predicted_inchikeys = library_metadata .iloc [highest_score_idx ]["inchikey" ]
68- # Compute MS2Query reliability score
69- ms2query_scores .extend (
70- get_ms2query_reliability_prediction (
71- predicted_inchikeys , spectrum_indices_per_inchikey , top_k_tanimoto_scores , score_matrix
72- )
73- )
74-
75- # construct results df
76- results = library_metadata .iloc [library_index_highest_ms2deepscore ]
77- results ["ms2query_reliability_prediction" ] = ms2query_scores
78- return results
187+ query_spectra = list (tqdm (load_spectra (query_spectrum_file ), desc = "loading_in_query_spectra" ))
188+ results_df = ms2query_library .run_ms2query (query_spectra )
189+ results_df .to_csv (save_file_location )
79190
80191
81192def get_ms2query_reliability_prediction (
@@ -97,52 +208,11 @@ def get_ms2query_reliability_prediction(
97208 return ms2query_scores
98209
99210
100- def create_ms2query_library (library_spectra_file : str , ms2deepscore_model_file_name : str ):
101- """Loads in a library and saves the embeddings and top_k_tanimoto_scores"""
102- spectrum_file_directory = Path ("/some/dir/file.txt" ).parent
103- embedding_file_location = spectrum_file_directory / "embeddings.npz"
104- top_k_tanimoto_score_file_location = spectrum_file_directory / "top_k_tanimoto_scores.parquet"
105- reference_metadata_file = spectrum_file_directory / "library_metadata.parquet"
106- if embedding_file_location .exists ():
107- raise FileExistsError ("There is already an embedding.npy file in the directory of your library spectra" )
108- if top_k_tanimoto_score_file_location .exists ():
109- raise FileExistsError (
110- "There is already an top_k_tanimoto_scores.parquet file in the directory of your library spectra"
111- )
112-
113- library_spectra = list (tqdm (load_spectra (library_spectra_file ), "Loading library spectra" ))
114- library_spectra = AnnotatedSpectrumSet .create_spectrum_set (library_spectra )
115- ms2deepscore_model = load_model (ms2deepscore_model_file_name )
116- library_spectra .add_embeddings (ms2deepscore_model )
117-
118- library_spectra ._embeddings .save (embedding_file_location )
119-
120- fingerprints = Fingerprints .from_spectrum_set (library_spectra , "daylight" , 4096 )
121- top_k_tanimoto_scores = TopKTanimotoScores .calculate_from_fingerprints (
122- fingerprints ,
123- fingerprints ,
124- k = 8 ,
125- )
126- top_k_tanimoto_scores .save (top_k_tanimoto_score_file_location )
127- reference_metadata = extract_metadata_from_library (
128- library_spectra ,
129- [
130- "precursor_mz" ,
131- "retention_time" ,
132- "collision_energy" ,
133- "compound_name" ,
134- "smiles" ,
135- "inchikey" ,
136- ],
137- )
138- reference_metadata .to_parquet (reference_metadata_file )
139-
140-
141211def extract_metadata_from_library (spectra : AnnotatedSpectrumSet , metadata_to_collect : list ):
142212 collected_metadata = {key : [] for key in metadata_to_collect }
143213 collected_metadata ["spectrum_hashes" ] = []
144214 for spectrum in tqdm (spectra .spectra , desc = "Extracting metadata df from spectra" ):
145215 for metadata_key in metadata_to_collect :
146216 collected_metadata [metadata_key ].append (spectrum .get (metadata_key ))
147- collected_metadata ["spectrum_hashes" ].append (spectrum .__hash__ ())
217+ collected_metadata ["spectrum_hashes" ].append (str ( spectrum .__hash__ () ))
148218 return pd .DataFrame (collected_metadata )
0 commit comments