|
1 | 1 | import os |
| 2 | + |
| 3 | +import numpy as np |
2 | 4 | import pytest |
3 | 5 | import sys |
4 | 6 | import pandas as pd |
5 | 7 | from ms2query.create_new_library.train_ms2query_model import \ |
6 | | - DataCollectorForTraining, calculate_tanimoto_scores_with_library, train_random_forest, train_ms2query_model |
7 | | -from ms2query.utils import load_pickled_file, load_matchms_spectrum_objects_from_file, convert_to_onnx_model |
| 8 | + DataCollectorForTraining, calculate_tanimoto_scores_with_library, train_random_forest, train_ms2query_model, \ |
| 9 | + convert_to_onnx_model |
| 10 | +from ms2query.utils import load_pickled_file, load_matchms_spectrum_objects_from_file, load_ms2query_model, \ |
| 11 | + predict_onnx_model |
8 | 12 | from onnxruntime import InferenceSession |
9 | 13 | from ms2query.utils import predict_onnx_model |
10 | 14 | from ms2query.ms2library import MS2Library |
@@ -76,15 +80,19 @@ def test_calculate_all_tanimoto_scores(tmp_path, ms2library, query_spectra): |
76 | 80 | pd.testing.assert_frame_equal(result, expected_result, check_dtype=False) |
77 | 81 |
|
78 | 82 |
|
79 | | -def test_train_random_forest(): |
| 83 | +def test_train_and_save_random_forest(): |
80 | 84 | training_scores, training_labels = load_pickled_file(os.path.join( |
81 | 85 | os.path.split(os.path.dirname(__file__))[0], |
82 | 86 | "tests/test_files/test_files_train_ms2query_nn", |
83 | 87 | "expected_train_and_val_data.pickle"))[:2] |
84 | 88 | ms2query_model = train_random_forest(training_scores, training_labels) |
85 | 89 | onnx_model = convert_to_onnx_model(ms2query_model) |
86 | 90 | onnx_model_session = InferenceSession(onnx_model.SerializeToString()) |
87 | | - predictions = predict_onnx_model(onnx_model_session, training_scores.values) |
| 91 | + predictions_onnx_model = predict_onnx_model(onnx_model_session, training_scores.values) |
| 92 | + |
| 93 | + # check if saving onnx model works |
| 94 | + predictions_sklearn_model = ms2query_model.predict(training_scores.values.astype(np.float32)) |
| 95 | + assert np.allclose(predictions_onnx_model, predictions_sklearn_model) |
88 | 96 |
|
89 | 97 |
|
90 | 98 | @pytest.mark.integration |
|
0 commit comments