Skip to content

Commit c1e67ce

Browse files
fix: limit Numba threads in Wilcoxon path of rank_genes_groups (#4082)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 6204f16 commit c1e67ce

5 files changed

Lines changed: 99 additions & 9 deletions

File tree

src/scanpy/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from .. import logging as logg
3636
from .._compat import CSBase, DaskArray, _CSArray, pkg_version, warn
37+
from ._numba import _numba_thread_limit
3738

3839
if TYPE_CHECKING:
3940
from collections.abc import Callable, Iterable, KeysView, Mapping
@@ -57,6 +58,7 @@
5758
"NeighborsView",
5859
"_choose_graph",
5960
"_doc_params",
61+
"_numba_thread_limit",
6062
"_resolve_axis",
6163
"annotate_doc_types",
6264
"axis_mul_or_truediv",

src/scanpy/_utils/_numba.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING
5+
6+
import numba
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Generator
10+
11+
12+
@contextmanager
13+
def _numba_thread_limit(n_threads: int | None) -> Generator[None]:
14+
"""Temporarily set Numba's thread count and restore it on exit."""
15+
if n_threads is None:
16+
yield
17+
return
18+
19+
previous = numba.get_num_threads()
20+
n_threads = max(1, min(n_threads, numba.config.NUMBA_NUM_THREADS))
21+
numba.set_num_threads(n_threads)
22+
try:
23+
yield
24+
finally:
25+
numba.set_num_threads(previous)

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .._settings import Default
1818
from .._settings.presets import DETest
1919
from .._utils import (
20+
_numba_thread_limit,
2021
check_nonnegative_integers,
2122
get_literal_vals,
2223
raise_not_implemented_error_if_backed_type,
@@ -708,14 +709,15 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915
708709
logg.debug(f"consider {groupby!r} groups:")
709710
logg.debug(f"with sizes: {np.count_nonzero(test_obj.groups_masks_obs, axis=1)}")
710711

711-
test_obj.compute_statistics(
712-
method,
713-
corr_method=corr_method,
714-
n_genes_user=n_genes_user,
715-
rankby_abs=rankby_abs,
716-
tie_correct=tie_correct,
717-
**kwds,
718-
)
712+
with _numba_thread_limit(settings.n_jobs if method == "wilcoxon" else None):
713+
test_obj.compute_statistics(
714+
method,
715+
corr_method=corr_method,
716+
n_genes_user=n_genes_user,
717+
rankby_abs=rankby_abs,
718+
tie_correct=tie_correct,
719+
**kwds,
720+
)
719721

720722
if test_obj.pts is not None:
721723
groups_names = [str(name) for name in test_obj.groups_order]

tests/test_rank_genes_groups.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, TypedDict, cast
66

7+
import numba
78
import numpy as np
89
import pandas as pd
910
import pytest
@@ -254,13 +255,41 @@ def test_wilcoxon_tie_correction(*, reference: bool) -> None:
254255
np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals, atol=1e-5)
255256

256257

257-
def test_wilcoxon_huge_data(monkeypatch):
258+
def test_wilcoxon_huge_data(monkeypatch: pytest.MonkeyPatch) -> None:
258259
max_size = 300
259260
adata = pbmc68k_reduced()
260261
monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size)
261262
rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon")
262263

263264

265+
@pytest.mark.parametrize(
266+
"method",
267+
[
268+
pytest.param(
269+
"t-test", marks=pytest.mark.xfail(reason="t-test doesn’t use numba (yet)")
270+
),
271+
"wilcoxon",
272+
],
273+
)
274+
def test_set_numba_threads_from_settings(
275+
monkeypatch: pytest.MonkeyPatch, method: Literal["t-test", "wilcoxon"]
276+
) -> None:
277+
was_set_to = []
278+
old_n_jobs = sc.settings.n_jobs
279+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
280+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
281+
282+
try:
283+
sc.settings.n_jobs = 2
284+
adata = get_example_data(np.asarray)
285+
rank_genes_groups(adata, "true_groups", n_genes=5, method=method)
286+
finally:
287+
sc.settings.n_jobs = old_n_jobs
288+
289+
assert 2 in was_set_to, "Wilcoxon path did not use scanpy.settings.n_jobs."
290+
assert was_set_to[-1] == 8
291+
292+
264293
@pytest.mark.parametrize(
265294
("n_genes_add", "n_genes_out_add"),
266295
[pytest.param(0, 0, id="equal"), pytest.param(2, 1, id="more")],

tests/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
import itertools
44
import string
5+
from contextlib import suppress
56
from operator import mul, truediv
67
from types import ModuleType
78
from typing import TYPE_CHECKING
89

10+
import numba
911
import numpy as np
1012
import pytest
1113
from anndata.tests.helpers import asarray
1214
from scipy import sparse
1315

1416
from scanpy._compat import CSBase, DaskArray
1517
from scanpy._utils import (
18+
_numba_thread_limit,
1619
axis_mul_or_truediv,
1720
check_nonnegative_integers,
1821
descend_classes_and_funcs,
@@ -240,3 +243,32 @@ def test_random_str() -> None:
240243
assert strings.dtype == np.dtype("U2")
241244
unique = np.unique(strings, axis=0)
242245
assert len(unique) == len(strings)
246+
247+
248+
@pytest.mark.parametrize("success", [True, False], ids=["success", "exception"])
249+
def test_numba_thread_limit_restores_previous_value(
250+
*, monkeypatch: pytest.MonkeyPatch, success: bool
251+
) -> None:
252+
was_set_to = []
253+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
254+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
255+
256+
with suppress(RuntimeError), _numba_thread_limit(2):
257+
if not success:
258+
raise RuntimeError
259+
260+
assert was_set_to == [2, 8]
261+
262+
263+
def test_numba_thread_limit_clamps_to_configured_maximum(
264+
monkeypatch: pytest.MonkeyPatch,
265+
) -> None:
266+
was_set_to = []
267+
monkeypatch.setattr(numba, "get_num_threads", lambda: 3)
268+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
269+
monkeypatch.setattr(numba.config, "NUMBA_NUM_THREADS", 4)
270+
271+
with _numba_thread_limit(99):
272+
pass
273+
274+
assert was_set_to == [4, 3]

0 commit comments

Comments
 (0)