99import pandas as pd
1010import pytest
1111import threadpoolctl
12+ from anndata import AnnData
1213from scipy import sparse
1314
1415import scanpy as sc
@@ -79,10 +80,12 @@ def test_consistency(metric) -> None:
7980 pytest .param (sc .metrics .morans_i , 50 , 1.0 , id = "morans_i" ),
8081 ],
8182)
82- def test_correctness (metric , size , expected ):
83+ def test_correctness (metric , size , expected ) -> None :
84+ rng = np .random .default_rng ()
85+
8386 # Test case with perfectly seperated groups
8487 connected = np .zeros (100 )
85- connected [np . random .choice (100 , size = size , replace = False )] = 1
88+ connected [rng .choice (100 , size = size , replace = False )] = 1
8689 graph = np .zeros ((100 , 100 ))
8790 graph [np .ix_ (connected .astype (bool ), connected .astype (bool ))] = 1
8891 graph [np .ix_ (~ connected .astype (bool ), ~ connected .astype (bool ))] = 1
@@ -93,9 +96,6 @@ def test_correctness(metric, size, expected):
9396 metric (graph , connected ),
9497 metric (graph , sparse .csr_matrix (connected )), # noqa: TID251
9598 )
96- # Checking that obsp works
97- adata = sc .AnnData (sparse .csr_matrix ((100 , 100 )), obsp = {"connectivities" : graph }) # noqa: TID251
98- np .testing .assert_equal (metric (adata , vals = connected ), expected )
9999
100100
101101@pytest .mark .usefixtures ("_threading" )
@@ -104,18 +104,20 @@ def test_correctness(metric, size, expected):
104104)
105105def test_graph_metrics_w_constant_values (
106106 request : pytest .FixtureRequest , metric , array_type
107- ):
107+ ) -> None :
108108 if "dask" in array_type .__name__ :
109109 reason = "DaskArray not yet supported"
110110 request .applymarker (pytest .mark .xfail (reason = reason ))
111111
112+ rng = np .random .default_rng ()
113+
112114 # https://github.com/scverse/scanpy/issues/1806
113115 pbmc = pbmc68k_reduced ()
114116 x_t = pbmc .raw .X .T .copy ()
115117 g = pbmc .obsp ["connectivities" ].copy ()
116118 equality_check = partial (np .testing .assert_allclose , atol = 1e-11 )
117119
118- const_inds = np . random .choice (x_t .shape [0 ], 10 , replace = False )
120+ const_inds = rng .choice (x_t .shape [0 ], 10 , replace = False )
119121 with warnings .catch_warnings ():
120122 warnings .simplefilter ("ignore" , sparse .SparseEfficiencyWarning )
121123 x_t_zero_vals = x_t .copy ()
@@ -145,6 +147,43 @@ def test_graph_metrics_w_constant_values(
145147 equality_check (results_full [non_const_mask ], results_const_zeros [non_const_mask ])
146148
147149
150+ @pytest .mark .parametrize (
151+ ("neigh_params" , "metric_params" ),
152+ [
153+ pytest .param (
154+ dict (key_added = "foo" ), dict (use_graph = "foo_connectivities" ), id = "use_graph"
155+ ),
156+ pytest .param (
157+ dict (key_added = "bar" ), dict (neighbors_key = "bar" ), id = "neighbors_key"
158+ ),
159+ ],
160+ )
161+ def test_metrics_graph_params (metric , neigh_params , metric_params ) -> None :
162+ rng = np .random .default_rng ()
163+ adata = AnnData (rng .normal (size = (10 , 20 )))
164+ sc .pp .neighbors (adata , ** neigh_params )
165+ if "use_graph" in metric_params : # make sure no extra stuff is there
166+ adata = AnnData (adata .X , obsp = adata .obsp )
167+ metric (adata , ** metric_params )
168+
169+
170+ @pytest .mark .parametrize (
171+ ("params" , "err_cls" , "pattern" ),
172+ [
173+ pytest .param (
174+ dict (use_graph = "foo" , neighbors_key = "bar" ), TypeError , r"both" , id = "both"
175+ ),
176+ pytest .param (dict (use_graph = "foo" ), KeyError , r"foo" , id = "no_graph" ),
177+ pytest .param (dict (neighbors_key = "bar" ), KeyError , r"bar" , id = "no_key" ),
178+ pytest .param ({}, KeyError , r"neighbors.*uns" , id = "nothing" ),
179+ ],
180+ )
181+ def test_metrics_graph_params_errors (metric , params , err_cls , pattern ) -> None :
182+ adata = AnnData (shape = (10 , 20 ))
183+ with pytest .raises (err_cls , match = pattern ):
184+ metric (adata , ** params )
185+
186+
148187def test_confusion_matrix ():
149188 mtx = sc .metrics .confusion_matrix (["a" , "b" ], ["c" , "d" ], normalize = False )
150189 assert mtx .loc ["a" , "c" ] == 1
@@ -184,10 +223,12 @@ def test_confusion_matrix_randomized() -> None:
184223 )
185224
186225
187- def test_confusion_matrix_api ():
226+ def test_confusion_matrix_api () -> None :
227+ rng = np .random .default_rng ()
228+
188229 data = pd .DataFrame ({
189- "a" : np . random . randint (5 , size = 100 ),
190- "b" : np . random . randint (5 , size = 100 ),
230+ "a" : rng . integers (5 , size = 100 ),
231+ "b" : rng . integers (5 , size = 100 ),
191232 })
192233 expected = sc .metrics .confusion_matrix (data ["a" ], data ["b" ])
193234
0 commit comments