Skip to content

Commit d0f5f82

Browse files
committed
save working progress
1 parent 8016f49 commit d0f5f82

2 files changed

Lines changed: 241 additions & 0 deletions

File tree

play.ipynb

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%load_ext autoreload\n",
10+
"%autoreload 2"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"\"\"\"Tests for verifying process/thread usage in parallelized functions.\"\"\"\n",
20+
"\n",
21+
"from __future__ import annotations\n",
22+
"\n",
23+
"import numpy as np\n",
24+
"import pytest # type: ignore[import]\n",
25+
"import numba\n",
26+
"import dask.array as da\n",
27+
"from typing import Callable\n",
28+
"from functools import partial\n",
29+
"\n",
30+
"from squidpy._utils import parallelize, Signal\n",
31+
"\n",
32+
"\n",
33+
"\n",
34+
"# Functions to be parallelized\n",
35+
"\n",
36+
"@numba.njit(parallel=True)\n",
37+
"def numba_parallel_func(x, y) -> np.ndarray:\n",
38+
" return x * 2 + y\n",
39+
"\n",
40+
"@numba.njit(parallel=False)\n",
41+
"def numba_serial_func(x, y) -> np.ndarray:\n",
42+
" return x * 2 + y\n",
43+
"\n",
44+
"def dask_func(x, y) -> np.ndarray:\n",
45+
" return (da.from_array(x) * 2 + y).compute()\n",
46+
"\n",
47+
"def vanilla_func(x, y) -> np.ndarray:\n",
48+
" return x * 2 + y\n",
49+
"\n",
50+
"# Mock runner function\n",
51+
"\n",
52+
"def mock_runner(x, y, queue, func):\n",
53+
" for i in range(len(x)):\n",
54+
" x[i] = func(x[i], y)\n",
55+
" if queue is not None:\n",
56+
" queue.put(Signal.UPDATE)\n",
57+
" if queue is not None:\n",
58+
" queue.put(Signal.FINISH)\n",
59+
" return x\n",
60+
"\n",
61+
"\n",
62+
"@pytest.fixture(params=[\"numba_parallel\", \"numba_serial\", \"dask\", \"vanilla\"])\n",
63+
"def func(request) -> Callable:\n",
64+
" return {\n",
65+
" \"numba_parallel\": numba_parallel_func,\n",
66+
" \"numba_serial\": numba_serial_func,\n",
67+
" \"dask\": dask_func,\n",
68+
" \"vanilla\": vanilla_func,\n",
69+
" }[request.param]\n",
70+
"\n"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 8,
76+
"metadata": {},
77+
"outputs": [
78+
{
79+
"data": {
80+
"application/vnd.jupyter.widget-view+json": {
81+
"model_id": "4f5ca04ed21c48cbb923359030b6fefb",
82+
"version_major": 2,
83+
"version_minor": 0
84+
},
85+
"text/plain": [
86+
" 0%| | 0/8 [00:00<?, ?/s]"
87+
]
88+
},
89+
"metadata": {},
90+
"output_type": "display_data"
91+
},
92+
{
93+
"name": "stdout",
94+
"output_type": "stream",
95+
"text": [
96+
"8 8\n",
97+
"8 8\n"
98+
]
99+
},
100+
{
101+
"name": "stderr",
102+
"output_type": "stream",
103+
"text": [
104+
"/Users/selman/miniforge3/envs/squidpy/lib/python3.11/site-packages/dask/dataframe/__init__.py:31: FutureWarning: The legacy Dask DataFrame implementation is deprecated and will be removed in a future version. Set the configuration option `dataframe.query-planning` to `True` or None to enable the new Dask Dataframe implementation and silence this warning.\n",
105+
" warnings.warn(\n"
106+
]
107+
},
108+
{
109+
"name": "stdout",
110+
"output_type": "stream",
111+
"text": [
112+
"8 8\n",
113+
"8 8\n",
114+
"8 8\n",
115+
"8 8\n",
116+
"8 8\n",
117+
"8 8\n"
118+
]
119+
},
120+
{
121+
"ename": "AssertionError",
122+
"evalue": "Expected: [array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21])] but got [array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21])]. Length mismatch",
123+
"output_type": "error",
124+
"traceback": [
125+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
126+
"\u001b[31mAssertionError\u001b[39m Traceback (most recent call last)",
127+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m p_func = parallelize(runner, arr1, n_jobs=\u001b[32m2\u001b[39m, backend=\u001b[33m\"\u001b[39m\u001b[33mloky\u001b[39m\u001b[33m\"\u001b[39m, use_ixs=\u001b[38;5;28;01mFalse\u001b[39;00m, n_splits=\u001b[38;5;28mlen\u001b[39m(arr1))\n\u001b[32m 8\u001b[39m result = p_func(arr2)[\u001b[32m0\u001b[39m]\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(result) == \u001b[38;5;28mlen\u001b[39m(expected), \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExpected: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. Length mismatch\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(arr1)):\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m np.all(result[i] == expected[i]), \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExpected \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected[i]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult[i]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n",
128+
"\u001b[31mAssertionError\u001b[39m: Expected: [array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21])] but got [array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21]), array([ 0, 3, 6, 9, 12, 15, 18, 21])]. Length mismatch"
129+
]
130+
}
131+
],
132+
"source": [
133+
"n = 8\n",
134+
"func = numba_parallel_func\n",
135+
"arr1 = [np.arange(n) for _ in range(n)]\n",
136+
"arr2 = np.arange(n)\n",
137+
"runner = partial(mock_runner, func=func)\n",
138+
"# expected = [func(arr1[i], arr2) for i in range(len(arr1))]\n",
139+
"p_func = parallelize(runner, arr1, n_jobs=2, backend=\"loky\", use_ixs=False, n_splits=len(arr1))\n",
140+
"result = p_func(arr2)[0]\n",
141+
"assert len(result) == len(expected), f\"Expected: {expected} but got {result}. Length mismatch\"\n",
142+
"for i in range(len(arr1)):\n",
143+
" assert np.all(result[i] == expected[i]), f\"Expected {expected[i]} but got {result[i]}\"\n"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"metadata": {},
150+
"outputs": [],
151+
"source": []
152+
}
153+
],
154+
"metadata": {
155+
"kernelspec": {
156+
"display_name": "squidpy",
157+
"language": "python",
158+
"name": "python3"
159+
},
160+
"language_info": {
161+
"codemirror_mode": {
162+
"name": "ipython",
163+
"version": 3
164+
},
165+
"file_extension": ".py",
166+
"mimetype": "text/x-python",
167+
"name": "python",
168+
"nbconvert_exporter": "python",
169+
"pygments_lexer": "ipython3",
170+
"version": "3.11.11"
171+
}
172+
},
173+
"nbformat": 4,
174+
"nbformat_minor": 2
175+
}

tests/utils/test_parallelize.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Tests for verifying process/thread usage in parallelized functions."""
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
import pytest # type: ignore[import]
7+
import numba
8+
import dask.array as da
9+
from typing import Callable
10+
from functools import partial
11+
12+
from squidpy._utils import parallelize, Signal
13+
14+
15+
16+
# Functions to be parallelized
17+
18+
@numba.njit(parallel=True)
19+
def numba_parallel_func(x, y) -> np.ndarray:
20+
return x * 2 + y
21+
22+
@numba.njit(parallel=False)
23+
def numba_serial_func(x, y) -> np.ndarray:
24+
return x * 2 + y
25+
26+
def dask_func(x, y) -> np.ndarray:
27+
return (da.from_array(x) * 2 + y).compute()
28+
29+
def vanilla_func(x, y) -> np.ndarray:
30+
return x * 2 + y
31+
32+
# Mock runner function
33+
34+
def mock_runner(x, y, queue, func):
35+
for i in range(len(x)):
36+
print(len(x[i]), len(y))
37+
x[i] = func(x[i], y)
38+
if queue is not None:
39+
queue.put(Signal.UPDATE)
40+
if queue is not None:
41+
queue.put(Signal.FINISH)
42+
return x
43+
44+
45+
@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"])
46+
def func(request) -> Callable:
47+
return {
48+
"numba_parallel": numba_parallel_func,
49+
"numba_serial": numba_serial_func,
50+
"dask": dask_func,
51+
"vanilla": vanilla_func,
52+
}[request.param]
53+
54+
55+
@pytest.mark.parametrize("n_jobs", [1, 2, 8])
56+
def test_parallelize_loky(func, n_jobs):
57+
n = 8
58+
arr1 = [np.arange(n) for _ in range(n)]
59+
arr2 = np.arange(n)
60+
runner = partial(mock_runner, func=func)
61+
expected = [func(arr1[i], arr2) for i in range(len(arr1))]
62+
p_func = parallelize(runner, arr1, n_jobs=n_jobs, backend="loky", use_ixs=False)
63+
result = p_func(arr2)[0]
64+
assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch"
65+
for i in range(len(arr1)):
66+
assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}"

0 commit comments

Comments
 (0)