|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | 3 | from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix |
| 4 | + |
| 5 | +from ms2query.benchmarking.Fingerprints import Fingerprints |
4 | 6 | from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match |
5 | 7 | from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine |
6 | 8 | from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore |
|
11 | 13 | from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet |
12 | 14 | from tests.conftest import create_test_spectra, ms2deepscore_model |
13 | 15 |
|
| 16 | +def get_library_and_test_spectra(): |
| 17 | + model = ms2deepscore_model() |
| 18 | + library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra()) |
| 19 | + test_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(1)) |
| 20 | + library_spectra.add_embeddings(model) |
| 21 | + test_spectra.add_embeddings(model) |
| 22 | + return library_spectra, test_spectra |
14 | 23 |
|
15 | 24 | @pytest.mark.parametrize( |
16 | 25 | "prediction_function", |
17 | 26 | [ |
18 | 27 | predict_highest_cosine, |
19 | 28 | predict_highest_ms2deepscore, |
20 | | - predict_best_possible_match, |
21 | 29 | ], |
22 | 30 | ) |
23 | 31 | def test_all_methods(prediction_function): |
24 | | - model = ms2deepscore_model() |
25 | | - |
26 | | - library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra()) |
27 | | - test_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(1)) |
28 | | - library_spectra.add_embeddings(model) |
29 | | - test_spectra.add_embeddings(model) |
| 32 | + library_spectra, test_spectra = get_library_and_test_spectra() |
30 | 33 | predicted_inchikeys, scores = prediction_function(library_spectra, test_spectra) |
31 | 34 | for i, spectrum in enumerate(test_spectra.spectra): |
32 | 35 | inchikey = spectrum.get("inchikey")[:14] |
33 | 36 | assert predicted_inchikeys[i] == inchikey |
34 | 37 | assert np.allclose(scores[i], np.array(1.0), atol=1e-5) |
35 | 38 |
|
36 | 39 |
|
| 40 | +def test_predict_best_possible_match(): |
| 41 | + library_spectra, test_spectra = get_library_and_test_spectra() |
| 42 | + fingerprints = Fingerprints.from_spectrum_set(library_spectra + test_spectra, "daylight", 2048) |
| 43 | + predicted_inchikeys, scores = predict_best_possible_match(library_spectra, test_spectra, fingerprints) |
| 44 | + for i, spectrum in enumerate(test_spectra.spectra): |
| 45 | + inchikey = spectrum.get("inchikey")[:14] |
| 46 | + assert predicted_inchikeys[i] == inchikey |
| 47 | + assert np.allclose(scores[i], np.array(1.0), atol=1e-5) |
| 48 | + |
37 | 49 | def test_predict_with_integrated_similarity_flow(): |
38 | | - model = ms2deepscore_model() |
39 | | - library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra()) |
40 | | - test_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(1)) |
41 | | - library_spectra.add_embeddings(model) |
42 | | - test_spectra.add_embeddings(model) |
43 | | - predicted_inchikeys, scores = predict_with_integrated_similarity_flow(library_spectra, test_spectra) |
| 50 | + library_spectra, test_spectra = get_library_and_test_spectra() |
| 51 | + fingerprints = Fingerprints.from_spectrum_set(library_spectra, "daylight", 4096) |
| 52 | + predicted_inchikeys, scores = predict_with_integrated_similarity_flow(library_spectra, test_spectra, fingerprints) |
44 | 53 |
|
45 | 54 | assert predicted_inchikeys == ["RYYVLZVUVIJVGH", "ZPUCINDJVBIVPJ", "ZPUCINDJVBIVPJ"] |
46 | 55 | assert np.allclose(np.array([0.38829751082577607, 0.3919729335980483, 0.38774130710967564]), np.array(scores)) |
|
0 commit comments