Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .. import logging as logg
from .._compat import CSBase, DaskArray, _CSArray, pkg_version, warn
from ._numba import _numba_thread_limit

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, KeysView, Mapping
Expand All @@ -57,6 +58,7 @@
"NeighborsView",
"_choose_graph",
"_doc_params",
"_numba_thread_limit",
"_resolve_axis",
"annotate_doc_types",
"axis_mul_or_truediv",
Expand Down
25 changes: 25 additions & 0 deletions src/scanpy/_utils/_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING

import numba

if TYPE_CHECKING:
from collections.abc import Generator


@contextmanager
def _numba_thread_limit(n_threads: int | None) -> Generator[None]:
"""Temporarily set Numba's thread count and restore it on exit."""
if n_threads is None:
yield
return

previous = numba.get_num_threads()
n_threads = max(1, min(n_threads, numba.config.NUMBA_NUM_THREADS))
numba.set_num_threads(n_threads)
try:
yield
finally:
numba.set_num_threads(previous)
18 changes: 10 additions & 8 deletions src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .._settings import Default
from .._settings.presets import DETest
from .._utils import (
_numba_thread_limit,
check_nonnegative_integers,
get_literal_vals,
raise_not_implemented_error_if_backed_type,
Expand Down Expand Up @@ -708,14 +709,15 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915
logg.debug(f"consider {groupby!r} groups:")
logg.debug(f"with sizes: {np.count_nonzero(test_obj.groups_masks_obs, axis=1)}")

test_obj.compute_statistics(
method,
corr_method=corr_method,
n_genes_user=n_genes_user,
rankby_abs=rankby_abs,
tie_correct=tie_correct,
**kwds,
)
with _numba_thread_limit(settings.n_jobs if method == "wilcoxon" else None):
test_obj.compute_statistics(
method,
corr_method=corr_method,
n_genes_user=n_genes_user,
rankby_abs=rankby_abs,
tie_correct=tie_correct,
**kwds,
)

if test_obj.pts is not None:
groups_names = [str(name) for name in test_obj.groups_order]
Expand Down
31 changes: 30 additions & 1 deletion tests/test_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict, cast

import numba
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -254,13 +255,41 @@ def test_wilcoxon_tie_correction(*, reference: bool) -> None:
np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals, atol=1e-5)


def test_wilcoxon_huge_data(monkeypatch):
def test_wilcoxon_huge_data(monkeypatch: pytest.MonkeyPatch) -> None:
max_size = 300
adata = pbmc68k_reduced()
monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size)
rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon")


@pytest.mark.parametrize(
"method",
[
pytest.param(
"t-test", marks=pytest.mark.xfail(reason="t-test doesn’t use numba (yet)")
),
"wilcoxon",
],
)
def test_set_numba_threads_from_settings(
monkeypatch: pytest.MonkeyPatch, method: Literal["t-test", "wilcoxon"]
) -> None:
was_set_to = []
old_n_jobs = sc.settings.n_jobs
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)

try:
sc.settings.n_jobs = 2
adata = get_example_data(np.asarray)
rank_genes_groups(adata, "true_groups", n_genes=5, method=method)
finally:
sc.settings.n_jobs = old_n_jobs

assert 2 in was_set_to, "Wilcoxon path did not use scanpy.settings.n_jobs."
assert was_set_to[-1] == 8


@pytest.mark.parametrize(
("n_genes_add", "n_genes_out_add"),
[pytest.param(0, 0, id="equal"), pytest.param(2, 1, id="more")],
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

import itertools
import string
from contextlib import suppress
from operator import mul, truediv
from types import ModuleType
from typing import TYPE_CHECKING

import numba
import numpy as np
import pytest
from anndata.tests.helpers import asarray
from scipy import sparse

from scanpy._compat import CSBase, DaskArray
from scanpy._utils import (
_numba_thread_limit,
axis_mul_or_truediv,
check_nonnegative_integers,
descend_classes_and_funcs,
Expand Down Expand Up @@ -240,3 +243,32 @@ def test_random_str() -> None:
assert strings.dtype == np.dtype("U2")
unique = np.unique(strings, axis=0)
assert len(unique) == len(strings)


@pytest.mark.parametrize("success", [True, False], ids=["success", "exception"])
def test_numba_thread_limit_restores_previous_value(
*, monkeypatch: pytest.MonkeyPatch, success: bool
) -> None:
was_set_to = []
monkeypatch.setattr(numba, "get_num_threads", lambda: 8)
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)

with suppress(RuntimeError), _numba_thread_limit(2):
if not success:
raise RuntimeError

assert was_set_to == [2, 8]


def test_numba_thread_limit_clamps_to_configured_maximum(
monkeypatch: pytest.MonkeyPatch,
) -> None:
was_set_to = []
monkeypatch.setattr(numba, "get_num_threads", lambda: 3)
monkeypatch.setattr(numba, "set_num_threads", was_set_to.append)
monkeypatch.setattr(numba.config, "NUMBA_NUM_THREADS", 4)

with _numba_thread_limit(99):
pass

assert was_set_to == [4, 3]
Loading