|
| 1 | +from pathlib import Path |
1 | 2 | import numpy as np |
2 | 3 | import pandas as pd |
3 | 4 | from ms2query.benchmarking.Fingerprints import Fingerprints |
@@ -27,13 +28,21 @@ def _create_multi_index( |
27 | 28 | combined_data = np.empty((len(inchikey_indexes), self.k * 2), dtype=object) |
28 | 29 | combined_data[:, 0::2] = top_k_inchikeys |
29 | 30 | combined_data[:, 1::2] = tanimoto_scores_for_top_k |
30 | | - return pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns) |
| 31 | + df = pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns) |
| 32 | + |
| 33 | + # Cast score columns to float64 |
| 34 | + score_cols = [(rank, "score") for rank in [f"Rank_{i + 1}" for i in range(self.k)]] |
| 35 | + df[score_cols] = df[score_cols].astype(float) |
| 36 | + |
| 37 | + return df |
31 | 38 |
|
32 | 39 | @classmethod |
33 | 40 | def calculate_from_fingerprints(cls, query_fingerprints: Fingerprints, target_fingerprints: Fingerprints, k): |
34 | 41 | """ |
35 | 42 | Gets the top k highest inchikeys and scores for each inchikey in query_fingerprints from target_fingerprints |
36 | 43 | """ |
| 44 | + if target_fingerprints.fingerprints.shape[0] < k: |
| 45 | + raise ValueError("K cannot be larger than the number of fingerprints") |
37 | 46 | similarity_scores = generalized_tanimoto_similarity_matrix( |
38 | 47 | query_fingerprints.fingerprints, target_fingerprints.fingerprints |
39 | 48 | ) |
@@ -67,3 +76,30 @@ def get_all_average_tanimoto_scores(self) -> dict[str, float]: |
67 | 76 |
|
68 | 77 | average_per_inchikey_df = scores_df.mean(axis=1) |
69 | 78 | return average_per_inchikey_df.to_dict() |
| 79 | + |
| 80 | + def save(self, path: str | Path) -> None: |
| 81 | + """Save the TopKTanimotoScores to disk as a parquet file. |
| 82 | +
|
| 83 | + Args: |
| 84 | + path: File path without extension, e.g. "/data/top_k_scores". |
| 85 | + """ |
| 86 | + Path(path).with_suffix(".parquet").parent.mkdir(parents=True, exist_ok=True) |
| 87 | + self.top_k_inchikeys_and_scores.to_parquet(Path(path).with_suffix(".parquet")) |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def load(cls, path: str | Path) -> "TopKTanimotoScores": |
| 91 | + """Load a previously saved TopKTanimotoScores from disk. |
| 92 | +
|
| 93 | + Args: |
| 94 | + path: File path without extension, e.g. "/data/top_k_scores". |
| 95 | +
|
| 96 | + Returns: |
| 97 | + A fully reconstructed TopKTanimotoScores instance. |
| 98 | + """ |
| 99 | + df = pd.read_parquet(Path(path).with_suffix(".parquet")) |
| 100 | + df.columns.names = ["result_rank", "attribute"] |
| 101 | + |
| 102 | + instance = cls.__new__(cls) |
| 103 | + instance.k = len(df.columns.get_level_values("result_rank").unique()) |
| 104 | + instance.top_k_inchikeys_and_scores = df |
| 105 | + return instance |
0 commit comments