Skip to content

Commit 39c2a7e

Browse files
authored
custom near neigh algorithms in benchmarker (#78)
* custom near neigh * fix * fix * make jax array * update tutorial * Add changelog * bump version
1 parent 913e102 commit 39c2a7e

15 files changed

+340
-77
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.1.1
2+
current_version = 0.2.0
33
tag = True
44
commit = True
55

.pre-commit-config.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,31 @@ repos:
7777
Fix the merge conflicts manually and remove the .rej files.
7878
language: fail
7979
files: '.*\.rej$'
80+
- repo: https://github.com/nbQA-dev/nbQA
81+
rev: 1.6.1
82+
hooks:
83+
- id: nbqa-pyupgrade
84+
args: [--py38-plus]
85+
- id: nbqa-black
86+
- id: nbqa-isort
87+
- id: nbqa-ruff
88+
args: [--fix]
89+
- id: nbqa
90+
entry: nbqa blacken-docs
91+
name: nbqa-blacken-docs
92+
alias: nbqa-blacken-docs
93+
additional_dependencies: [blacken-docs]
94+
args: [--nbqa-md]
95+
- id: nbqa
96+
entry: nbqa mdformat
97+
name: nbqa-mdformat
98+
alias: nbqa-mdformat
99+
additional_dependencies:
100+
[
101+
mdformat,
102+
mdformat-black,
103+
mdformat-frontmatter,
104+
mdformat-web,
105+
mdformat-myst,
106+
]
107+
args: [--nbqa-md]

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11+
## 0.2.0 (2022-02-02)
12+
13+
- Allow custom nearest neighbors methods in Benchmarker ([#78][])
14+
1115
## 0.1.1 (2022-01-04)
1216

1317
- Add new tutorial and fix scalability of lisi ([#71][])

docs/api.md

+14
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ scib_metrics.ilisi_knn(...)
7171
utils.diffusion_nn
7272
```
7373

74+
### Nearest neighbors
75+
76+
```{eval-rst}
77+
.. module:: scib_metrics.nearest_neighbors
78+
.. currentmodule:: scib_metrics
79+
80+
.. autosummary::
81+
:toctree: generated
82+
83+
nearest_neighbors.pynndescent
84+
nearest_neighbors.jax_approx_min_k
85+
nearest_neighbors.NeighborsOutput
86+
```
87+
7488
## Settings
7589

7690
An instance of the {class}`~scib_metrics._settings.ScibConfig` is available as `scib_metrics.settings` and allows configuring scib_metrics.

docs/notebooks/large_scale.ipynb

+118-45
Large diffs are not rendered by default.

docs/notebooks/lung_example.ipynb

+10-7
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@
2323
"outputs": [],
2424
"source": [
2525
"import numpy as np\n",
26-
"from anndata import AnnData\n",
27-
"import matplotlib.pyplot as plt\n",
2826
"import scanpy as sc\n",
29-
"from plottable import Table\n",
3027
"\n",
3128
"from scib_metrics.benchmark import Benchmarker\n",
29+
"\n",
3230
"%matplotlib inline"
3331
]
3432
},
@@ -120,7 +118,7 @@
120118
],
121119
"source": [
122120
"sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor=\"cell_ranger\", batch_key=\"batch\")\n",
123-
"sc.tl.pca(adata, n_comps=30, use_highly_variable=True)\n"
121+
"sc.tl.pca(adata, n_comps=30, use_highly_variable=True)"
124122
]
125123
},
126124
{
@@ -177,6 +175,7 @@
177175
"source": [
178176
"%%capture\n",
179177
"import scanorama\n",
178+
"\n",
180179
"# List of adata per batch\n",
181180
"batch_cats = adata.obs.batch.cat.categories\n",
182181
"adata_list = [adata[adata.obs.batch == b].copy() for b in batch_cats]\n",
@@ -211,6 +210,7 @@
211210
],
212211
"source": [
213212
"import pyliger\n",
213+
"\n",
214214
"bdata = adata.copy()\n",
215215
"# Pyliger normalizes by library size with a size factor of 1\n",
216216
"# So here we give it the count data\n",
@@ -220,7 +220,7 @@
220220
"for i, ad in enumerate(adata_list):\n",
221221
" ad.uns[\"sample_name\"] = batch_cats[i]\n",
222222
" # Hack to make sure each method uses the same genes\n",
223-
" ad.uns['var_gene_idx'] = np.arange(bdata.n_vars)\n",
223+
" ad.uns[\"var_gene_idx\"] = np.arange(bdata.n_vars)\n",
224224
"\n",
225225
"\n",
226226
"liger_data = pyliger.create_liger(adata_list, remove_missing=False, make_sparse=False)\n",
@@ -270,7 +270,8 @@
270270
],
271271
"source": [
272272
"from harmony import harmonize\n",
273-
"adata.obsm[\"Harmony\"] = harmonize(adata.obsm[\"X_pca\"], adata.obs, batch_key = \"batch\")"
273+
"\n",
274+
"adata.obsm[\"Harmony\"] = harmonize(adata.obsm[\"X_pca\"], adata.obs, batch_key=\"batch\")"
274275
]
275276
},
276277
{
@@ -303,6 +304,7 @@
303304
"source": [
304305
"%%capture\n",
305306
"import scvi\n",
307+
"\n",
306308
"scvi.model.SCVI.setup_anndata(adata, layer=\"counts\", batch_key=\"batch\")\n",
307309
"vae = scvi.model.SCVI(adata, gene_likelihood=\"nb\", n_layers=2, n_latent=30)\n",
308310
"vae.train()\n",
@@ -389,7 +391,7 @@
389391
" embedding_obsm_keys=[\"Unintegrated\", \"Scanorama\", \"LIGER\", \"Harmony\", \"scVI\", \"scANVI\"],\n",
390392
" n_jobs=6,\n",
391393
")\n",
392-
"bm.benchmark()\n"
394+
"bm.benchmark()"
393395
]
394396
},
395397
{
@@ -585,6 +587,7 @@
585587
],
586588
"source": [
587589
"from rich import print\n",
590+
"\n",
588591
"df = bm.get_results(min_max_scale=False)\n",
589592
"print(df)"
590593
]

pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["hatchling"]
55

66
[project]
77
name = "scib-metrics"
8-
version = "0.1.1"
8+
version = "0.2.0"
99
description = "Accelerated and Python-only scIB metrics"
1010
readme = "README.md"
1111
requires-python = ">=3.8"
@@ -99,6 +99,11 @@ multi_line_output = 3
9999
profile = "black"
100100
skip_glob = ["docs/*"]
101101

102+
[tool.ruff]
103+
line-length = 88
104+
exclude = [".git","__pycache__","build","docs/","_build","dist"]
105+
ignore = ["E402","E501", "F821", "E741"]
106+
102107
[tool.black]
103108
line-length = 120
104109
target-version = ['py38']

src/scib_metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from importlib.metadata import version
33

4-
from . import utils
4+
from . import nearest_neighbors, utils
55
from ._graph_connectivity import graph_connectivity
66
from ._isolated_labels import isolated_labels
77
from ._kbet import kbet, kbet_per_label
@@ -13,6 +13,7 @@
1313

1414
__all__ = [
1515
"utils",
16+
"nearest_neighbors",
1617
"isolated_labels",
1718
"pcr_comparison",
1819
"silhouette_label",

src/scib_metrics/benchmark/_core.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import asdict, dataclass
44
from enum import Enum
55
from functools import partial
6-
from typing import Any, Dict, List, Optional, Union
6+
from typing import Any, Callable, Dict, List, Optional, Union
77

88
import matplotlib
99
import matplotlib.pyplot as plt
@@ -14,11 +14,11 @@
1414
from plottable import ColumnDefinition, Table
1515
from plottable.cmap import normed_cmap
1616
from plottable.plots import bar
17-
from pynndescent import NNDescent
1817
from sklearn.preprocessing import MinMaxScaler
1918
from tqdm import tqdm
2019

2120
import scib_metrics
21+
from scib_metrics.nearest_neighbors import NeighborsOutput, pynndescent
2222

2323
Kwargs = Dict[str, Any]
2424
MetricType = Union[bool, Kwargs]
@@ -156,8 +156,17 @@ def __init__(
156156
"Batch correction": self._batch_correction_metrics,
157157
}
158158

159-
def prepare(self) -> None:
160-
"""Prepare the data for benchmarking."""
159+
def prepare(self, neighbor_computer: Optional[Callable[[np.ndarray, int], NeighborsOutput]] = None) -> None:
160+
"""Prepare the data for benchmarking.
161+
162+
Parameters
163+
----------
164+
neighbor_computer
165+
Function that computes the neighbors of the data. If `None`, the neighbors will be computed
166+
with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input
167+
the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsOutput`
168+
object.
169+
"""
161170
# Compute PCA
162171
if self._pre_integrated_embedding_obsm_key is None:
163172
# This is how scib does it
@@ -173,24 +182,13 @@ def prepare(self) -> None:
173182

174183
# Compute neighbors
175184
for ad in tqdm(self._emb_adatas.values(), desc="Computing neighbors"):
176-
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
177-
# which is used by scanpy under the hood
178-
n_trees = min(64, 5 + int(round((ad.X.shape[0]) ** 0.5 / 20.0)))
179-
n_iters = max(5, int(round(np.log2(ad.X.shape[0]))))
180-
max_candidates = 60
181-
182-
knn_search_index = NNDescent(
183-
ad.X,
184-
n_neighbors=max(self._neighbor_values),
185-
random_state=0,
186-
low_memory=True,
187-
n_jobs=self._n_jobs,
188-
compressed=False,
189-
n_trees=n_trees,
190-
n_iters=n_iters,
191-
max_candidates=max_candidates,
192-
)
193-
indices, distances = knn_search_index.neighbor_graph
185+
if neighbor_computer is not None:
186+
neigh_output = neighbor_computer(ad.X, max(self._neighbor_values))
187+
else:
188+
neigh_output = pynndescent(
189+
ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
190+
)
191+
indices, distances = neigh_output.indices, neigh_output.distances
194192
for n in self._neighbor_values:
195193
sp_distances, sp_conns = sc.neighbors._compute_connectivities_umap(
196194
indices[:, :n], distances[:, :n], ad.n_obs, n_neighbors=n
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ._dataclass import NeighborsOutput
2+
from ._jax import jax_approx_min_k
3+
from ._pynndescent import pynndescent
4+
5+
__all__ = [
6+
"pynndescent",
7+
"jax_approx_min_k",
8+
"NeighborsOutput",
9+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
3+
import numpy as np
4+
5+
6+
@dataclass
7+
class NeighborsOutput:
8+
"""Output of the nearest neighbors function.
9+
10+
Attributes
11+
----------
12+
distances : np.ndarray
13+
Array of distances to the nearest neighbors.
14+
indices : np.ndarray
15+
Array of indices of the nearest neighbors.
16+
"""
17+
18+
indices: np.ndarray
19+
distances: np.ndarray
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import functools
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
7+
from scib_metrics.utils import cdist, get_ndarray
8+
9+
from ._dataclass import NeighborsOutput
10+
11+
12+
@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
13+
def _euclidean_ann(qy: jnp.ndarray, db: jnp.ndarray, k: int, recall_target: float = 0.95):
14+
"""Compute half squared L2 distance between query points and database points."""
15+
dists = cdist(qy, db)
16+
return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
17+
18+
19+
def jax_approx_min_k(
20+
X: np.ndarray, n_neighbors: int, recall_target: float = 0.95, chunk_size: int = 2048
21+
) -> NeighborsOutput:
22+
"""Run approximate nearest neighbor search using jax.
23+
24+
On TPU backends, this is approximate nearest neighbor search. On other backends, this is exact nearest neighbor search.
25+
26+
Parameters
27+
----------
28+
X
29+
Data matrix.
30+
n_neighbors
31+
Number of neighbors to search for.
32+
recall_target
33+
Target recall for approximate nearest neighbor search.
34+
chunk_size
35+
Number of query points to search for at once.
36+
"""
37+
db = jnp.asarray(X)
38+
# Loop over query points in chunks
39+
neighbors = []
40+
dists = []
41+
for i in range(0, db.shape[0], chunk_size):
42+
start = i
43+
end = min(i + chunk_size, db.shape[0])
44+
qy = db[start:end]
45+
dist, neighbor = _euclidean_ann(qy, db, k=n_neighbors, recall_target=recall_target)
46+
neighbors.append(neighbor)
47+
dists.append(dist)
48+
neighbors = jnp.concatenate(neighbors, axis=0)
49+
dists = jnp.concatenate(dists, axis=0)
50+
return NeighborsOutput(indices=get_ndarray(neighbors), distances=get_ndarray(dists))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
from pynndescent import NNDescent
3+
4+
from ._dataclass import NeighborsOutput
5+
6+
7+
def pynndescent(X: np.ndarray, n_neighbors: int, random_state: int = 0, n_jobs: int = 1) -> NeighborsOutput:
8+
"""Run pynndescent approximate nearest neighbor search.
9+
10+
Parameters
11+
----------
12+
X
13+
Data matrix.
14+
n_neighbors
15+
Number of neighbors to search for.
16+
random_state
17+
Random state.
18+
n_jobs
19+
Number of jobs to use.
20+
"""
21+
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
22+
# which is used by scanpy under the hood
23+
n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0)))
24+
n_iters = max(5, int(round(np.log2(X.shape[0]))))
25+
max_candidates = 60
26+
27+
knn_search_index = NNDescent(
28+
X,
29+
n_neighbors=n_neighbors,
30+
random_state=random_state,
31+
low_memory=True,
32+
n_jobs=n_jobs,
33+
compressed=False,
34+
n_trees=n_trees,
35+
n_iters=n_iters,
36+
max_candidates=max_candidates,
37+
)
38+
indices, distances = knn_search_index.neighbor_graph
39+
40+
return NeighborsOutput(indices=indices, distances=distances)

0 commit comments

Comments
 (0)