Skip to content

Commit b95e9a6

Browse files
committed
add parallelize tests again
1 parent 8c9f3d5 commit b95e9a6

2 files changed

Lines changed: 105 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ test = [
8989
"pytest-cov>=4",
9090
"coverage[toml]>=7",
9191
"psutil",
92+
"pytest-isolate",
9293
]
9394
docs = [
9495
"ipython",

tests/utils/test_parallelize.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Tests for verifying process/thread usage in parallelized functions."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable
6+
from functools import partial
7+
import time
8+
9+
import dask.array as da
10+
import numba
11+
import numpy as np
12+
import psutil
13+
import pytest # type: ignore[import]
14+
15+
from squidpy._utils import Signal, parallelize
16+
17+
# Functions to be parallelized
18+
19+
20+
@numba.njit(parallel=True)
21+
def numba_parallel_func(x, y) -> np.ndarray:
22+
return x * 2 + y
23+
24+
25+
@numba.njit(parallel=False)
26+
def numba_serial_func(x, y) -> np.ndarray:
27+
return x * 2 + y
28+
29+
30+
def dask_func(x, y) -> np.ndarray:
31+
return (da.from_array(x) * 2 + y).compute()
32+
33+
34+
def vanilla_func(x, y) -> np.ndarray:
35+
return x * 2 + y
36+
37+
38+
# Mock runner function
39+
40+
41+
def mock_runner(x, y, queue, func):
42+
for i in range(len(x)):
43+
x[i] = func(x[i], y)
44+
if queue is not None:
45+
queue.put(Signal.UPDATE)
46+
if queue is not None:
47+
queue.put(Signal.FINISH)
48+
return x
49+
50+
51+
@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"])
52+
def func(request) -> Callable:
53+
return {
54+
"numba_parallel": numba_parallel_func,
55+
"numba_serial": numba_serial_func,
56+
"dask": dask_func,
57+
"vanilla": vanilla_func,
58+
}[request.param]
59+
60+
61+
@pytest.mark.isolate
62+
@pytest.mark.parametrize("n_jobs", [1, 2, 8])
63+
def test_parallelize_loky(func, n_jobs):
64+
start_time = time.time()
65+
seed = 42
66+
rng = np.random.RandomState(seed)
67+
n = 8
68+
arr1 = [rng.randint(0, 100, n) for _ in range(n)]
69+
arr2 = np.arange(n)
70+
runner = partial(mock_runner, func=func)
71+
# this is the expected result of the function
72+
expected = [func(arr1[i], arr2) for i in range(len(arr1))]
73+
# this will be set to something other than 1,2,8
74+
# we want to check if setting the threads works
75+
# then after the function is run if the numba cores are set back to 1
76+
old_num_threads = 3
77+
numba.set_num_threads(old_num_threads)
78+
# Get initial state
79+
initial_process = psutil.Process()
80+
initial_children = {p.pid for p in initial_process.children(recursive=True)}
81+
initial_children = {psutil.Process(pid) for pid in initial_children}
82+
init_numba_threads = numba.get_num_threads()
83+
84+
p_func = parallelize(runner, arr1, n_jobs=n_jobs, backend="loky", use_ixs=False, n_split=1)
85+
result = p_func(arr2)[0]
86+
87+
final_children = {p.pid for p in initial_process.children(recursive=True)}
88+
final_numba_threads = numba.get_num_threads()
89+
90+
assert init_numba_threads == old_num_threads, "Numba threads should not change"
91+
assert final_numba_threads == 1, "Numba threads should be 1"
92+
assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch"
93+
for i in range(len(arr1)):
94+
assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}"
95+
96+
processes = final_children - initial_children
97+
98+
processes = {psutil.Process(pid) for pid in processes}
99+
processes = {p for p in processes if not any("resource_tracker" in cl for cl in p.cmdline())}
100+
if n_jobs > 1: # expect exactly n_jobs
101+
assert len(processes) == n_jobs, f"Unexpected processes created or not created: {processes}"
102+
else: # some functions use the main process others use a new process
103+
processes = {p for p in processes if p.create_time() > start_time}
104+
assert len(processes) <= 1, f"Unexpected processes created or not created: {processes}"

0 commit comments

Comments
 (0)