Skip to content

Commit 76b625e

Browse files
authored
ci: fix and deduplicate rank_gene_groups tests (#3966)
1 parent 4c7c5c0 commit 76b625e

7 files changed

Lines changed: 64 additions & 105 deletions

File tree

.github/workflows/benchmark.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ jobs:
4646
key: benchmark-state-${{ hashFiles('benchmarks/**') }}
4747

4848
- name: Install dependencies
49-
run: pip install 'asv>=0.6.4' py-rattler
49+
# https://github.com/airspeed-velocity/asv/issues/1577
50+
run: pip install 'asv>=0.6.4' 'py-rattler<0.22'
5051

5152
- name: Configure ASV
5253
working-directory: ${{ env.ASV_DIR }}

src/scanpy/metrics/_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def confusion_matrix(
7979
mtx = _confusion_matrix(orig, new, labels=unique_labels)
8080
if normalize:
8181
sums = mtx.sum(axis=1)[:, np.newaxis]
82-
mtx = np.divide(mtx, sums, where=sums != 0)
82+
mtx = mtx.astype(np.float64)
83+
np.divide(mtx, sums, where=sums != 0, out=mtx)
8384

8485
# Label
8586
orig_name = "Original labels" if orig.name is None else orig.name

tests/_data/objs-t-test.npz

708 Bytes
Binary file not shown.

tests/_data/objs-wilcoxon.npz

708 Bytes
Binary file not shown.

tests/_data/objs_t_test.pkl

-8.61 KB
Binary file not shown.

tests/_data/objs_wilcoxon.pkl

-8.61 KB
Binary file not shown.

tests/test_rank_genes_groups.py

Lines changed: 60 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from __future__ import annotations
22

3-
import pickle
43
from functools import partial
54
from pathlib import Path
6-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, TypedDict, cast
76

87
import numpy as np
98
import pandas as pd
@@ -24,8 +23,9 @@
2423

2524
if TYPE_CHECKING:
2625
from collections.abc import Callable
27-
from typing import Any
26+
from typing import Any, Literal
2827

28+
from numpy.lib.npyio import NpzFile
2929
from numpy.typing import NDArray
3030

3131
HERE = Path(__file__).parent
@@ -59,126 +59,83 @@ def get_example_data(array_type: Callable[[np.ndarray], Any]) -> AnnData:
5959
return adata
6060

6161

62-
def get_true_scores() -> tuple[
63-
NDArray[np.object_],
64-
NDArray[np.object_],
65-
NDArray[np.floating],
66-
NDArray[np.floating],
67-
]:
68-
with (DATA_PATH / "objs_t_test.pkl").open("rb") as f:
69-
true_scores_t_test, true_names_t_test = pickle.load(f)
70-
with (DATA_PATH / "objs_wilcoxon.pkl").open("rb") as f:
71-
true_scores_wilcoxon, true_names_wilcoxon = pickle.load(f)
72-
73-
return (
74-
true_names_t_test,
75-
true_names_wilcoxon,
76-
true_scores_t_test,
77-
true_scores_wilcoxon,
78-
)
62+
class Expected(TypedDict):
63+
names: NDArray[np.str_]
64+
scores: NDArray[np.floating]
65+
66+
67+
def get_true_scores(method: Literal["t-test", "wilcoxon"]) -> Expected:
68+
path = DATA_PATH / f"objs-{method}.npz"
69+
with (
70+
path.open("rb") as f,
71+
cast("NpzFile", np.load(f, allow_pickle=False)) as z,
72+
):
73+
expected = dict(z)
74+
return Expected(names=expected["names"].astype("T"), scores=expected["scores"])
7975

8076

8177
# TODO: Make dask compatible
78+
@pytest.mark.parametrize("method", ["t-test", "wilcoxon"])
8279
@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM)
83-
def test_results(array_type):
80+
def test_results(
81+
subtests: pytest.Subtests, array_type, method: Literal["t-test", "wilcoxon"]
82+
) -> None:
8483
seed(1234)
85-
8684
adata = get_example_data(array_type)
8785
assert adata.raw is None # Assumption for later checks
86+
expected = get_true_scores(method)
87+
# no clue why we did this: https://github.com/scverse/scanpy/commit/7f10fa3138374bbc664776c6aae1c0e05cf2c5cf
88+
n = 7 if method == "wilcoxon" else None
8889

89-
(
90-
true_names_t_test,
91-
true_names_wilcoxon,
92-
true_scores_t_test,
93-
true_scores_wilcoxon,
94-
) = get_true_scores()
95-
96-
rank_genes_groups(adata, "true_groups", n_genes=20, method="t-test")
97-
98-
adata.uns["rank_genes_groups"]["names"] = adata.uns["rank_genes_groups"][
99-
"names"
100-
].astype(true_names_t_test.dtype)
90+
rank_genes_groups(adata, "true_groups", n_genes=20, method=method)
91+
results = adata.uns["rank_genes_groups"]
10192

102-
for name in true_scores_t_test.dtype.names:
103-
assert np.allclose(
104-
true_scores_t_test[name], adata.uns["rank_genes_groups"]["scores"][name]
105-
)
106-
assert np.array_equal(true_names_t_test, adata.uns["rank_genes_groups"]["names"])
107-
assert adata.uns["rank_genes_groups"]["params"]["use_raw"] is False
108-
109-
rank_genes_groups(adata, "true_groups", n_genes=20, method="wilcoxon")
110-
111-
adata.uns["rank_genes_groups"]["names"] = adata.uns["rank_genes_groups"][
112-
"names"
113-
].astype(true_names_wilcoxon.dtype)
114-
115-
for name in true_scores_t_test.dtype.names:
116-
assert np.allclose(
117-
true_scores_wilcoxon[name][:7],
118-
adata.uns["rank_genes_groups"]["scores"][name][:7],
119-
)
120-
assert np.array_equal(
121-
true_names_wilcoxon[:7], adata.uns["rank_genes_groups"]["names"][:7]
122-
)
123-
assert adata.uns["rank_genes_groups"]["params"]["use_raw"] is False
93+
for g in range(expected["names"].shape[0]):
94+
with subtests.test(group=g):
95+
assert np.allclose(expected["scores"][g, :n], results["scores"][str(g)][:n])
96+
assert np.array_equal(
97+
expected["names"][g, :n], results["names"][str(g)][:n]
98+
)
99+
assert results["params"]["use_raw"] is False
124100

125101

102+
@pytest.mark.parametrize("method", ["t-test", "wilcoxon"])
126103
@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM)
127-
def test_results_layers(array_type):
104+
def test_results_layers(
105+
subtests: pytest.Subtests, array_type, method: Literal["t-test", "wilcoxon"]
106+
) -> None:
128107
seed(1234)
129-
130108
adata = get_example_data(array_type)
131109
adata.layers["to_test"] = adata.X.copy()
132110
x = adata.X.tolil() if isinstance(adata.X, CSBase) else adata.X
133111
mask = np.random.randint(0, 2, adata.shape, dtype=bool)
134112
x[mask] = 0
135113
adata.X = array_type(x)
136-
137-
_, _, true_scores_t_test, true_scores_wilcoxon = get_true_scores()
138-
139-
# Wilcoxon
140-
rank_genes_groups(
141-
adata,
142-
"true_groups",
143-
method="wilcoxon",
144-
layer="to_test",
145-
n_genes=20,
146-
)
147-
assert adata.uns["rank_genes_groups"]["params"]["use_raw"] is False
148-
for name in true_scores_t_test.dtype.names:
149-
assert np.allclose(
150-
true_scores_wilcoxon[name][:7],
151-
adata.uns["rank_genes_groups"]["scores"][name][:7],
152-
)
153-
154-
rank_genes_groups(adata, "true_groups", method="wilcoxon", n_genes=20)
155-
for name in true_scores_t_test.dtype.names:
156-
assert not np.allclose(
157-
true_scores_wilcoxon[name][:7],
158-
adata.uns["rank_genes_groups"]["scores"][name][:7],
159-
)
160-
161-
# t-test
162-
rank_genes_groups(
163-
adata,
164-
"true_groups",
165-
method="t-test",
166-
layer="to_test",
167-
use_raw=False,
168-
n_genes=20,
169-
)
170-
for name in true_scores_t_test.dtype.names:
171-
assert np.allclose(
172-
true_scores_t_test[name][:7],
173-
adata.uns["rank_genes_groups"]["scores"][name][:7],
174-
)
175-
176-
rank_genes_groups(adata, "true_groups", method="t-test", n_genes=20)
177-
for name in true_scores_t_test.dtype.names:
178-
assert not np.allclose(
179-
true_scores_t_test[name][:7],
180-
adata.uns["rank_genes_groups"]["scores"][name][:7],
114+
scores = get_true_scores(method)["scores"]
115+
116+
with subtests.test("layer"):
117+
rank_genes_groups(
118+
adata,
119+
"true_groups",
120+
method=method,
121+
layer="to_test",
122+
use_raw=None if method == "wilcoxon" else False,
123+
n_genes=20,
181124
)
125+
assert adata.uns["rank_genes_groups"]["params"]["use_raw"] is False
126+
for g in range(scores.shape[0]):
127+
np.testing.assert_allclose(
128+
scores[g, :7],
129+
adata.uns["rank_genes_groups"]["scores"][str(g)][:7],
130+
rtol=1e-5, # default of np.allclose
131+
)
132+
133+
with subtests.test("X"):
134+
rank_genes_groups(adata, "true_groups", method=method, n_genes=20)
135+
for g in range(scores.shape[0]):
136+
assert not np.allclose(
137+
scores[g, :7], adata.uns["rank_genes_groups"]["scores"][str(g)][:7]
138+
)
182139

183140

184141
def test_rank_genes_groups_use_raw():

0 commit comments

Comments
 (0)