Skip to content

Commit 22845a1

Browse files
authored
feat: allow use_graph and neighbors_key for metrics (#3898)
1 parent 20675cf commit 22845a1

10 files changed

Lines changed: 113 additions & 47 deletions

File tree

docs/conf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import os
6+
import shutil
67
import sys
78
from datetime import datetime
89
from functools import partial
@@ -13,6 +14,7 @@
1314
import matplotlib # noqa
1415
from docutils import nodes
1516
from packaging.version import Version
17+
from sphinxcontrib.katex import NODEJS_BINARY
1618

1719
# Don’t use tkinter agg when importing scanpy → … → matplotlib
1820
matplotlib.use("agg")
@@ -52,7 +54,6 @@
5254
bibtex_bibfiles = ["references.bib"]
5355
bibtex_reference_style = "author_year"
5456

55-
5657
# default settings
5758
templates_path = ["_templates"]
5859
master_doc = "index"
@@ -73,10 +74,10 @@
7374
"sphinx.ext.intersphinx",
7475
"sphinx.ext.doctest",
7576
"sphinx.ext.coverage",
76-
"sphinx.ext.mathjax",
7777
"sphinx.ext.napoleon",
7878
"sphinx.ext.autosummary",
7979
"sphinxcontrib.bibtex",
80+
"sphinxcontrib.katex",
8081
"matplotlib.sphinxext.plot_directive",
8182
"sphinx_autodoc_typehints", # needs to be after napoleon
8283
"git_ref", # needs to be before scanpydoc.rtd_github_links
@@ -129,6 +130,8 @@
129130
pygments_style = "default"
130131
pygments_dark_style = "native"
131132

133+
katex_prerender = shutil.which(NODEJS_BINARY) is not None
134+
132135
intersphinx_mapping = dict(
133136
anndata=("https://anndata.readthedocs.io/en/stable/", None),
134137
bbknn=("https://bbknn.readthedocs.io/en/latest/", None),

docs/release-notes/3898.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow specifying graphs in {mod}`scanpy.metrics` functions {smaller}`P Angerer`

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ doc = [
124124
"nbsphinx>=0.9",
125125
"ipython>=7.20", # for nbsphinx code highlighting
126126
"sphinxcontrib-bibtex",
127+
"sphinxcontrib-katex",
127128
# TODO: remove necessity for being able to import doc-linked classes
128129
"scanpy[paga,dask-ml,leiden]",
129130
"sam-algorithm",

src/scanpy/metrics/_common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010

1111
from .._compat import CSRBase, DaskArray, SpBase, fullname, warn
12+
from .._utils import NeighborsView
1213

1314
if TYPE_CHECKING:
1415
from typing import NoReturn
@@ -65,17 +66,22 @@ def __call__(self) -> np.ndarray:
6566
raise NotImplementedError(msg)
6667

6768

68-
def _get_graph(adata: AnnData, *, use_graph: str | None = None) -> CSRBase:
69+
def _get_graph(
70+
adata: AnnData,
71+
*,
72+
use_graph: str | None = None,
73+
neighbors_key: str | None = None,
74+
) -> CSRBase:
6975
if use_graph is not None:
70-
raise NotImplementedError()
71-
# Fix for anndata<0.7
72-
if hasattr(adata, "obsp") and "connectivities" in adata.obsp:
73-
return adata.obsp["connectivities"]
74-
elif "neighbors" in adata.uns:
75-
return adata.uns["neighbors"]["connectivities"]
76-
else:
76+
if neighbors_key is not None:
77+
msg = "Cannot specify both `use_graph` and `neighbors_key`."
78+
raise TypeError(msg)
79+
return adata.obsp[use_graph]
80+
nv = NeighborsView(adata, neighbors_key)
81+
if "connectivities" not in nv:
7782
msg = "Must run neighbors first."
7883
raise ValueError(msg)
84+
return nv["connectivities"]
7985

8086

8187
@overload

src/scanpy/metrics/_gearys_c.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import numpy as np
1010

1111
from .._compat import CSRBase, njit
12+
from .._utils import _doc_params
1213
from ..get import _get_obs_rep
14+
from ..neighbors._doc import doc_neighbors_key
1315
from ._common import _get_graph, _SparseMetric
1416

1517
if TYPE_CHECKING:
@@ -20,12 +22,14 @@
2022

2123

2224
@singledispatch
25+
@_doc_params(neighbors_key=doc_neighbors_key)
2326
def gearys_c(
2427
adata_or_graph: AnnData | CSRBase,
2528
/,
2629
vals: _Vals | None = None,
2730
*,
2831
use_graph: str | None = None,
32+
neighbors_key: str | None = None,
2933
layer: str | None = None,
3034
obsm: str | None = None,
3135
obsp: str | None = None,
@@ -42,11 +46,11 @@ def gearys_c(
4246
.. math::
4347
4448
C =
45-
\frac{
46-
(N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2
47-
}{
48-
2W \sum_i (x_i - \bar{x})^2
49-
}
49+
\frac{{
50+
(N - 1)\sum_{{i,j}} w_{{i,j}} (x_i - x_j)^2
51+
}}{{
52+
2W \sum_i (x_i - \bar{{x}})^2
53+
}}
5054
5155
Params
5256
------
@@ -60,8 +64,10 @@ def gearys_c(
6064
object by using key word arguments: `layer`, `obsm`, `obsp`, or
6165
`use_raw`.
6266
use_graph
63-
Key to use for graph in anndata object. If not provided, default
64-
neighbors connectivities will be used instead.
67+
Key to use for graph in anndata object.
68+
If not provided, default neighbors connectivities will be used instead.
69+
(See ``neighbors_key`` below.)
70+
{neighbors_key}
6571
layer
6672
Key for `adata.layers` to choose `vals`.
6773
obsm
@@ -96,7 +102,7 @@ def gearys_c(
96102
97103
"""
98104
adata = cast("AnnData", adata_or_graph)
99-
g = _get_graph(adata, use_graph=use_graph)
105+
g = _get_graph(adata, use_graph=use_graph, neighbors_key=neighbors_key)
100106
if vals is None:
101107
vals = _get_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T
102108
return gearys_c(g, vals)

src/scanpy/metrics/_morans_i.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import numpy as np
1010

1111
from .._compat import CSRBase, njit
12+
from .._utils import _doc_params
1213
from ..get import _get_obs_rep
14+
from ..neighbors._doc import doc_neighbors_key
1315
from ._common import _get_graph, _SparseMetric
1416

1517
if TYPE_CHECKING:
@@ -20,12 +22,14 @@
2022

2123

2224
@singledispatch
25+
@_doc_params(neighbors_key=doc_neighbors_key)
2326
def morans_i(
2427
adata_or_graph: AnnData | CSRBase,
2528
/,
2629
vals: _Vals | None = None,
2730
*,
2831
use_graph: str | None = None,
32+
neighbors_key: str | None = None,
2933
layer: str | None = None,
3034
obsm: str | None = None,
3135
obsp: str | None = None,
@@ -40,11 +44,11 @@ def morans_i(
4044
.. math::
4145
4246
I =
43-
\frac{
44-
N \sum_{i, j} w_{i, j} z_{i} z_{j}
45-
}{
46-
S_{0} \sum_{i} z_{i}^{2}
47-
}
47+
\frac{{
48+
N \sum_{{i,j}} w_{{i,j}} z_{{i}} z_{{j}}
49+
}}{{
50+
S_{{0}} \sum_{{i}} z_{{i}}^{{2}}
51+
}}
4852
4953
Params
5054
------
@@ -58,8 +62,10 @@ def morans_i(
5862
object by using key word arguments: `layer`, `obsm`, `obsp`, or
5963
`use_raw`.
6064
use_graph
61-
Key to use for graph in anndata object. If not provided, default
62-
neighbors connectivities will be used instead.
65+
Key to use for graph in anndata object.
66+
If not provided, default neighbors connectivities will be used instead.
67+
(See ``neighbors_key`` below.)
68+
{neighbors_key}
6369
layer
6470
Key for `adata.layers` to choose `vals`.
6571
obsm
@@ -94,7 +100,7 @@ def morans_i(
94100
95101
"""
96102
adata = cast("AnnData", adata_or_graph)
97-
g = _get_graph(adata, use_graph=use_graph)
103+
g = _get_graph(adata, use_graph=use_graph, neighbors_key=neighbors_key)
98104
if vals is None:
99105
vals = _get_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T
100106
return morans_i(g, vals)

src/scanpy/neighbors/_doc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from __future__ import annotations
22

3+
doc_neighbors_key = """\
4+
neighbors_key
5+
Where to look for neighbors connectivities.
6+
If not specified, this retrieves ``.obsp['connectivities']`` for connectivities
7+
(default storage place for :func:`~scanpy.pp.neighbors`).
8+
If specified, this retrieves
9+
``.obsp[.uns[neighbors_key]['connectivities_key']]`` for connectivities.
10+
"""
11+
312
doc_use_rep = """\
413
use_rep
514
Use the indicated representation. `'X'` or any key for `.obsm` is valid.

src/scanpy/plotting/_docs.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from ..neighbors._doc import doc_neighbors_key
6+
57
doc_adata_color_etc = """\
68
adata
79
Annotated data matrix.
@@ -22,19 +24,14 @@
2224
takes precedence over `use_raw`.\
2325
"""
2426

25-
doc_edges_arrows = """\
27+
doc_edges_arrows = f"""\
2628
edges
2729
Show edges.
2830
edges_width
2931
Width of edges.
3032
edges_color
3133
Color of edges. See :func:`~networkx.drawing.nx_pylab.draw_networkx_edges`.
32-
neighbors_key
33-
Where to look for neighbors connectivities.
34-
If not specified, this looks .obsp['connectivities'] for connectivities
35-
(default storage place for pp.neighbors).
36-
If specified, this looks
37-
.obsp[.uns[neighbors_key]['connectivities_key']] for connectivities.
34+
{doc_neighbors_key}
3835
arrows
3936
Show arrows (deprecated in favour of `scvelo.pl.velocity_embedding`).
4037
arrows_kwds

src/scanpy/preprocessing/_highly_variable_genes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,6 @@ def _highly_variable_genes_single_batch(
363363
if n_removed:
364364
x = x[:, filt].copy()
365365

366-
if hasattr(x, "_view_args"): # AnnData array view
367-
# For compatibility with anndata<0.9
368-
x = x.copy() # Doesn't actually copy memory, just removes View class wrapper
369-
370366
if flavor == "seurat":
371367
x = x.copy()
372368
if (base := adata.uns.get("log1p", {}).get("base")) is not None:

tests/test_metrics.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
import pytest
1111
import threadpoolctl
12+
from anndata import AnnData
1213
from scipy import sparse
1314

1415
import scanpy as sc
@@ -79,10 +80,12 @@ def test_consistency(metric) -> None:
7980
pytest.param(sc.metrics.morans_i, 50, 1.0, id="morans_i"),
8081
],
8182
)
82-
def test_correctness(metric, size, expected):
83+
def test_correctness(metric, size, expected) -> None:
84+
rng = np.random.default_rng()
85+
8386
# Test case with perfectly seperated groups
8487
connected = np.zeros(100)
85-
connected[np.random.choice(100, size=size, replace=False)] = 1
88+
connected[rng.choice(100, size=size, replace=False)] = 1
8689
graph = np.zeros((100, 100))
8790
graph[np.ix_(connected.astype(bool), connected.astype(bool))] = 1
8891
graph[np.ix_(~connected.astype(bool), ~connected.astype(bool))] = 1
@@ -93,9 +96,6 @@ def test_correctness(metric, size, expected):
9396
metric(graph, connected),
9497
metric(graph, sparse.csr_matrix(connected)), # noqa: TID251
9598
)
96-
# Checking that obsp works
97-
adata = sc.AnnData(sparse.csr_matrix((100, 100)), obsp={"connectivities": graph}) # noqa: TID251
98-
np.testing.assert_equal(metric(adata, vals=connected), expected)
9999

100100

101101
@pytest.mark.usefixtures("_threading")
@@ -104,18 +104,20 @@ def test_correctness(metric, size, expected):
104104
)
105105
def test_graph_metrics_w_constant_values(
106106
request: pytest.FixtureRequest, metric, array_type
107-
):
107+
) -> None:
108108
if "dask" in array_type.__name__:
109109
reason = "DaskArray not yet supported"
110110
request.applymarker(pytest.mark.xfail(reason=reason))
111111

112+
rng = np.random.default_rng()
113+
112114
# https://github.com/scverse/scanpy/issues/1806
113115
pbmc = pbmc68k_reduced()
114116
x_t = pbmc.raw.X.T.copy()
115117
g = pbmc.obsp["connectivities"].copy()
116118
equality_check = partial(np.testing.assert_allclose, atol=1e-11)
117119

118-
const_inds = np.random.choice(x_t.shape[0], 10, replace=False)
120+
const_inds = rng.choice(x_t.shape[0], 10, replace=False)
119121
with warnings.catch_warnings():
120122
warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning)
121123
x_t_zero_vals = x_t.copy()
@@ -145,6 +147,43 @@ def test_graph_metrics_w_constant_values(
145147
equality_check(results_full[non_const_mask], results_const_zeros[non_const_mask])
146148

147149

150+
@pytest.mark.parametrize(
151+
("neigh_params", "metric_params"),
152+
[
153+
pytest.param(
154+
dict(key_added="foo"), dict(use_graph="foo_connectivities"), id="use_graph"
155+
),
156+
pytest.param(
157+
dict(key_added="bar"), dict(neighbors_key="bar"), id="neighbors_key"
158+
),
159+
],
160+
)
161+
def test_metrics_graph_params(metric, neigh_params, metric_params) -> None:
162+
rng = np.random.default_rng()
163+
adata = AnnData(rng.normal(size=(10, 20)))
164+
sc.pp.neighbors(adata, **neigh_params)
165+
if "use_graph" in metric_params: # make sure no extra stuff is there
166+
adata = AnnData(adata.X, obsp=adata.obsp)
167+
metric(adata, **metric_params)
168+
169+
170+
@pytest.mark.parametrize(
171+
("params", "err_cls", "pattern"),
172+
[
173+
pytest.param(
174+
dict(use_graph="foo", neighbors_key="bar"), TypeError, r"both", id="both"
175+
),
176+
pytest.param(dict(use_graph="foo"), KeyError, r"foo", id="no_graph"),
177+
pytest.param(dict(neighbors_key="bar"), KeyError, r"bar", id="no_key"),
178+
pytest.param({}, KeyError, r"neighbors.*uns", id="nothing"),
179+
],
180+
)
181+
def test_metrics_graph_params_errors(metric, params, err_cls, pattern) -> None:
182+
adata = AnnData(shape=(10, 20))
183+
with pytest.raises(err_cls, match=pattern):
184+
metric(adata, **params)
185+
186+
148187
def test_confusion_matrix():
149188
mtx = sc.metrics.confusion_matrix(["a", "b"], ["c", "d"], normalize=False)
150189
assert mtx.loc["a", "c"] == 1
@@ -184,10 +223,12 @@ def test_confusion_matrix_randomized() -> None:
184223
)
185224

186225

187-
def test_confusion_matrix_api():
226+
def test_confusion_matrix_api() -> None:
227+
rng = np.random.default_rng()
228+
188229
data = pd.DataFrame({
189-
"a": np.random.randint(5, size=100),
190-
"b": np.random.randint(5, size=100),
230+
"a": rng.integers(5, size=100),
231+
"b": rng.integers(5, size=100),
191232
})
192233
expected = sc.metrics.confusion_matrix(data["a"], data["b"])
193234

0 commit comments

Comments
 (0)