|
4 | 4 | from pathlib import Path |
5 | 5 | from typing import TYPE_CHECKING, TypedDict, cast |
6 | 6 |
|
| 7 | +import numba |
7 | 8 | import numpy as np |
8 | 9 | import pandas as pd |
9 | 10 | import pytest |
|
16 | 17 | from scanpy._utils import select_groups |
17 | 18 | from scanpy.get import rank_genes_groups_df |
18 | 19 | from scanpy.tools import rank_genes_groups |
19 | | -from scanpy.tools._rank_genes_groups import _numba_thread_limit, _RankGenes |
| 20 | +from scanpy.tools._rank_genes_groups import _RankGenes |
20 | 21 | from testing.scanpy._helpers import random_mask |
21 | 22 | from testing.scanpy._helpers.data import pbmc68k_reduced |
22 | 23 | from testing.scanpy._pytest.params import ARRAY_TYPES, ARRAY_TYPES_MEM |
@@ -254,82 +255,39 @@ def test_wilcoxon_tie_correction(*, reference: bool) -> None: |
254 | 255 | np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals, atol=1e-5) |
255 | 256 |
|
256 | 257 |
|
257 | | -def test_wilcoxon_huge_data(monkeypatch): |
| 258 | +def test_wilcoxon_huge_data(monkeypatch: pytest.MonkeyPatch) -> None: |
258 | 259 | max_size = 300 |
259 | 260 | adata = pbmc68k_reduced() |
260 | 261 | monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size) |
261 | 262 | rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon") |
262 | 263 |
|
263 | 264 |
|
264 | | -def test_numba_thread_limit_restores_previous_value(monkeypatch): |
265 | | - calls = [] |
266 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8) |
267 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append) |
268 | | - |
269 | | - with _numba_thread_limit(2): |
270 | | - pass |
271 | | - |
272 | | - assert calls == [2, 8] |
273 | | - |
274 | | - |
275 | | -def test_numba_thread_limit_restores_previous_value_on_exception(monkeypatch): |
276 | | - calls = [] |
277 | | - msg = "synthetic failure" |
278 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8) |
279 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append) |
280 | | - |
281 | | - with ( |
282 | | - pytest.raises(RuntimeError, match=msg), |
283 | | - _numba_thread_limit(2), |
284 | | - ): |
285 | | - raise RuntimeError(msg) |
286 | | - |
287 | | - assert calls == [2, 8] |
288 | | - |
289 | | - |
290 | | -def test_numba_thread_limit_clamps_to_configured_maximum(monkeypatch): |
291 | | - calls = [] |
292 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 3) |
293 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append) |
294 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba.config, "NUMBA_NUM_THREADS", 4) |
295 | | - |
296 | | - with _numba_thread_limit(99): |
297 | | - pass |
298 | | - |
299 | | - assert calls[0] == 4 |
300 | | - assert calls[-1] == 3 |
301 | | - |
302 | | - |
303 | | -def test_wilcoxon_sets_numba_threads_from_settings(monkeypatch): |
304 | | - calls = [] |
305 | | - old_n_jobs = sc.settings.n_jobs |
306 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "get_num_threads", lambda: 8) |
307 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append) |
308 | | - |
309 | | - try: |
310 | | - sc.settings.n_jobs = 2 |
311 | | - adata = get_example_data(np.asarray) |
312 | | - rank_genes_groups(adata, "true_groups", n_genes=5, method="wilcoxon") |
313 | | - finally: |
314 | | - sc.settings.n_jobs = old_n_jobs |
315 | | - |
316 | | - assert 2 in calls, "Wilcoxon path did not use scanpy.settings.n_jobs." |
317 | | - assert calls[-1] == 8 |
318 | | - |
319 | | - |
320 | | -def test_t_test_does_not_set_numba_threads_from_settings(monkeypatch): |
321 | | - calls = [] |
| 265 | +@pytest.mark.parametrize( |
| 266 | + "method", |
| 267 | + [ |
| 268 | + pytest.param( |
| 269 | + "t-test", marks=pytest.mark.xfail(reason="t-test doesn’t use numba (yet)") |
| 270 | + ), |
| 271 | + "wilcoxon", |
| 272 | + ], |
| 273 | +) |
| 274 | +def test_set_numba_threads_from_settings( |
| 275 | + monkeypatch: pytest.MonkeyPatch, method: Literal["t-test", "wilcoxon"] |
| 276 | +) -> None: |
| 277 | + was_set_to = [] |
322 | 278 | old_n_jobs = sc.settings.n_jobs |
323 | | - monkeypatch.setattr(sc.tl._rank_genes_groups.numba, "set_num_threads", calls.append) |
| 279 | + monkeypatch.setattr(numba, "get_num_threads", lambda: 8) |
| 280 | + monkeypatch.setattr(numba, "set_num_threads", was_set_to.append) |
324 | 281 |
|
325 | 282 | try: |
326 | 283 | sc.settings.n_jobs = 2 |
327 | 284 | adata = get_example_data(np.asarray) |
328 | | - rank_genes_groups(adata, "true_groups", n_genes=5, method="t-test") |
| 285 | + rank_genes_groups(adata, "true_groups", n_genes=5, method=method) |
329 | 286 | finally: |
330 | 287 | sc.settings.n_jobs = old_n_jobs |
331 | 288 |
|
332 | | - assert calls == [] |
| 289 | + assert 2 in was_set_to, "Wilcoxon path did not use scanpy.settings.n_jobs." |
| 290 | + assert was_set_to[-1] == 8 |
333 | 291 |
|
334 | 292 |
|
335 | 293 | @pytest.mark.parametrize( |
|
0 commit comments