1
1
import os
2
+ import warnings
2
3
from dataclasses import asdict , dataclass
3
4
from enum import Enum
4
- from typing import Callable , List , Optional , Union
5
+ from functools import partial
6
+ from typing import Any , Dict , List , Optional , Union
5
7
6
8
import matplotlib
7
9
import matplotlib .pyplot as plt
18
20
19
21
import scib_metrics
20
22
23
+ Kwargs = Dict [str , Any ]
24
+ MetricType = Union [bool , Kwargs ]
25
+
21
26
_LABELS = "labels"
22
27
_BATCH = "batch"
23
28
_X_PRE = "X_pre"
24
29
_METRIC_TYPE = "Metric Type"
25
30
_AGGREGATE_SCORE = "Aggregate score"
26
31
32
+ # Mapping of metric fn names to clean DataFrame column names
27
33
metric_name_cleaner = {
28
34
"silhouette_label" : "Silhouette label" ,
29
35
"silhouette_batch" : "Silhouette batch" ,
40
46
}
41
47
42
48
43
- @dataclass
49
+ @dataclass ( frozen = True )
44
50
class BioConservation :
45
51
"""Specification of bio conservation metrics to run in the pipeline.
46
52
47
53
Metrics can be included using a boolean flag. Custom keyword args can be
48
- used by passing a partial callable of that metric here.
54
+ used by passing a dictionary here. Keyword args should not set data-related
55
+ parameters, such as `X` or `labels`.
49
56
"""
50
57
51
- isolated_labels : Union [ bool , Callable ] = True
52
- nmi_ari_cluster_labels_leiden : Union [ bool , Callable ] = True
53
- nmi_ari_cluster_labels_kmeans : Union [ bool , Callable ] = False
54
- silhouette_label : Union [ bool , Callable ] = True
55
- clisi_knn : Union [ bool , Callable ] = True
58
+ isolated_labels : MetricType = True
59
+ nmi_ari_cluster_labels_leiden : MetricType = True
60
+ nmi_ari_cluster_labels_kmeans : MetricType = False
61
+ silhouette_label : MetricType = True
62
+ clisi_knn : MetricType = True
56
63
57
64
58
- @dataclass
65
+ @dataclass ( frozen = True )
59
66
class BatchCorrection :
60
67
"""Specification of which batch correction metrics to run in the pipeline.
61
68
62
69
Metrics can be included using a boolean flag. Custom keyword args can be
63
- used by passing a partial callable of that metric here.
70
+ used by passing a dictionary here. Keyword args should not set data-related
71
+ parameters, such as `X` or `labels`.
64
72
"""
65
73
66
- silhouette_batch : Union [ bool , Callable ] = True
67
- ilisi_knn : Union [ bool , Callable ] = True
68
- kbet_per_label : Union [ bool , Callable ] = True
69
- graph_connectivity : Union [ bool , Callable ] = True
70
- pcr_comparison : Union [ bool , Callable ] = True
74
+ silhouette_batch : MetricType = True
75
+ ilisi_knn : MetricType = True
76
+ kbet_per_label : MetricType = True
77
+ graph_connectivity : MetricType = True
78
+ pcr_comparison : MetricType = True
71
79
72
80
73
81
class MetricAnnDataAPI (Enum ):
@@ -138,6 +146,7 @@ def __init__(
138
146
self ._emb_adatas = {}
139
147
self ._neighbor_values = (15 , 50 , 90 )
140
148
self ._prepared = False
149
+ self ._benchmarked = False
141
150
self ._batch_key = batch_key
142
151
self ._label_key = label_key
143
152
self ._n_jobs = n_jobs
@@ -183,6 +192,12 @@ def prepare(self) -> None:
183
192
184
193
def benchmark (self ) -> None :
185
194
"""Run the pipeline."""
195
+ if self ._benchmarked :
196
+ warnings .warn (
197
+ "The benchmark has already been run. Running it again will overwrite the previous results." ,
198
+ UserWarning ,
199
+ )
200
+
186
201
if not self ._prepared :
187
202
self .prepare ()
188
203
@@ -193,13 +208,12 @@ def benchmark(self) -> None:
193
208
for emb_key , ad in tqdm (self ._emb_adatas .items (), desc = "Embeddings" , position = 0 , colour = "green" ):
194
209
pbar = tqdm (total = num_metrics , desc = "Metrics" , position = 1 , leave = False , colour = "blue" )
195
210
for metric_type , metric_collection in self ._metric_collection_dict .items ():
196
- for metric_name , use_metric in asdict (metric_collection ).items ():
197
- if use_metric :
198
- if isinstance (metric_name , str ):
199
- metric_fn = getattr (scib_metrics , metric_name )
200
- else :
201
- # Callable in this case
202
- metric_fn = use_metric
211
+ for metric_name , use_metric_or_kwargs in asdict (metric_collection ).items ():
212
+ if use_metric_or_kwargs :
213
+ metric_fn = getattr (scib_metrics , metric_name )
214
+ if isinstance (use_metric_or_kwargs , dict ):
215
+ # Kwargs in this case
216
+ metric_fn = partial (metric_fn , ** use_metric_or_kwargs )
203
217
metric_value = getattr (MetricAnnDataAPI , metric_name )(ad , metric_fn )
204
218
# nmi/ari metrics return a dict
205
219
if isinstance (metric_value , dict ):
@@ -211,6 +225,8 @@ def benchmark(self) -> None:
211
225
self ._results .loc [metric_name , _METRIC_TYPE ] = metric_type
212
226
pbar .update (1 )
213
227
228
+ self ._benchmarked = True
229
+
214
230
def get_results (self , min_max_scale : bool = True , clean_names : bool = True ) -> pd .DataFrame :
215
231
"""Return the benchmarking results.
216
232
@@ -242,6 +258,7 @@ def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> p
242
258
243
259
# Compute scores
244
260
per_class_score = df .groupby (_METRIC_TYPE ).mean ().transpose ()
261
+ # This is the default scIB weighting from the manuscript
245
262
per_class_score ["Total" ] = 0.4 * per_class_score ["Batch correction" ] + 0.6 * per_class_score ["Bio conservation" ]
246
263
df = pd .concat ([df .transpose (), per_class_score ], axis = 1 )
247
264
df .loc [_METRIC_TYPE , per_class_score .columns ] = _AGGREGATE_SCORE
0 commit comments