Skip to content

Commit fb0107d

Browse files
committed
Add save and load options for top_k_tanimoto_scores
1 parent 29cc1e1 commit fb0107d

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

ms2query/benchmarking/TopKTanimotoScores.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
import numpy as np
23
import pandas as pd
34
from ms2query.benchmarking.Fingerprints import Fingerprints
@@ -27,13 +28,21 @@ def _create_multi_index(
2728
combined_data = np.empty((len(inchikey_indexes), self.k * 2), dtype=object)
2829
combined_data[:, 0::2] = top_k_inchikeys
2930
combined_data[:, 1::2] = tanimoto_scores_for_top_k
30-
return pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns)
31+
df = pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns)
32+
33+
# Cast score columns to float64
34+
score_cols = [(rank, "score") for rank in [f"Rank_{i + 1}" for i in range(self.k)]]
35+
df[score_cols] = df[score_cols].astype(float)
36+
37+
return df
3138

3239
@classmethod
3340
def calculate_from_fingerprints(cls, query_fingerprints: Fingerprints, target_fingerprints: Fingerprints, k):
3441
"""
3542
Gets the top k highest inchikeys and scores for each inchikey in query_fingerprints from target_fingerprints
3643
"""
44+
if target_fingerprints.fingerprints.shape[0] < k:
45+
raise ValueError("K cannot be larger than the number of fingerprints")
3746
similarity_scores = generalized_tanimoto_similarity_matrix(
3847
query_fingerprints.fingerprints, target_fingerprints.fingerprints
3948
)
@@ -67,3 +76,30 @@ def get_all_average_tanimoto_scores(self) -> dict[str, float]:
6776

6877
average_per_inchikey_df = scores_df.mean(axis=1)
6978
return average_per_inchikey_df.to_dict()
79+
80+
def save(self, path: str | Path) -> None:
81+
"""Save the TopKTanimotoScores to disk as a parquet file.
82+
83+
Args:
84+
path: File path without extension, e.g. "/data/top_k_scores".
85+
"""
86+
Path(path).with_suffix(".parquet").parent.mkdir(parents=True, exist_ok=True)
87+
self.top_k_inchikeys_and_scores.to_parquet(Path(path).with_suffix(".parquet"))
88+
89+
@classmethod
90+
def load(cls, path: str | Path) -> "TopKTanimotoScores":
91+
"""Load a previously saved TopKTanimotoScores from disk.
92+
93+
Args:
94+
path: File path without extension, e.g. "/data/top_k_scores".
95+
96+
Returns:
97+
A fully reconstructed TopKTanimotoScores instance.
98+
"""
99+
df = pd.read_parquet(Path(path).with_suffix(".parquet"))
100+
df.columns.names = ["result_rank", "attribute"]
101+
102+
instance = cls.__new__(cls)
103+
instance.k = len(df.columns.get_level_values("result_rank").unique())
104+
instance.top_k_inchikeys_and_scores = df
105+
return instance

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ ms2deepscore= ">=2.6.0"
3232
rdkit= ">2024.3.4"
3333
nmslib= ">=2.0.0"
3434
umap-learn= ">=0.5.7"
35+
pyarrow= ">=14.0.1"
3536

3637
[tool.poetry.group.dev.dependencies]
3738
decorator = "^5.1.1"

tests/test_benchmarking/test_top_k_tanimoto_scores.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23
import pytest
34
from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores
45
from tests.helper_functions import make_test_fingerprints
@@ -26,3 +27,54 @@ def test_calculate_from_fingerprints():
2627
"AAAAAAAAAAAAAD": 0.75,
2728
"AAAAAAAAAAAAAE": 1.0,
2829
}
30+
31+
32+
@pytest.fixture
33+
def sample_scores():
34+
"""Creates a simple TopKTanimotoScores instance for testing."""
35+
tanimoto_scores = np.array(
36+
[
37+
[0.9, 0.7, 0.5],
38+
[0.8, 0.6, 0.4],
39+
[0.95, 0.85, 0.75],
40+
]
41+
)
42+
top_k_inchikeys = np.array(
43+
[
44+
["INCHI_A", "INCHI_B", "INCHI_C"],
45+
["INCHI_B", "INCHI_C", "INCHI_A"],
46+
["INCHI_C", "INCHI_A", "INCHI_B"],
47+
]
48+
)
49+
inchikey_indexes = np.array(["QUERY_1", "QUERY_2", "QUERY_3"])
50+
return TopKTanimotoScores(tanimoto_scores, top_k_inchikeys, inchikey_indexes)
51+
52+
53+
# ----- save and load tests -----
54+
def test_save_creates_parquet_file(sample_scores, tmp_path):
55+
sample_scores.save(tmp_path / "test_scores")
56+
assert (tmp_path / "test_scores.parquet").exists()
57+
58+
59+
def test_save_creates_parent_directories(sample_scores, tmp_path):
60+
sample_scores.save(tmp_path / "nested" / "dir" / "test_scores")
61+
assert (tmp_path / "nested" / "dir" / "test_scores.parquet").exists()
62+
63+
64+
def test_roundtrip_produces_identical_object(sample_scores, tmp_path):
65+
sample_scores.save(tmp_path / "test_scores")
66+
loaded = TopKTanimotoScores.load(tmp_path / "test_scores")
67+
68+
assert loaded.k == sample_scores.k
69+
pd.testing.assert_frame_equal(loaded.top_k_inchikeys_and_scores, sample_scores.top_k_inchikeys_and_scores)
70+
assert sample_scores.select_top_k_inchikeys_and_scores("QUERY_1") == loaded.select_top_k_inchikeys_and_scores(
71+
"QUERY_1"
72+
)
73+
assert sample_scores.select_top_k_inchikeys("QUERY_2") == loaded.select_top_k_inchikeys("QUERY_2")
74+
assert sample_scores.select_average_score("QUERY_3") == pytest.approx(loaded.select_average_score("QUERY_3"))
75+
76+
77+
def test_roundtrip_accepts_string_path(sample_scores, tmp_path):
78+
sample_scores.save(str(tmp_path / "test_scores"))
79+
loaded = TopKTanimotoScores.load(str(tmp_path / "test_scores"))
80+
pd.testing.assert_frame_equal(loaded.top_k_inchikeys_and_scores, sample_scores.top_k_inchikeys_and_scores)

0 commit comments

Comments
 (0)