Skip to content

Commit 0cb5a53

Browse files
committed
Convert run_ms2query to class
1 parent be5e756 commit 0cb5a53

File tree

2 files changed

+208
-110
lines changed

2 files changed

+208
-110
lines changed

ms2query/run_ms2query.py

Lines changed: 169 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,192 @@
1-
import json
21
from collections import defaultdict
32
from pathlib import Path
4-
from typing import Sequence, Tuple
3+
from typing import Sequence
54
import numpy as np
65
import pandas as pd
7-
from matchms import Spectrum
86
from matchms.importing import load_spectra
9-
from ms2deepscore.models import load_model
7+
from matchms.Spectrum import Spectrum
8+
from ms2deepscore.models import SiameseSpectralModel, load_model
109
from ms2deepscore.vector_operations import cosine_similarity_matrix
1110
from tqdm import tqdm
1211
from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet
13-
from ms2query.benchmarking.Embeddings import Embeddings
12+
from ms2query.benchmarking.Embeddings import Embeddings, _to_json_serializable
1413
from ms2query.benchmarking.Fingerprints import Fingerprints
1514
from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores
1615

1716

17+
class MS2QueryLibrary:
18+
# Set default file names to enable save and load per library
19+
embedding_file_name = "embeddings.npz"
20+
top_k_tanimoto_scores_file_name = "top_k_tanimoto_scores.parquet"
21+
reference_metadata_file_name = "library_metadata.parquet"
22+
ms2deepscore_model_file_name = "ms2deepscore_model.pt"
23+
metadata_to_store = [
24+
"precursor_mz",
25+
"retention_time",
26+
"collision_energy",
27+
"compound_name",
28+
"smiles",
29+
"inchikey",
30+
]
31+
fingerprint_type = "daylight"
32+
fingerprint_nbits = 4096
33+
top_k_inchikeys = 8
34+
35+
def __init__(
36+
self,
37+
ms2deepscore_model: SiameseSpectralModel,
38+
reference_embeddings: Embeddings,
39+
top_k_tanimoto_scores: TopKTanimotoScores,
40+
reference_metadata: pd.DataFrame,
41+
):
42+
self.ms2deepscore_model = ms2deepscore_model
43+
self.reference_embeddings = reference_embeddings
44+
self.top_k_tanimoto_scores = top_k_tanimoto_scores
45+
self.reference_metadata = reference_metadata
46+
47+
# Check that the loaded files match
48+
if _to_json_serializable(ms2deepscore_model.model_settings.get_dict()) != reference_embeddings.model_settings:
49+
raise ValueError(
50+
"The settings of the ms2deepscore model does not match the model used for creating the library embeddings"
51+
)
52+
if list(self.reference_metadata["spectrum_hashes"]) != [
53+
str(spectrum_hash) for spectrum_hash in reference_embeddings.index_to_spectrum_hash
54+
]:
55+
raise ValueError("The loaded metadata does not match the used embeddings")
56+
if {inchikey[:14] for inchikey in reference_metadata["inchikey"]} != set(
57+
top_k_tanimoto_scores.top_k_inchikeys_and_scores.index
58+
):
59+
raise ValueError("The inchikeys in the metadata and in the top_k_tanimoto_scores do not match")
60+
61+
# Get the spectrum_indices_per_inchikey
62+
self.spectrum_indices_per_inchikey = defaultdict(list)
63+
for lib_spec_index, inchikey in enumerate(reference_metadata["inchikey"]):
64+
self.spectrum_indices_per_inchikey[inchikey[:14]].append(lib_spec_index)
65+
66+
@classmethod
67+
def load_from_directory(cls, library_file_directory) -> "MS2QueryLibrary":
68+
reference_embeddings_file = library_file_directory / cls.embedding_file_name
69+
top_k_tanimoto_scores_file = library_file_directory / cls.top_k_tanimoto_scores_file_name
70+
reference_metadata_file = library_file_directory / cls.reference_metadata_file_name
71+
ms2deepscore_model_file_name = library_file_directory = cls.ms2deepscore_model_file_name
72+
return cls.load_from_files(
73+
ms2deepscore_model_file_name, reference_embeddings_file, top_k_tanimoto_scores_file, reference_metadata_file
74+
)
75+
76+
@classmethod
77+
def load_from_files(
78+
cls,
79+
ms2deepscore_model_file_name,
80+
reference_embeddings_file,
81+
top_k_tanimoto_scores_file,
82+
reference_metadata_file,
83+
) -> "MS2QueryLibrary":
84+
return cls(
85+
load_model(ms2deepscore_model_file_name),
86+
Embeddings.load(reference_embeddings_file),
87+
TopKTanimotoScores.load(top_k_tanimoto_scores_file),
88+
pd.read_parquet(reference_metadata_file),
89+
)
90+
91+
@classmethod
92+
def create_from_spectra(
93+
cls,
94+
library_spectra: Sequence[Spectrum],
95+
ms2deepscore_model_file_name: str,
96+
store_file_directory=None,
97+
store_files=True,
98+
) -> "MS2QueryLibrary":
99+
"""Creates all the files needed for MS2Query and stores them"""
100+
if store_file_directory is None:
101+
store_file_directory = Path(ms2deepscore_model_file_name).parent
102+
if store_files:
103+
# Check the files don't exist yet
104+
for file in (
105+
store_file_directory / cls.embedding_file_name,
106+
store_file_directory / cls.top_k_tanimoto_scores_file_name,
107+
store_file_directory / cls.reference_metadata_file_name,
108+
):
109+
if file.exists():
110+
raise FileExistsError(f"There is already a file stored with the name {file}")
111+
112+
# library_spectra = list(tqdm(load_spectra(library_spectra_file), "Loading library spectra"))
113+
library_spectrum_set = AnnotatedSpectrumSet.create_spectrum_set(library_spectra)
114+
ms2deepscore_model = load_model(ms2deepscore_model_file_name)
115+
library_spectrum_set.add_embeddings(ms2deepscore_model)
116+
117+
fingerprints = Fingerprints.from_spectrum_set(library_spectrum_set, cls.fingerprint_type, cls.fingerprint_nbits)
118+
top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(
119+
fingerprints, fingerprints, cls.top_k_inchikeys
120+
)
121+
reference_metadata = extract_metadata_from_library(
122+
library_spectrum_set,
123+
cls.metadata_to_store,
124+
)
125+
126+
if store_files:
127+
reference_metadata.to_parquet(store_file_directory / cls.reference_metadata_file_name)
128+
top_k_tanimoto_scores.save(store_file_directory / cls.top_k_tanimoto_scores_file_name)
129+
library_spectrum_set.embeddings.save(store_file_directory / cls.embedding_file_name)
130+
return cls(ms2deepscore_model, library_spectrum_set.embeddings, top_k_tanimoto_scores, reference_metadata)
131+
132+
def run_ms2query(
133+
self,
134+
query_spectra: Sequence[Spectrum],
135+
batch_size: int = 1000,
136+
) -> pd.DataFrame:
137+
138+
query_embeddings = Embeddings.create_from_spectra(query_spectra, self.ms2deepscore_model)
139+
140+
num_of_query_embeddings = query_embeddings.embeddings.shape[0]
141+
142+
library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int)
143+
ms2query_scores = []
144+
for start_idx in tqdm(
145+
range(0, num_of_query_embeddings, batch_size),
146+
desc="Predicting highest ms2deepscore per batch of "
147+
+ str(min(batch_size, num_of_query_embeddings))
148+
+ " embeddings",
149+
):
150+
# Do MS2DeepScore predictions for batch
151+
end_idx = min(start_idx + batch_size, num_of_query_embeddings)
152+
selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx]
153+
score_matrix = cosine_similarity_matrix(selected_query_embeddings, self.reference_embeddings.embeddings)
154+
highest_score_idx = np.argmax(score_matrix, axis=1)
155+
library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx
156+
157+
# get predicted inchikeys
158+
predicted_inchikeys = self.reference_metadata.iloc[highest_score_idx]["inchikey"]
159+
# Compute MS2Query reliability score
160+
ms2query_scores.extend(
161+
get_ms2query_reliability_prediction(
162+
predicted_inchikeys, self.spectrum_indices_per_inchikey, self.top_k_tanimoto_scores, score_matrix
163+
)
164+
)
165+
166+
# construct results df
167+
results = self.reference_metadata.iloc[library_index_highest_ms2deepscore]
168+
results["ms2query_reliability_prediction"] = ms2query_scores
169+
return results
170+
171+
18172
def run_ms2query_from_files(
19173
query_spectrum_file,
20174
ms2deepscore_model_file_name,
21175
reference_embeddings_file,
22176
top_k_tanimoto_scores_file,
23177
reference_metadata_file,
178+
save_file_location,
24179
):
25-
reference_embeddings = Embeddings.load(reference_embeddings_file)
26-
top_k_tanimoto_scores = TopKTanimotoScores.load(top_k_tanimoto_scores_file)
27-
reference_metadata = pd.read_parquet(reference_metadata_file)
28-
# Get the spectrum_indices_per_inchikey
29-
spectrum_indices_per_inchikey = defaultdict(list)
30-
for lib_spec_index, inchikey in enumerate(reference_metadata["inchikey"]):
31-
spectrum_indices_per_inchikey[inchikey[:14]].append(lib_spec_index)
32-
33-
query_spectra = list(tqdm(load_spectra(query_spectrum_file), desc="loading_in_query_spectra"))
34-
ms2deepscore_model = load_model(ms2deepscore_model_file_name)
35-
query_embeddings = Embeddings.create_from_spectra(query_spectra, ms2deepscore_model)
36-
run_ms2query(
37-
query_embeddings, reference_embeddings, reference_metadata, spectrum_indices_per_inchikey, top_k_tanimoto_scores
180+
ms2query_library = MS2QueryLibrary.load_from_files(
181+
ms2deepscore_model_file_name,
182+
reference_embeddings_file,
183+
top_k_tanimoto_scores_file,
184+
reference_metadata_file,
38185
)
39186

40-
41-
def run_ms2query(
42-
query_embeddings: Embeddings,
43-
library_embeddings: Embeddings,
44-
library_metadata: pd.DataFrame,
45-
spectrum_indices_per_inchikey: defaultdict[str, list[int]],
46-
top_k_tanimoto_scores: TopKTanimotoScores,
47-
batch_size: int = 1000,
48-
):
49-
num_of_query_embeddings = query_embeddings.embeddings.shape[0]
50-
51-
library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int)
52-
ms2query_scores = []
53-
for start_idx in tqdm(
54-
range(0, num_of_query_embeddings, batch_size),
55-
desc="Predicting highest ms2deepscore per batch of "
56-
+ str(min(batch_size, num_of_query_embeddings))
57-
+ " embeddings",
58-
):
59-
# Do MS2DeepScore predictions for batch
60-
end_idx = min(start_idx + batch_size, num_of_query_embeddings)
61-
selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx]
62-
score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings.embeddings)
63-
highest_score_idx = np.argmax(score_matrix, axis=1)
64-
library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx
65-
66-
# get predicted inchikeys
67-
predicted_inchikeys = library_metadata.iloc[highest_score_idx]["inchikey"]
68-
# Compute MS2Query reliability score
69-
ms2query_scores.extend(
70-
get_ms2query_reliability_prediction(
71-
predicted_inchikeys, spectrum_indices_per_inchikey, top_k_tanimoto_scores, score_matrix
72-
)
73-
)
74-
75-
# construct results df
76-
results = library_metadata.iloc[library_index_highest_ms2deepscore]
77-
results["ms2query_reliability_prediction"] = ms2query_scores
78-
return results
187+
query_spectra = list(tqdm(load_spectra(query_spectrum_file), desc="loading_in_query_spectra"))
188+
results_df = ms2query_library.run_ms2query(query_spectra)
189+
results_df.to_csv(save_file_location)
79190

80191

81192
def get_ms2query_reliability_prediction(
@@ -97,52 +208,11 @@ def get_ms2query_reliability_prediction(
97208
return ms2query_scores
98209

99210

100-
def create_ms2query_library(library_spectra_file: str, ms2deepscore_model_file_name: str):
101-
"""Loads in a library and saves the embeddings and top_k_tanimoto_scores"""
102-
spectrum_file_directory = Path("/some/dir/file.txt").parent
103-
embedding_file_location = spectrum_file_directory / "embeddings.npz"
104-
top_k_tanimoto_score_file_location = spectrum_file_directory / "top_k_tanimoto_scores.parquet"
105-
reference_metadata_file = spectrum_file_directory / "library_metadata.parquet"
106-
if embedding_file_location.exists():
107-
raise FileExistsError("There is already an embedding.npy file in the directory of your library spectra")
108-
if top_k_tanimoto_score_file_location.exists():
109-
raise FileExistsError(
110-
"There is already an top_k_tanimoto_scores.parquet file in the directory of your library spectra"
111-
)
112-
113-
library_spectra = list(tqdm(load_spectra(library_spectra_file), "Loading library spectra"))
114-
library_spectra = AnnotatedSpectrumSet.create_spectrum_set(library_spectra)
115-
ms2deepscore_model = load_model(ms2deepscore_model_file_name)
116-
library_spectra.add_embeddings(ms2deepscore_model)
117-
118-
library_spectra._embeddings.save(embedding_file_location)
119-
120-
fingerprints = Fingerprints.from_spectrum_set(library_spectra, "daylight", 4096)
121-
top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(
122-
fingerprints,
123-
fingerprints,
124-
k=8,
125-
)
126-
top_k_tanimoto_scores.save(top_k_tanimoto_score_file_location)
127-
reference_metadata = extract_metadata_from_library(
128-
library_spectra,
129-
[
130-
"precursor_mz",
131-
"retention_time",
132-
"collision_energy",
133-
"compound_name",
134-
"smiles",
135-
"inchikey",
136-
],
137-
)
138-
reference_metadata.to_parquet(reference_metadata_file)
139-
140-
141211
def extract_metadata_from_library(spectra: AnnotatedSpectrumSet, metadata_to_collect: list):
142212
collected_metadata = {key: [] for key in metadata_to_collect}
143213
collected_metadata["spectrum_hashes"] = []
144214
for spectrum in tqdm(spectra.spectra, desc="Extracting metadata df from spectra"):
145215
for metadata_key in metadata_to_collect:
146216
collected_metadata[metadata_key].append(spectrum.get(metadata_key))
147-
collected_metadata["spectrum_hashes"].append(spectrum.__hash__())
217+
collected_metadata["spectrum_hashes"].append(str(spectrum.__hash__()))
148218
return pd.DataFrame(collected_metadata)

tests/test_run_ms2query.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
import os
2+
3+
import pandas as pd
4+
15
from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet
26
from ms2query.benchmarking.Fingerprints import Fingerprints
37
from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores
4-
from ms2query.run_ms2query import extract_metadata_from_library, run_ms2query
5-
from tests.helper_functions import create_test_spectra, ms2deepscore_model
8+
from ms2query.run_ms2query import extract_metadata_from_library, run_ms2query_from_files, MS2QueryLibrary
9+
from tests.helper_functions import TEST_RESOURCES_PATH, create_test_spectra, ms2deepscore_model
10+
from matchms.exporting import save_as_mgf
611

712

813
def test_run_ms2query():
914
model = ms2deepscore_model()
1015
library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(nr_of_inchikeys=7))
11-
test_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(1, nr_of_inchikeys=3))
16+
test_spectra = create_test_spectra(1, nr_of_inchikeys=3)
1217
library_spectra.add_embeddings(model)
13-
test_spectra.add_embeddings(model)
1418
fingerprints = Fingerprints.from_spectrum_set(library_spectra, "daylight", 100)
1519
top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(fingerprints, fingerprints, 3)
16-
spectrum_indices_per_inchikey = library_spectra.spectrum_indices_per_inchikey
1720
metadata_library = extract_metadata_from_library(
1821
library_spectra,
1922
[
@@ -24,11 +27,36 @@ def test_run_ms2query():
2427
"inchikey",
2528
],
2629
)
27-
results = run_ms2query(
28-
test_spectra.embeddings,
29-
library_spectra.embeddings,
30-
metadata_library,
31-
spectrum_indices_per_inchikey,
32-
top_k_tanimoto_scores,
30+
31+
results = MS2QueryLibrary(model, library_spectra.embeddings, top_k_tanimoto_scores, metadata_library).run_ms2query(
32+
test_spectra
3333
)
3434
print(results)
35+
36+
37+
def test_create_library(tmp_path):
38+
lib_spectra = create_test_spectra(nr_of_inchikeys=10, number_of_spectra_per_inchikey=3)
39+
# save_as_mgf(lib_spectra, os.path.join(tmp_path, "library_spectra.mgf"))
40+
ms2deepscore_model_file = os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt")
41+
MS2QueryLibrary.create_from_spectra(lib_spectra, ms2deepscore_model_file, tmp_path)
42+
assert (tmp_path / MS2QueryLibrary.embedding_file_name).exists()
43+
assert (tmp_path / MS2QueryLibrary.top_k_tanimoto_scores_file_name).exists()
44+
assert (tmp_path / MS2QueryLibrary.reference_metadata_file_name).exists()
45+
46+
47+
def test_create_and_use_library(tmp_path):
48+
lib_spectra = create_test_spectra(nr_of_inchikeys=10, number_of_spectra_per_inchikey=3)
49+
ms2deepscore_model_file = os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt")
50+
ms2query_library = MS2QueryLibrary.create_from_spectra(lib_spectra, ms2deepscore_model_file, tmp_path)
51+
test_spectra = create_test_spectra(1, nr_of_inchikeys=3)
52+
results = ms2query_library.run_ms2query(test_spectra)
53+
54+
ms2query_library_2 = MS2QueryLibrary.load_from_files(
55+
ms2deepscore_model_file,
56+
tmp_path / MS2QueryLibrary.embedding_file_name,
57+
tmp_path / MS2QueryLibrary.top_k_tanimoto_scores_file_name,
58+
tmp_path / MS2QueryLibrary.reference_metadata_file_name,
59+
)
60+
61+
results_2 = ms2query_library_2.run_ms2query(test_spectra)
62+
pd.testing.assert_frame_equal(results, results_2)

0 commit comments

Comments
 (0)