Skip to content

Commit f7b1b8f

Browse files
committed
add progress bar when creating annotatedspectrumset and remove progress_bars parameter
1 parent d6418f4 commit f7b1b8f

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

ms2query/benchmarking/AnnotatedSpectrumSet.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Iterable, List, Mapping, Optional, Sequence
33
from matchms import Spectrum
44
from ms2deepscore.models import SiameseSpectralModel
5+
from tqdm import tqdm
56
from ms2query.benchmarking.Embeddings import Embeddings
67

78

@@ -13,24 +14,22 @@ def __init__(
1314
spectra: Sequence[Spectrum],
1415
spectrum_indices_per_inchikey: Mapping[str, Iterable[int]],
1516
embeddings: Optional[Embeddings] = None,
16-
progress_bars=False,
1717
):
1818
self._spectra = tuple([spectrum.clone() for spectrum in spectra])
1919
self.spectrum_indices_per_inchikey: dict[str, tuple[int, ...]] = {
2020
key: tuple(values) for key, values in spectrum_indices_per_inchikey.items()
2121
}
22-
self.progress_bars = progress_bars
2322
self._embeddings = embeddings
2423

2524
@classmethod
26-
def create_spectrum_set(cls, spectra: Sequence[Spectrum], progress_bars=False) -> "AnnotatedSpectrumSet":
25+
def create_spectrum_set(cls, spectra: Sequence[Spectrum]) -> "AnnotatedSpectrumSet":
2726
spectrum_indices_per_inchikey = defaultdict(list)
28-
for spectrum_index, spectrum in enumerate(spectra):
27+
for spectrum_index, spectrum in enumerate(tqdm(spectra, desc="Create mapping from inchikey to spectrum")):
2928
inchikey = spectrum.get("inchikey")
3029
if inchikey is None:
3130
raise ValueError("Annotated Spectrum set expects spectra that all have an inchikey")
3231
spectrum_indices_per_inchikey[inchikey[:14]].append(spectrum_index)
33-
return cls(spectra, spectrum_indices_per_inchikey, progress_bars=progress_bars)
32+
return cls(spectra, spectrum_indices_per_inchikey)
3433

3534
def __add__(self, other) -> "AnnotatedSpectrumSet":
3635
"""Adds two spectrum sets together"""
@@ -52,14 +51,12 @@ def __add__(self, other) -> "AnnotatedSpectrumSet":
5251
embeddings = None
5352
if self._embeddings and other._embeddings:
5453
embeddings = Embeddings.combine_embeddings(self.embeddings, other.embeddings)
55-
return AnnotatedSpectrumSet(
56-
spectra, spectrum_indices_per_inchikey, embeddings=embeddings, progress_bars=self.progress_bars
57-
)
54+
return AnnotatedSpectrumSet(spectra, spectrum_indices_per_inchikey, embeddings=embeddings)
5855

5956
def subset_spectra(self, spectrum_indices) -> "AnnotatedSpectrumSet":
6057
"""Returns a new instance of a subset of the spectra"""
6158
spectra = [self._spectra[index] for index in spectrum_indices]
62-
new_instance = AnnotatedSpectrumSet.create_spectrum_set(spectra, progress_bars=self.progress_bars)
59+
new_instance = AnnotatedSpectrumSet.create_spectrum_set(spectra)
6360
if self._embeddings is not None:
6461
new_instance._embeddings = self.embeddings.subset_embeddings(spectra)
6562
return new_instance
@@ -88,9 +85,7 @@ def inchikeys(self):
8885
return tuple(self.spectrum_indices_per_inchikey.keys())
8986

9087
def __copy__(self):
91-
return AnnotatedSpectrumSet(
92-
self.spectra, self.spectrum_indices_per_inchikey, self.embeddings, progress_bars=self.progress_bars
93-
)
88+
return AnnotatedSpectrumSet(self.spectra, self.spectrum_indices_per_inchikey, self.embeddings)
9489

9590
def __eq__(self, other: object):
9691
if not isinstance(other, AnnotatedSpectrumSet):

0 commit comments

Comments
 (0)