22from typing import Iterable , List , Mapping , Optional , Sequence
33from matchms import Spectrum
44from ms2deepscore .models import SiameseSpectralModel
5+ from tqdm import tqdm
56from ms2query .benchmarking .Embeddings import Embeddings
67
78
@@ -13,24 +14,22 @@ def __init__(
1314 spectra : Sequence [Spectrum ],
1415 spectrum_indices_per_inchikey : Mapping [str , Iterable [int ]],
1516 embeddings : Optional [Embeddings ] = None ,
16- progress_bars = False ,
1717 ):
1818 self ._spectra = tuple ([spectrum .clone () for spectrum in spectra ])
1919 self .spectrum_indices_per_inchikey : dict [str , tuple [int , ...]] = {
2020 key : tuple (values ) for key , values in spectrum_indices_per_inchikey .items ()
2121 }
22- self .progress_bars = progress_bars
2322 self ._embeddings = embeddings
2423
2524 @classmethod
26- def create_spectrum_set (cls , spectra : Sequence [Spectrum ], progress_bars = False ) -> "AnnotatedSpectrumSet" :
25+ def create_spectrum_set (cls , spectra : Sequence [Spectrum ]) -> "AnnotatedSpectrumSet" :
2726 spectrum_indices_per_inchikey = defaultdict (list )
28- for spectrum_index , spectrum in enumerate (spectra ):
27+ for spectrum_index , spectrum in enumerate (tqdm ( spectra , desc = "Create mapping from inchikey to spectrum" ) ):
2928 inchikey = spectrum .get ("inchikey" )
3029 if inchikey is None :
3130 raise ValueError ("Annotated Spectrum set expects spectra that all have an inchikey" )
3231 spectrum_indices_per_inchikey [inchikey [:14 ]].append (spectrum_index )
33- return cls (spectra , spectrum_indices_per_inchikey , progress_bars = progress_bars )
32+ return cls (spectra , spectrum_indices_per_inchikey )
3433
3534 def __add__ (self , other ) -> "AnnotatedSpectrumSet" :
3635 """Adds two spectrum sets together"""
@@ -52,14 +51,12 @@ def __add__(self, other) -> "AnnotatedSpectrumSet":
5251 embeddings = None
5352 if self ._embeddings and other ._embeddings :
5453 embeddings = Embeddings .combine_embeddings (self .embeddings , other .embeddings )
55- return AnnotatedSpectrumSet (
56- spectra , spectrum_indices_per_inchikey , embeddings = embeddings , progress_bars = self .progress_bars
57- )
54+ return AnnotatedSpectrumSet (spectra , spectrum_indices_per_inchikey , embeddings = embeddings )
5855
5956 def subset_spectra (self , spectrum_indices ) -> "AnnotatedSpectrumSet" :
6057 """Returns a new instance of a subset of the spectra"""
6158 spectra = [self ._spectra [index ] for index in spectrum_indices ]
62- new_instance = AnnotatedSpectrumSet .create_spectrum_set (spectra , progress_bars = self . progress_bars )
59+ new_instance = AnnotatedSpectrumSet .create_spectrum_set (spectra )
6360 if self ._embeddings is not None :
6461 new_instance ._embeddings = self .embeddings .subset_embeddings (spectra )
6562 return new_instance
@@ -88,9 +85,7 @@ def inchikeys(self):
8885 return tuple (self .spectrum_indices_per_inchikey .keys ())
8986
9087 def __copy__ (self ):
91- return AnnotatedSpectrumSet (
92- self .spectra , self .spectrum_indices_per_inchikey , self .embeddings , progress_bars = self .progress_bars
93- )
88+ return AnnotatedSpectrumSet (self .spectra , self .spectrum_indices_per_inchikey , self .embeddings )
9489
9590 def __eq__ (self , other : object ):
9691 if not isinstance (other , AnnotatedSpectrumSet ):
0 commit comments