|
2 | 2 |
|
3 | 3 | import itertools |
4 | 4 | import string |
| 5 | +from contextlib import suppress |
5 | 6 | from operator import mul, truediv |
6 | 7 | from types import ModuleType |
7 | 8 | from typing import TYPE_CHECKING |
8 | 9 |
|
| 10 | +import numba |
9 | 11 | import numpy as np |
10 | 12 | import pytest |
11 | 13 | from anndata.tests.helpers import asarray |
12 | 14 | from scipy import sparse |
13 | 15 |
|
14 | 16 | from scanpy._compat import CSBase, DaskArray |
15 | 17 | from scanpy._utils import ( |
| 18 | + _numba_thread_limit, |
16 | 19 | axis_mul_or_truediv, |
17 | 20 | check_nonnegative_integers, |
18 | 21 | descend_classes_and_funcs, |
@@ -240,3 +243,32 @@ def test_random_str() -> None: |
240 | 243 | assert strings.dtype == np.dtype("U2") |
241 | 244 | unique = np.unique(strings, axis=0) |
242 | 245 | assert len(unique) == len(strings) |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.parametrize("success", [True, False], ids=["success", "exception"]) |
| 249 | +def test_numba_thread_limit_restores_previous_value( |
| 250 | + *, monkeypatch: pytest.MonkeyPatch, success: bool |
| 251 | +) -> None: |
| 252 | + was_set_to = [] |
| 253 | + monkeypatch.setattr(numba, "get_num_threads", lambda: 8) |
| 254 | + monkeypatch.setattr(numba, "set_num_threads", was_set_to.append) |
| 255 | + |
| 256 | + with suppress(RuntimeError), _numba_thread_limit(2): |
| 257 | + if not success: |
| 258 | + raise RuntimeError |
| 259 | + |
| 260 | + assert was_set_to == [2, 8] |
| 261 | + |
| 262 | + |
| 263 | +def test_numba_thread_limit_clamps_to_configured_maximum( |
| 264 | + monkeypatch: pytest.MonkeyPatch, |
| 265 | +) -> None: |
| 266 | + was_set_to = [] |
| 267 | + monkeypatch.setattr(numba, "get_num_threads", lambda: 3) |
| 268 | + monkeypatch.setattr(numba, "set_num_threads", was_set_to.append) |
| 269 | + monkeypatch.setattr(numba.config, "NUMBA_NUM_THREADS", 4) |
| 270 | + |
| 271 | + with _numba_thread_limit(99): |
| 272 | + pass |
| 273 | + |
| 274 | + assert was_set_to == [4, 3] |
0 commit comments