Skip to content

Commit 6be0d40

Browse files
Merge pull request #1 from scverse/fix-rgg-n_jobs
move to utils, slimmer tests
2 parents 9a1672c + f9ce049 commit 6be0d40

5 files changed

Lines changed: 81 additions & 80 deletions

File tree

src/scanpy/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from .. import logging as logg
3636
from .._compat import CSBase, DaskArray, _CSArray, pkg_version, warn
37+
from ._numba import _numba_thread_limit
3738

3839
if TYPE_CHECKING:
3940
from collections.abc import Callable, Iterable, KeysView, Mapping
@@ -57,6 +58,7 @@
5758
"NeighborsView",
5859
"_choose_graph",
5960
"_doc_params",
61+
"_numba_thread_limit",
6062
"_resolve_axis",
6163
"annotate_doc_types",
6264
"axis_mul_or_truediv",

src/scanpy/_utils/_numba.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING
5+
6+
import numba
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Generator
10+
11+
12+
@contextmanager
13+
def _numba_thread_limit(n_threads: int | None) -> Generator[None]:
14+
"""Temporarily set Numba's thread count and restore it on exit."""
15+
if n_threads is None:
16+
yield
17+
return
18+
19+
previous = numba.get_num_threads()
20+
n_threads = max(1, min(n_threads, numba.config.NUMBA_NUM_THREADS))
21+
numba.set_num_threads(n_threads)
22+
try:
23+
yield
24+
finally:
25+
numba.set_num_threads(previous)

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from contextlib import contextmanager
65
from typing import TYPE_CHECKING
76

87
import numba
@@ -18,6 +17,7 @@
1817
from .._settings import Default
1918
from .._settings.presets import DETest
2019
from .._utils import (
20+
_numba_thread_limit,
2121
check_nonnegative_integers,
2222
get_literal_vals,
2323
raise_not_implemented_error_if_backed_type,
@@ -47,22 +47,6 @@ def _select_top_n(scores: NDArray, n_top: int):
4747
return global_indices
4848

4949

50-
@contextmanager
51-
def _numba_thread_limit(n_threads: int | None) -> Generator[None, None, None]:
52-
"""Temporarily set Numba's thread count and restore it on exit."""
53-
if n_threads is None:
54-
yield
55-
return
56-
57-
previous = numba.get_num_threads()
58-
n_threads = max(1, min(n_threads, numba.config.NUMBA_NUM_THREADS))
59-
numba.set_num_threads(n_threads)
60-
try:
61-
yield
62-
finally:
63-
numba.set_num_threads(previous)
64-
65-
6650
@njit
6751
def rankdata(data: NDArray[np.number]) -> NDArray[np.float64]:
6852
"""Parallelized version of scipy.stats.rankdata."""

tests/test_rank_genes_groups.py

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, TypedDict, cast
66

7+
import numba
78
import numpy as np
89
import pandas as pd
910
import pytest
@@ -16,7 +17,7 @@
1617
from scanpy._utils import select_groups
1718
from scanpy.get import rank_genes_groups_df
1819
from scanpy.tools import rank_genes_groups
19-
from scanpy.tools._rank_genes_groups import _numba_thread_limit, _RankGenes
20+
from scanpy.tools._rank_genes_groups import _RankGenes
2021
from testing.scanpy._helpers import random_mask
2122
from testing.scanpy._helpers.data import pbmc68k_reduced
2223
from testing.scanpy._pytest.params import ARRAY_TYPES, ARRAY_TYPES_MEM
@@ -254,82 +255,39 @@ def test_wilcoxon_tie_correction(*, reference: bool) -> None:
254255
np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals, atol=1e-5)
255256

256257

257-
def test_wilcoxon_huge_data(monkeypatch):
258+
def test_wilcoxon_huge_data(monkeypatch: pytest.MonkeyPatch) -> None:
258259
max_size = 300
259260
adata = pbmc68k_reduced()
260261
monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size)
261262
rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon")
262263

263264

264-
def test_numba_thread_limit_restores_previous_value(monkeypatch):
265-
calls = []
266-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8)
267-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append)
268-
269-
with _numba_thread_limit(2):
270-
pass
271-
272-
assert calls == [2, 8]
273-
274-
275-
def test_numba_thread_limit_restores_previous_value_on_exception(monkeypatch):
276-
calls = []
277-
msg = "synthetic failure"
278-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8)
279-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append)
280-
281-
with (
282-
pytest.raises(RuntimeError, match=msg),
283-
_numba_thread_limit(2),
284-
):
285-
raise RuntimeError(msg)
286-
287-
assert calls == [2, 8]
288-
289-
290-
def test_numba_thread_limit_clamps_to_configured_maximum(monkeypatch):
291-
calls = []
292-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 3)
293-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append)
294-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba.config, "NUMBA_NUM_THREADS", 4)
295-
296-
with _numba_thread_limit(99):
297-
pass
298-
299-
assert calls[0] == 4
300-
assert calls[-1] == 3
301-
302-
303-
def test_wilcoxon_sets_numba_threads_from_settings(monkeypatch):
304-
calls = []
305-
old_n_jobs = sc.settings.n_jobs
306-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8)
307-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append)
308-
309-
try:
310-
sc.settings.n_jobs = 2
311-
adata = get_example_data(np.asarray)
312-
rank_genes_groups(adata, "true_groups", n_genes=5, method="wilcoxon")
313-
finally:
314-
sc.settings.n_jobs = old_n_jobs
315-
316-
assert 2 in calls, "Wilcoxon path did not use scanpy.settings.n_jobs."
317-
assert calls[-1] == 8
318-
319-
320-
def test_t_test_does_not_set_numba_threads_from_settings(monkeypatch):
321-
calls = []
265+
@pytest.mark.parametrize(
266+
"method",
267+
[
268+
pytest.param(
269+
"t-test", marks=pytest.mark.xfail(reason="t-test doesn’t use numba (yet)")
270+
),
271+
"wilcoxon",
272+
],
273+
)
274+
def test_set_numba_threads_from_settings(
275+
monkeypatch: pytest.MonkeyPatch, method: Literal["t-test", "wilcoxon"]
276+
) -> None:
277+
was_set_to = []
322278
old_n_jobs = sc.settings.n_jobs
323-
monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append)
279+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
280+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
324281

325282
try:
326283
sc.settings.n_jobs = 2
327284
adata = get_example_data(np.asarray)
328-
rank_genes_groups(adata, "true_groups", n_genes=5, method="t-test")
285+
rank_genes_groups(adata, "true_groups", n_genes=5, method=method)
329286
finally:
330287
sc.settings.n_jobs = old_n_jobs
331288

332-
assert calls == []
289+
assert 2 in was_set_to, "Wilcoxon path did not use scanpy.settings.n_jobs."
290+
assert was_set_to[-1] == 8
333291

334292

335293
@pytest.mark.parametrize(

tests/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
import itertools
44
import string
5+
from contextlib import suppress
56
from operator import mul, truediv
67
from types import ModuleType
78
from typing import TYPE_CHECKING
89

10+
import numba
911
import numpy as np
1012
import pytest
1113
from anndata.tests.helpers import asarray
1214
from scipy import sparse
1315

1416
from scanpy._compat import CSBase, DaskArray
1517
from scanpy._utils import (
18+
_numba_thread_limit,
1619
axis_mul_or_truediv,
1720
check_nonnegative_integers,
1821
descend_classes_and_funcs,
@@ -240,3 +243,32 @@ def test_random_str() -> None:
240243
assert strings.dtype == np.dtype("U2")
241244
unique = np.unique(strings, axis=0)
242245
assert len(unique) == len(strings)
246+
247+
248+
@pytest.mark.parametrize("success", [True, False], ids=["success", "exception"])
249+
def test_numba_thread_limit_restores_previous_value(
250+
*, monkeypatch: pytest.MonkeyPatch, success: bool
251+
) -> None:
252+
was_set_to = []
253+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
254+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
255+
256+
with suppress(RuntimeError), _numba_thread_limit(2):
257+
if not success:
258+
raise RuntimeError
259+
260+
assert was_set_to == [2, 8]
261+
262+
263+
def test_numba_thread_limit_clamps_to_configured_maximum(
264+
monkeypatch: pytest.MonkeyPatch,
265+
) -> None:
266+
was_set_to = []
267+
monkeypatch.setattr(numba, "get_num_threads", lambda: 3)
268+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
269+
monkeypatch.setattr(numba.config, "NUMBA_NUM_THREADS", 4)
270+
271+
with _numba_thread_limit(99):
272+
pass
273+
274+
assert was_set_to == [4, 3]

0 commit comments

Comments
 (0)