Skip to content

Commit 2a1f392

Browse files
authored
bump version and benchmarking changes (#69)
* bump version and benchmarking changes * docs fix * changelog * fix diffusion distance (#70)
1 parent d8960c4 commit 2a1f392

File tree

6 files changed

+58
-30
lines changed

6 files changed

+58
-30
lines changed

.readthedocs.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ version: 2
33
build:
44
os: ubuntu-20.04
55
tools:
6-
python: "3.8"
6+
python: "3.10"
77
sphinx:
88
configuration: docs/conf.py
99
fail_on_warning: false

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11+
## 0.1.0 (2022-01-03)
12+
13+
- Add benchmarking pipeline with plotting ([#52][] and [#69][])
14+
- Fix diffusion distance computation, affecting kbet ([#70][])
15+
16+
[#52]: https://github.com/YosefLab/scib-metrics/pull/52
17+
[#69]: https://github.com/YosefLab/scib-metrics/pull/69
18+
[#70]: https://github.com/YosefLab/scib-metrics/pull/70
19+
1120
## 0.0.9 (2022-12-16)
1221

1322
- Add kbet ([#60][])

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["hatchling"]
55

66
[project]
77
name = "scib-metrics"
8-
version = "0.0.9"
8+
version = "0.1.0"
99
description = "Accelerated and Python-only scIB metrics"
1010
readme = "README.md"
1111
requires-python = ">=3.8"
@@ -44,7 +44,7 @@ dev = [
4444
"pre-commit"
4545
]
4646
doc = [
47-
"sphinx>=4",
47+
"sphinx>=4,<5.3",
4848
"sphinx-book-theme",
4949
"myst-nb",
5050
"sphinxcontrib-bibtex>=1.0.0",

src/scib_metrics/benchmark/_core.py

+39-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import warnings
23
from dataclasses import asdict, dataclass
34
from enum import Enum
4-
from typing import Callable, List, Optional, Union
5+
from functools import partial
6+
from typing import Any, Dict, List, Optional, Union
57

68
import matplotlib
79
import matplotlib.pyplot as plt
@@ -18,12 +20,16 @@
1820

1921
import scib_metrics
2022

23+
Kwargs = Dict[str, Any]
24+
MetricType = Union[bool, Kwargs]
25+
2126
_LABELS = "labels"
2227
_BATCH = "batch"
2328
_X_PRE = "X_pre"
2429
_METRIC_TYPE = "Metric Type"
2530
_AGGREGATE_SCORE = "Aggregate score"
2631

32+
# Mapping of metric fn names to clean DataFrame column names
2733
metric_name_cleaner = {
2834
"silhouette_label": "Silhouette label",
2935
"silhouette_batch": "Silhouette batch",
@@ -40,34 +46,36 @@
4046
}
4147

4248

43-
@dataclass
49+
@dataclass(frozen=True)
4450
class BioConservation:
4551
"""Specification of bio conservation metrics to run in the pipeline.
4652
4753
Metrics can be included using a boolean flag. Custom keyword args can be
48-
used by passing a partial callable of that metric here.
54+
used by passing a dictionary here. Keyword args should not set data-related
55+
parameters, such as `X` or `labels`.
4956
"""
5057

51-
isolated_labels: Union[bool, Callable] = True
52-
nmi_ari_cluster_labels_leiden: Union[bool, Callable] = True
53-
nmi_ari_cluster_labels_kmeans: Union[bool, Callable] = False
54-
silhouette_label: Union[bool, Callable] = True
55-
clisi_knn: Union[bool, Callable] = True
58+
isolated_labels: MetricType = True
59+
nmi_ari_cluster_labels_leiden: MetricType = True
60+
nmi_ari_cluster_labels_kmeans: MetricType = False
61+
silhouette_label: MetricType = True
62+
clisi_knn: MetricType = True
5663

5764

58-
@dataclass
65+
@dataclass(frozen=True)
5966
class BatchCorrection:
6067
"""Specification of which batch correction metrics to run in the pipeline.
6168
6269
Metrics can be included using a boolean flag. Custom keyword args can be
63-
used by passing a partial callable of that metric here.
70+
used by passing a dictionary here. Keyword args should not set data-related
71+
parameters, such as `X` or `labels`.
6472
"""
6573

66-
silhouette_batch: Union[bool, Callable] = True
67-
ilisi_knn: Union[bool, Callable] = True
68-
kbet_per_label: Union[bool, Callable] = True
69-
graph_connectivity: Union[bool, Callable] = True
70-
pcr_comparison: Union[bool, Callable] = True
74+
silhouette_batch: MetricType = True
75+
ilisi_knn: MetricType = True
76+
kbet_per_label: MetricType = True
77+
graph_connectivity: MetricType = True
78+
pcr_comparison: MetricType = True
7179

7280

7381
class MetricAnnDataAPI(Enum):
@@ -138,6 +146,7 @@ def __init__(
138146
self._emb_adatas = {}
139147
self._neighbor_values = (15, 50, 90)
140148
self._prepared = False
149+
self._benchmarked = False
141150
self._batch_key = batch_key
142151
self._label_key = label_key
143152
self._n_jobs = n_jobs
@@ -183,6 +192,12 @@ def prepare(self) -> None:
183192

184193
def benchmark(self) -> None:
185194
"""Run the pipeline."""
195+
if self._benchmarked:
196+
warnings.warn(
197+
"The benchmark has already been run. Running it again will overwrite the previous results.",
198+
UserWarning,
199+
)
200+
186201
if not self._prepared:
187202
self.prepare()
188203

@@ -193,13 +208,12 @@ def benchmark(self) -> None:
193208
for emb_key, ad in tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green"):
194209
pbar = tqdm(total=num_metrics, desc="Metrics", position=1, leave=False, colour="blue")
195210
for metric_type, metric_collection in self._metric_collection_dict.items():
196-
for metric_name, use_metric in asdict(metric_collection).items():
197-
if use_metric:
198-
if isinstance(metric_name, str):
199-
metric_fn = getattr(scib_metrics, metric_name)
200-
else:
201-
# Callable in this case
202-
metric_fn = use_metric
211+
for metric_name, use_metric_or_kwargs in asdict(metric_collection).items():
212+
if use_metric_or_kwargs:
213+
metric_fn = getattr(scib_metrics, metric_name)
214+
if isinstance(use_metric_or_kwargs, dict):
215+
# Kwargs in this case
216+
metric_fn = partial(metric_fn, **use_metric_or_kwargs)
203217
metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn)
204218
# nmi/ari metrics return a dict
205219
if isinstance(metric_value, dict):
@@ -211,6 +225,8 @@ def benchmark(self) -> None:
211225
self._results.loc[metric_name, _METRIC_TYPE] = metric_type
212226
pbar.update(1)
213227

228+
self._benchmarked = True
229+
214230
def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> pd.DataFrame:
215231
"""Return the benchmarking results.
216232
@@ -242,6 +258,7 @@ def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> p
242258

243259
# Compute scores
244260
per_class_score = df.groupby(_METRIC_TYPE).mean().transpose()
261+
# This is the default scIB weighting from the manuscript
245262
per_class_score["Total"] = 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
246263
df = pd.concat([df.transpose(), per_class_score], axis=1)
247264
df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE

src/scib_metrics/utils/_diffusion_nn.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,12 @@ def diffusion_nn(X: csr_matrix, k: int, n_comps: int = 100):
113113
Neighbors graph
114114
"""
115115
transitions = _compute_transitions(X)
116-
_, embedding = _compute_eigen(transitions, n_comps=n_comps)
116+
evals, evecs = _compute_eigen(transitions, n_comps=n_comps)
117+
evals += 1e-8 # Avoid division by zero
118+
# Multiscale such that the number of steps t gets "integrated out"
119+
# First eigenvalue is 1, so we start at the second one
120+
embedding = evecs
121+
embedding[:, 1:] = (evals[1:] / (1 - evals[1:])) * embedding[:, 1:]
117122
nn_obj = pynndescent.NNDescent(embedding, n_neighbors=k + 1)
118123
neigh_inds, neigh_distances = nn_obj.neighbor_graph
119124
# We purposely ignore the first neighbor as it is the cell itself

tests/test_benchmarker.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from functools import partial
2-
31
import pandas as pd
42

5-
from scib_metrics import clisi_knn
63
from scib_metrics.benchmark import BatchCorrection, Benchmarker, BioConservation
74
from tests.utils.data import dummy_benchmarker_adata
85

@@ -36,7 +33,7 @@ def test_benchmarker_custom_metric_booleans():
3633

3734

3835
def test_benchmarker_custom_metric_callable():
39-
bioc = BioConservation(clisi_knn=partial(clisi_knn, perplexity=10))
36+
bioc = BioConservation(clisi_knn={"perplexity": 10})
4037
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
4138
bm = Benchmarker(ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc)
4239
bm.benchmark()

0 commit comments

Comments
 (0)