-
Notifications
You must be signed in to change notification settings - Fork 742
Expand file tree
/
Copy pathtest_neighbors_key_added.py
More file actions
91 lines (68 loc) · 2.95 KB
/
test_neighbors_key_added.py
File metadata and controls
91 lines (68 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import annotations
import numpy as np
import pytest
import scanpy as sc
from testing.scanpy._helpers.data import pbmc68k_reduced
from testing.scanpy._pytest.marks import needs
n_neighbors = 5
key = "test"
@pytest.fixture
def adata():
return sc.AnnData(pbmc68k_reduced().X)
def test_neighbors_key_added(adata):
sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0)
sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0, key_added=key)
conns_key = adata.uns[key]["connectivities_key"]
dists_key = adata.uns[key]["distances_key"]
assert adata.uns["neighbors"]["params"] == adata.uns[key]["params"]
assert np.allclose(
adata.obsp["connectivities"].toarray(), adata.obsp[conns_key].toarray()
)
assert np.allclose(
adata.obsp["distances"].toarray(), adata.obsp[dists_key].toarray()
)
def test_neighbors_pca_keys_added_without_previous_pca_run(adata):
assert "pca" not in adata.uns
assert "X_pca" not in adata.obsm
with pytest.warns(
UserWarning,
match=r".*Falling back to preprocessing with `sc.pp.pca` and default params",
):
sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0)
assert "pca" in adata.uns
# test functions with neighbors_key and obsp
@needs.igraph
@needs.leidenalg
@pytest.mark.parametrize("field", ["neighbors_key", "obsp"])
def test_neighbors_key_obsp(adata, field):
adata1 = adata.copy()
sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0)
sc.pp.neighbors(adata1, n_neighbors=n_neighbors, random_state=0, key_added=key)
if field == "neighbors_key":
arg = {field: key}
else:
arg = {field: adata1.uns[key]["connectivities_key"]}
sc.tl.draw_graph(adata, layout="fr", random_state=1)
sc.tl.draw_graph(adata1, layout="fr", random_state=1, **arg)
assert adata.uns["draw_graph"]["params"] == adata1.uns["draw_graph"]["params"]
assert np.allclose(adata.obsm["X_draw_graph_fr"], adata1.obsm["X_draw_graph_fr"])
sc.tl.leiden(adata, random_state=0)
sc.tl.leiden(adata1, random_state=0, **arg)
assert adata.uns["leiden"]["params"] == adata1.uns["leiden"]["params"]
assert np.all(adata.obs["leiden"] == adata1.obs["leiden"])
# no obsp in umap, paga
if field == "neighbors_key":
sc.tl.umap(adata, random_state=0)
sc.tl.umap(adata1, random_state=0, neighbors_key=key)
assert adata.uns["umap"]["params"] == adata1.uns["umap"]["params"]
assert np.allclose(adata.obsm["X_umap"], adata1.obsm["X_umap"])
sc.tl.paga(adata, groups="leiden")
sc.tl.paga(adata1, groups="leiden", neighbors_key=key)
assert np.allclose(
adata.uns["paga"]["connectivities"].toarray(),
adata1.uns["paga"]["connectivities"].toarray(),
)
assert np.allclose(
adata.uns["paga"]["connectivities_tree"].toarray(),
adata1.uns["paga"]["connectivities_tree"].toarray(),
)