Skip to content

Commit 20bfaa7

Browse files
authored
new tutorial and scalability fixes (#71)
* new tutorial and scalability fixes * notebook
1 parent d7513d5 commit 20bfaa7

File tree

7 files changed

+441
-25
lines changed

7 files changed

+441
-25
lines changed

docs/notebooks/large_scale.ipynb

+403
Large diffs are not rendered by default.

docs/tutorials.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
:maxdepth: 1
77
88
notebooks/lung_example
9+
notebooks/large_scale
910
```

src/scib_metrics/_kbet.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def kbet_per_label(
178178
batches=batches_sub,
179179
alpha=alpha,
180180
)
181-
except RuntimeError:
182-
logger.info("Not enough neighbors")
181+
except ValueError:
182+
logger.info("Diffusion distance failed. Skip.")
183183
score = 0 # i.e. 100% rejection
184184

185185
else:
@@ -204,8 +204,8 @@ def kbet_per_label(
204204
batches=batches_sub[idx_nonan],
205205
alpha=alpha,
206206
)
207-
except RuntimeError:
208-
logger.info("Not enough neighbors")
207+
except ValueError:
208+
logger.info("Diffusion distance failed. Skip.")
209209
score = 0 # i.e. 100% rejection
210210
else: # if there are too many too small connected components, set kBET score to 0
211211
score = 0 # i.e. 100% rejection

src/scib_metrics/benchmark/_core.py

+10
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,22 @@ def prepare(self) -> None:
173173

174174
# Compute neighbors
175175
for ad in tqdm(self._emb_adatas.values(), desc="Computing neighbors"):
176+
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
177+
# which is used by scanpy under the hood
178+
n_trees = min(64, 5 + int(round((ad.X.shape[0]) ** 0.5 / 20.0)))
179+
n_iters = max(5, int(round(np.log2(ad.X.shape[0]))))
180+
max_candidates = 60
181+
176182
knn_search_index = NNDescent(
177183
ad.X,
178184
n_neighbors=max(self._neighbor_values),
179185
random_state=0,
180186
low_memory=True,
181187
n_jobs=self._n_jobs,
188+
compressed=False,
189+
n_trees=n_trees,
190+
n_iters=n_iters,
191+
max_candidates=max_candidates,
182192
)
183193
indices, distances = knn_search_index.neighbor_graph
184194
for n in self._neighbor_values:

src/scib_metrics/utils/_diffusion_nn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def diffusion_nn(X: csr_matrix, k: int, n_comps: int = 100):
116116
evals, evecs = _compute_eigen(transitions, n_comps=n_comps)
117117
evals += 1e-8 # Avoid division by zero
118118
# Multiscale such that the number of steps t gets "integrated out"
119-
# First eigenvalue is 1, so we start at the second one
120119
embedding = evecs
121-
embedding[:, 1:] = (evals[1:] / (1 - evals[1:])) * embedding[:, 1:]
120+
scaled_evals = np.array([e if e == 1 else e / (1 - e) for e in evals])
121+
embedding *= scaled_evals
122122
nn_obj = pynndescent.NNDescent(embedding, n_neighbors=k + 1)
123123
neigh_inds, neigh_distances = nn_obj.neighbor_graph
124124
# We purposely ignore the first neighbor as it is the cell itself

src/scib_metrics/utils/_lisi.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from functools import partial
12
from typing import Tuple, Union
23

34
import chex
45
import jax
56
import jax.numpy as jnp
67
import numpy as np
78

9+
from ._utils import get_ndarray
10+
811
NdArray = Union[np.ndarray, jnp.ndarray]
912

1013

@@ -68,15 +71,13 @@ def _get_neighbor_probability_convergence(state):
6871

6972

7073
def _compute_simpson_index_cell(
71-
knn_dists_row: jnp.ndarray, knn_row: jnp.ndarray, labels: jnp.ndarray, n_batches: int, perplexity: float, tol: float
74+
knn_dists_row: jnp.ndarray, knn_labels_row: jnp.ndarray, n_batches: int, perplexity: float, tol: float
7275
) -> jnp.ndarray:
7376
H, P = _get_neighbor_probability(knn_dists_row, perplexity, tol)
7477

7578
def _non_zero_H_simpson():
76-
knn_labels = jnp.take(labels, knn_row)
77-
L = jax.nn.one_hot(knn_labels, n_batches)
78-
sumP = P @ L
79-
return jnp.where(knn_labels.shape[0] == P.shape[0], jnp.dot(sumP, sumP), 1)
79+
sumP = jnp.bincount(knn_labels_row, weights=P, length=n_batches)
80+
return jnp.where(knn_labels_row.shape[0] == P.shape[0], jnp.dot(sumP, sumP), 1)
8081

8182
return jnp.where(H == 0, -1, _non_zero_H_simpson())
8283

@@ -114,9 +115,7 @@ def compute_simpson_index(
114115
knn_dists = jnp.array(knn_dists)
115116
knn_idx = jnp.array(knn_idx)
116117
labels = jnp.array(labels)
117-
n = knn_dists.shape[0]
118-
return jax.device_get(
119-
jax.vmap(
120-
lambda i: _compute_simpson_index_cell(knn_dists[i, :], knn_idx[i, :], labels, n_labels, perplexity, tol)
121-
)(jnp.arange(n))
122-
)
118+
knn_labels = labels[knn_idx]
119+
simpson_fn = partial(_compute_simpson_index_cell, n_batches=n_labels, perplexity=perplexity, tol=tol)
120+
out = jax.vmap(simpson_fn)(knn_dists, knn_labels)
121+
return get_ndarray(out)

src/scib_metrics/utils/_silhouette.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,19 @@ def _silhouette_reduce(
3333
"""
3434
# accumulate distances from each sample to each cluster
3535
D_chunk_len = D_chunk.shape[0]
36-
clust_dists = jnp.zeros((D_chunk_len, len(label_freqs)), dtype=D_chunk.dtype)
3736

38-
def _bincount(i, _data):
39-
clust_dists, D_chunk, labels, label_freqs = _data
40-
clust_dists = clust_dists.at[i].set(jnp.bincount(labels, weights=D_chunk[i], length=label_freqs.shape[0]))
41-
return clust_dists, D_chunk, labels, label_freqs
37+
# If running into memory issues, use fori_loop instead of vmap
38+
# clust_dists = jnp.zeros((D_chunk_len, len(label_freqs)), dtype=D_chunk.dtype)
39+
# def _bincount(i, _data):
40+
# clust_dists, D_chunk, labels, label_freqs = _data
41+
# clust_dists = clust_dists.at[i].set(jnp.bincount(labels, weights=D_chunk[i], length=label_freqs.shape[0]))
42+
# return clust_dists, D_chunk, labels, label_freqs
4243

43-
clust_dists = jax.lax.fori_loop(
44-
0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs)
45-
)[0]
44+
# clust_dists = jax.lax.fori_loop(
45+
# 0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs)
46+
# )[0]
47+
48+
clust_dists = jax.vmap(partial(jnp.bincount, length=label_freqs.shape[0]), in_axes=(None, 0))(labels, D_chunk)
4649

4750
# intra_index selects intra-cluster distances within clust_dists
4851
intra_index = (jnp.arange(D_chunk_len), jax.lax.dynamic_slice(labels, (start,), (D_chunk_len,)))

0 commit comments

Comments
 (0)