Skip to content

Commit 8f30be6

Browse files
Backport PR #4082 on branch 1.12.x (fix: limit Numba threads in Wilcoxon path of rank_genes_groups) (#4088)
Co-authored-by: Jhonatan Felix <108437587+JhonatanFelix@users.noreply.github.com>
1 parent 746a05b commit 8f30be6

5 files changed

Lines changed: 100 additions & 10 deletions

File tree

src/scanpy/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .. import logging as logg
3737
from .._compat import CSBase, DaskArray, _CSArray, pkg_version, warn
3838
from .._settings import settings
39+
from ._numba import _numba_thread_limit
3940

4041
if TYPE_CHECKING:
4142
from collections.abc import Callable, Iterable, KeysView, Mapping
@@ -61,6 +62,7 @@
6162
"_choose_graph",
6263
"_doc_params",
6364
"_empty",
65+
"_numba_thread_limit",
6466
"_resolve_axis",
6567
"annotate_doc_types",
6668
"axis_mul_or_truediv",

src/scanpy/_utils/_numba.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING
5+
6+
import numba
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Generator
10+
11+
12+
@contextmanager
13+
def _numba_thread_limit(n_threads: int | None) -> Generator[None]:
14+
"""Temporarily set Numba's thread count and restore it on exit."""
15+
if n_threads is None:
16+
yield
17+
return
18+
19+
previous = numba.get_num_threads()
20+
n_threads = max(1, min(n_threads, numba.config.NUMBA_NUM_THREADS))
21+
numba.set_num_threads(n_threads)
22+
try:
23+
yield
24+
finally:
25+
numba.set_num_threads(previous)

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from fast_array_utils.stats import mean_var
1212
from scipy import sparse
1313

14-
from .. import _utils
14+
from .. import _utils, settings
1515
from .. import logging as logg
1616
from .._compat import CSBase, old_positionals
1717
from .._utils import (
18+
_numba_thread_limit,
1819
check_nonnegative_integers,
1920
get_literal_vals,
2021
raise_not_implemented_error_if_backed_type,
@@ -714,14 +715,15 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915
714715
logg.debug(f"consider {groupby!r} groups:")
715716
logg.debug(f"with sizes: {np.count_nonzero(test_obj.groups_masks_obs, axis=1)}")
716717

717-
test_obj.compute_statistics(
718-
method,
719-
corr_method=corr_method,
720-
n_genes_user=n_genes_user,
721-
rankby_abs=rankby_abs,
722-
tie_correct=tie_correct,
723-
**kwds,
724-
)
718+
with _numba_thread_limit(settings.n_jobs if method == "wilcoxon" else None):
719+
test_obj.compute_statistics(
720+
method,
721+
corr_method=corr_method,
722+
n_genes_user=n_genes_user,
723+
rankby_abs=rankby_abs,
724+
tie_correct=tie_correct,
725+
**kwds,
726+
)
725727

726728
if test_obj.pts is not None:
727729
groups_names = [str(name) for name in test_obj.groups_order]

tests/test_rank_genes_groups.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, TypedDict, cast
66

7+
import numba
78
import numpy as np
89
import pandas as pd
910
import pytest
@@ -254,13 +255,41 @@ def test_wilcoxon_tie_correction(*, reference: bool) -> None:
254255
np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals, atol=1e-5)
255256

256257

257-
def test_wilcoxon_huge_data(monkeypatch):
258+
def test_wilcoxon_huge_data(monkeypatch: pytest.MonkeyPatch) -> None:
258259
max_size = 300
259260
adata = pbmc68k_reduced()
260261
monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size)
261262
rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon")
262263

263264

265+
@pytest.mark.parametrize(
266+
"method",
267+
[
268+
pytest.param(
269+
"t-test", marks=pytest.mark.xfail(reason="t-test doesn’t use numba (yet)")
270+
),
271+
"wilcoxon",
272+
],
273+
)
274+
def test_set_numba_threads_from_settings(
275+
monkeypatch: pytest.MonkeyPatch, method: Literal["t-test", "wilcoxon"]
276+
) -> None:
277+
was_set_to = []
278+
old_n_jobs = sc.settings.n_jobs
279+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
280+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
281+
282+
try:
283+
sc.settings.n_jobs = 2
284+
adata = get_example_data(np.asarray)
285+
rank_genes_groups(adata, "true_groups", n_genes=5, method=method)
286+
finally:
287+
sc.settings.n_jobs = old_n_jobs
288+
289+
assert 2 in was_set_to, "Wilcoxon path did not use scanpy.settings.n_jobs."
290+
assert was_set_to[-1] == 8
291+
292+
264293
@pytest.mark.parametrize(
265294
("n_genes_add", "n_genes_out_add"),
266295
[pytest.param(0, 0, id="equal"), pytest.param(2, 1, id="more")],

tests/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
import itertools
44
import string
5+
from contextlib import suppress
56
from operator import mul, truediv
67
from types import ModuleType
78
from typing import TYPE_CHECKING
89

10+
import numba
911
import numpy as np
1012
import pytest
1113
from anndata.tests.helpers import asarray
1214
from scipy import sparse
1315

1416
from scanpy._compat import CSBase, DaskArray
1517
from scanpy._utils import (
18+
_numba_thread_limit,
1619
axis_mul_or_truediv,
1720
check_nonnegative_integers,
1821
descend_classes_and_funcs,
@@ -245,3 +248,32 @@ def test_random_str() -> None:
245248
assert strings.dtype == np.dtype("U2")
246249
unique = np.unique(strings, axis=0)
247250
assert len(unique) == len(strings)
251+
252+
253+
@pytest.mark.parametrize("success", [True, False], ids=["success", "exception"])
254+
def test_numba_thread_limit_restores_previous_value(
255+
*, monkeypatch: pytest.MonkeyPatch, success: bool
256+
) -> None:
257+
was_set_to = []
258+
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
259+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
260+
261+
with suppress(RuntimeError), _numba_thread_limit(2):
262+
if not success:
263+
raise RuntimeError
264+
265+
assert was_set_to == [2, 8]
266+
267+
268+
def test_numba_thread_limit_clamps_to_configured_maximum(
269+
monkeypatch: pytest.MonkeyPatch,
270+
) -> None:
271+
was_set_to = []
272+
monkeypatch.setattr(numba, "get_num_threads", lambda: 3)
273+
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
274+
monkeypatch.setattr(numba.config, "NUMBA_NUM_THREADS", 4)
275+
276+
with _numba_thread_limit(99):
277+
pass
278+
279+
assert was_set_to == [4, 3]

0 commit comments

Comments
 (0)