Skip to content

Commit 0940e93

Browse files
justjhongpre-commit-ci[bot]adamgayoso
authored
LISI implementation (#20)
* simpson index compute * running lisi * fix bug with lisi and expose bug in silhouette test * add docs * flake8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adjust atol for tiny silhouette val differences * flake * address flax comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use nneighbors to convert graph * Update pyproject.toml * move dep to test * use chex * add ilisi and clisi * add to docs * add references in docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adam Gayoso <[email protected]>
1 parent f88d530 commit 0940e93

File tree

8 files changed

+308
-5
lines changed

8 files changed

+308
-5
lines changed

docs/api.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
nmi_ari_cluster_labels_leiden
1515
silhouette_label
1616
silhouette_batch
17+
ilisi_knn
18+
clisi_knn
1719
```
1820

1921
## Utils

docs/references.bib

+15
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,18 @@ @article{luecken2022benchmarking
88
year = {2022},
99
publisher = {Nature Publishing Group}
1010
}
11+
12+
13+
@article{korsunsky2019harmony,
14+
title = {Fast, sensitive and accurate integration of single-cell data with
15+
Harmony},
16+
author = {Korsunsky, Ilya and Millard, Nghia and Fan, Jean and Slowikowski,
17+
Kamil and Zhang, Fan and Wei, Kevin and Baglaenko, Yuriy and
18+
Brenner, Michael and Loh, Po-Ru and Raychaudhuri, Soumya},
19+
journal = {Nat. Methods},
20+
volume = {16},
21+
number = {12},
22+
pages = {1289--1296},
23+
month = {dec},
24+
year = {2019},
25+
}

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ urls.Source = "https://github.com/yoseflab/scib-metrics"
2121
urls.Home-page = "https://github.com/yoseflab/scib-metrics"
2222
dependencies = [
2323
"anndata",
24+
"chex",
2425
"jax",
2526
"jaxlib",
2627
"numpy",
@@ -51,6 +52,7 @@ doc = [
5152
test = [
5253
"pytest",
5354
"pytest-cov",
55+
"harmonypy",
5456
"joblib",
5557
]
5658
parallel = [

src/scib_metrics/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from . import utils
55
from ._ari_nmi import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
66
from ._isolated_labels import isolated_labels
7+
from ._lisi import clisi_knn, ilisi_knn, lisi_knn
78
from ._settings import settings
89
from ._silhouette import silhouette_batch, silhouette_label
910

@@ -12,6 +13,8 @@
1213
"isolated_labels",
1314
"silhouette_label",
1415
"silhouette_batch",
16+
"ilisi_knn",
17+
"clisi_knn",
1518
"nmi_ari_cluster_labels_kmeans",
1619
"nmi_ari_cluster_labels_leiden",
1720
]

src/scib_metrics/_lisi.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
from scipy.sparse import csr_matrix
5+
from sklearn.neighbors import NearestNeighbors
6+
from sklearn.utils import check_array
7+
8+
from scib_metrics.utils import compute_simpson_index
9+
10+
11+
def _convert_knn_graph_to_idx(knn_graph: csr_matrix) -> Tuple[np.ndarray, np.ndarray]:
12+
check_array(knn_graph, accept_sparse="csr")
13+
14+
n_neighbors = np.unique(knn_graph.nonzero()[0], return_counts=True)[1]
15+
if len(np.unique(n_neighbors)) > 1:
16+
raise ValueError("Each cell must have the same number of neighbors.")
17+
18+
n_neighbors = int(np.unique(n_neighbors)[0])
19+
20+
nn_obj = NearestNeighbors(n_neighbors=n_neighbors, metric="precomputed").fit(knn_graph)
21+
return nn_obj.kneighbors(knn_graph)
22+
23+
24+
def lisi_knn(knn_graph: csr_matrix, labels: np.ndarray, perplexity: float = None) -> np.ndarray:
25+
"""Compute the local inverse simpson index (LISI) for each cell :cite:p:`korsunsky2019harmony`.
26+
27+
Parameters
28+
----------
29+
knn_graph
30+
Sparse array of shape (n_cells, n_cells) with non-zero values for
31+
exactly each cell's k nearest neighbors.
32+
labels
33+
Array of shape (n_cells,) representing label values
34+
for each cell.
35+
perplexity
36+
Parameter controlling effective neighborhood size. If None, the
37+
perplexity is set to the number of neighbors // 3.
38+
39+
Returns
40+
-------
41+
lisi
42+
Array of shape (n_cells,) with the LISI score for each cell.
43+
"""
44+
knn_dists, knn_idx = _convert_knn_graph_to_idx(knn_graph)
45+
46+
if perplexity is None:
47+
perplexity = np.floor(knn_idx.shape[1] / 3)
48+
49+
n_labels = len(np.unique(labels))
50+
51+
simpson = compute_simpson_index(knn_dists, knn_idx, labels, n_labels, perplexity=perplexity)
52+
return 1 / simpson
53+
54+
55+
def ilisi_knn(knn_graph: csr_matrix, batches: np.ndarray, perplexity: float = None, scale: bool = True) -> np.ndarray:
56+
"""Compute the integration local inverse simpson index (iLISI) for each cell :cite:p:`korsunsky2019harmony`.
57+
58+
Returns a scaled version of the iLISI score for each cell, by default :cite:p:`luecken2022benchmarking`.
59+
60+
Parameters
61+
----------
62+
knn_graph
63+
Sparse array of shape (n_cells, n_cells) with non-zero values for
64+
exactly each cell's k nearest neighbors.
65+
batches
66+
Array of shape (n_cells,) representing batch values
67+
for each cell.
68+
perplexity
69+
Parameter controlling effective neighborhood size. If None, the
70+
perplexity is set to the number of neighbors // 3.
71+
scale
72+
Scale lisi into the range [0, 1]. If True, higher values are better.
73+
74+
Returns
75+
-------
76+
ilisi
77+
Array of shape (n_cells,) with the iLISI score for each cell.
78+
"""
79+
lisi = lisi_knn(knn_graph, batches, perplexity=perplexity)
80+
ilisi = np.nanmedian(lisi)
81+
if scale:
82+
nbatches = len(np.unique(batches))
83+
ilisi = (ilisi - 1) / (nbatches - 1)
84+
return ilisi
85+
86+
87+
def clisi_knn(knn_graph: csr_matrix, labels: np.ndarray, perplexity: float = None, scale: bool = True) -> np.ndarray:
88+
"""Compute the cell-type local inverse simpson index (cLISI) for each cell :cite:p:`korsunsky2019harmony`.
89+
90+
Returns a scaled version of the cLISI score for each cell, by default :cite:p:`luecken2022benchmarking`.
91+
92+
Parameters
93+
----------
94+
knn_graph
95+
Sparse array of shape (n_cells, n_cells) with non-zero values for
96+
exactly each cell's k nearest neighbors.
97+
labels
98+
Array of shape (n_cells,) representing cell type label values
99+
for each cell.
100+
perplexity
101+
Parameter controlling effective neighborhood size. If None, the
102+
perplexity is set to the number of neighbors // 3.
103+
scale
104+
Scale lisi into the range [0, 1]. If True, higher values are better.
105+
106+
Returns
107+
-------
108+
clisi
109+
Array of shape (n_cells,) with the cLISI score for each cell.
110+
"""
111+
lisi = lisi_knn(knn_graph, labels, perplexity=perplexity)
112+
clisi = np.nanmedian(lisi)
113+
if scale:
114+
nlabels = len(np.unique(labels))
115+
clisi = (nlabels - clisi) / (nlabels - 1)
116+
return clisi

src/scib_metrics/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._dist import cdist
22
from ._kmeans import KMeansJax
3+
from ._lisi import compute_simpson_index
34
from ._silhouette import silhouette_samples
45

5-
__all__ = ["silhouette_samples", "cdist", "KMeansJax"]
6+
__all__ = ["silhouette_samples", "cdist", "KMeansJax", "compute_simpson_index"]

src/scib_metrics/utils/_lisi.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Tuple, Union
2+
3+
import chex
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
8+
NdArray = Union[np.ndarray, jnp.ndarray]
9+
10+
11+
@chex.dataclass
12+
class _NeighborProbabilityState:
13+
H: float
14+
P: chex.ArrayDevice
15+
Hdiff: float
16+
beta: float
17+
betamin: float
18+
betamax: float
19+
tries: int
20+
21+
22+
@jax.jit
23+
def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
24+
P = jnp.exp(-knn_dists_row * beta)
25+
sumP = jnp.nansum(P)
26+
H = jnp.where(sumP == 0, 0, jnp.log(sumP) + beta * jnp.nansum(knn_dists_row * P) / sumP)
27+
P = jnp.where(sumP == 0, jnp.zeros_like(knn_dists_row), P / sumP)
28+
return H, P
29+
30+
31+
@jax.jit
32+
def _get_neighbor_probability(
33+
knn_dists_row: jnp.ndarray, perplexity: float, tol: float
34+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
35+
beta = 1
36+
betamin = -jnp.inf
37+
betamax = jnp.inf
38+
H, P = _Hbeta(knn_dists_row, beta)
39+
Hdiff = H - jnp.log(perplexity)
40+
41+
def _get_neighbor_probability_step(state):
42+
Hdiff = state.Hdiff
43+
beta = state.beta
44+
betamin = state.betamin
45+
betamax = state.betamax
46+
tries = state.tries
47+
48+
new_betamin = jnp.where(Hdiff > 0, beta, betamin)
49+
new_betamax = jnp.where(Hdiff > 0, betamax, beta)
50+
new_beta = jnp.where(
51+
Hdiff > 0,
52+
jnp.where(betamax == jnp.inf, beta * 2, (beta + betamax) / 2),
53+
jnp.where(betamin == -jnp.inf, beta / 2, (beta + betamin) / 2),
54+
)
55+
new_H, new_P = _Hbeta(knn_dists_row, new_beta)
56+
new_Hdiff = new_H - jnp.log(perplexity)
57+
return _NeighborProbabilityState(
58+
H=new_H, P=new_P, Hdiff=new_Hdiff, beta=new_beta, betamin=new_betamin, betamax=new_betamax, tries=tries + 1
59+
)
60+
61+
def _get_neighbor_probability_convergence(state):
62+
Hdiff, tries = state.Hdiff, state.tries
63+
return jnp.logical_and(jnp.abs(Hdiff) > tol, tries < 50)
64+
65+
init_state = _NeighborProbabilityState(H=H, P=P, Hdiff=Hdiff, beta=beta, betamin=betamin, betamax=betamax, tries=0)
66+
final_state = jax.lax.while_loop(_get_neighbor_probability_convergence, _get_neighbor_probability_step, init_state)
67+
return final_state.H, final_state.P
68+
69+
70+
def _compute_simpson_index_cell(
71+
knn_dists_row: jnp.ndarray, knn_row: jnp.ndarray, labels: jnp.ndarray, n_batches: int, perplexity: float, tol: float
72+
) -> jnp.ndarray:
73+
H, P = _get_neighbor_probability(knn_dists_row, perplexity, tol)
74+
75+
def _non_zero_H_simpson():
76+
knn_labels = jnp.take(labels, knn_row)
77+
L = jax.nn.one_hot(knn_labels, n_batches)
78+
sumP = P @ L
79+
return jnp.where(knn_labels.shape[0] == P.shape[0], jnp.dot(sumP, sumP), 1)
80+
81+
return jnp.where(H == 0, -1, _non_zero_H_simpson())
82+
83+
84+
def compute_simpson_index(
85+
knn_dists: NdArray,
86+
knn_idx: NdArray,
87+
labels: NdArray,
88+
n_labels: int,
89+
perplexity: float = 30,
90+
tol: float = 1e-5,
91+
) -> np.ndarray:
92+
"""Compute the Simpson index for each cell.
93+
94+
Parameters
95+
----------
96+
knn_dists
97+
KNN distances of size (n_cells, n_neighbors).
98+
knn_idx
99+
KNN indices of size (n_cells, n_neighbors) corresponding to distances.
100+
labels
101+
Cell labels of size (n_cells,).
102+
n_labels
103+
Number of labels.
104+
perplexity
105+
Measure of the effective number of neighbors.
106+
tol
107+
Tolerance for binary search.
108+
109+
Returns
110+
-------
111+
simpson_index
112+
Simpson index of size (n_cells,).
113+
"""
114+
knn_dists = jnp.array(knn_dists)
115+
knn_idx = jnp.array(knn_idx)
116+
labels = jnp.array(labels)
117+
n = knn_dists.shape[0]
118+
return jax.device_get(
119+
jax.vmap(
120+
lambda i: _compute_simpson_index_cell(knn_dists[i, :], knn_idx[i, :], labels, n_labels, perplexity, tol)
121+
)(jnp.arange(n))
122+
)

tests/test_basic.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
1+
import sys
2+
13
import jax.numpy as jnp
24
import numpy as np
5+
import pandas as pd
6+
from harmonypy import compute_lisi as harmonypy_lisi
7+
from scipy.sparse import csr_matrix
38
from scipy.spatial.distance import cdist as sp_cdist
49
from sklearn.metrics import silhouette_samples as sk_silhouette_samples
10+
from sklearn.neighbors import NearestNeighbors
511

612
import scib_metrics
713

14+
sys.path.append("../src/")
15+
816

917
def dummy_x_labels(return_symmetric_positive=False):
10-
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
18+
np.random.seed(1)
19+
X = np.random.normal(size=(100, 10))
20+
labels = np.random.randint(0, 2, size=(100,))
1121
if return_symmetric_positive:
1222
X = np.abs(X @ X.T)
13-
labels = np.array([0, 0, 1, 1, 0, 1])
1423
return X, labels
1524

1625

1726
def dummy_x_labels_batch():
1827
X, labels = dummy_x_labels()
19-
batch = np.array([0, 1, 0, 1, 0, 1])
28+
batch = np.random.randint(0, 2, size=(100,))
2029
return X, labels, batch
2130

2231

@@ -32,7 +41,7 @@ def test_cdist():
3241

3342
def test_silhouette_samples():
3443
X, labels = dummy_x_labels()
35-
assert np.allclose(scib_metrics.utils.silhouette_samples(X, labels), sk_silhouette_samples(X, labels))
44+
assert np.allclose(scib_metrics.utils.silhouette_samples(X, labels), sk_silhouette_samples(X, labels), atol=1e-5)
3645

3746

3847
def test_silhouette_label():
@@ -49,6 +58,39 @@ def test_silhouette_batch():
4958
scib_metrics.silhouette_batch(X, labels, batch)
5059

5160

61+
def test_compute_simpson_index():
62+
X, labels = dummy_x_labels()
63+
D = scib_metrics.utils.cdist(X, X)
64+
nbrs = NearestNeighbors(n_neighbors=30, algorithm="kd_tree").fit(X)
65+
D, knn_idx = nbrs.kneighbors(X)
66+
scib_metrics.utils.compute_simpson_index(
67+
jnp.array(D), jnp.array(knn_idx), jnp.array(labels), len(np.unique(labels))
68+
)
69+
70+
71+
def test_lisi_knn():
72+
X, labels = dummy_x_labels()
73+
dist_mat = csr_matrix(scib_metrics.utils.cdist(X, X))
74+
nbrs = NearestNeighbors(n_neighbors=30, algorithm="kd_tree").fit(X)
75+
knn_graph = nbrs.kneighbors_graph(X)
76+
knn_graph = knn_graph.multiply(dist_mat)
77+
lisi_res = scib_metrics.lisi_knn(knn_graph, labels, perplexity=10)
78+
harmonypy_lisi_res = harmonypy_lisi(
79+
X, pd.DataFrame(labels, columns=["labels"]), label_colnames=["labels"], perplexity=10
80+
)[:, 0]
81+
assert np.allclose(lisi_res, harmonypy_lisi_res)
82+
83+
84+
def test_ilisi_clisi_knn():
85+
X, labels, batches = dummy_x_labels_batch()
86+
dist_mat = csr_matrix(scib_metrics.utils.cdist(X, X))
87+
nbrs = NearestNeighbors(n_neighbors=30, algorithm="kd_tree").fit(X)
88+
knn_graph = nbrs.kneighbors_graph(X)
89+
knn_graph = knn_graph.multiply(dist_mat)
90+
scib_metrics.ilisi_knn(knn_graph, batches, perplexity=10)
91+
scib_metrics.clisi_knn(knn_graph, labels, perplexity=10)
92+
93+
5294
def test_nmi_ari_cluster_labels_kmeans():
5395
X, labels = dummy_x_labels()
5496
nmi, ari = scib_metrics.nmi_ari_cluster_labels_kmeans(X, labels)

0 commit comments

Comments
 (0)