Skip to content

Commit 73ab5e7

Browse files
committed
feat(utils): refactor multiprocessing codes.
1 parent 19b14ed commit 73ab5e7

File tree

6 files changed

+228
-42
lines changed

6 files changed

+228
-42
lines changed

cellseg_models_pytorch/inference/post_processor.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Callable, Dict, List
22

33
import numpy as np
4-
from pathos.multiprocessing import ThreadPool as Pool
54
from skimage.util import img_as_ubyte
6-
from tqdm import tqdm
75

86
from ..postproc import POSTPROC_LOOKUP
97
from ..utils import (
@@ -13,6 +11,7 @@
1311
med_filt_parallel,
1412
med_filt_sequential,
1513
remove_debris_semantic,
14+
run_pool,
1615
)
1716

1817
__all__ = ["PostProcessor"]
@@ -166,30 +165,32 @@ def post_proc_pipeline(
166165
return res
167166

168167
def run_parallel(
169-
self, maps: List[Dict[str, np.ndarray]], progress_bar: bool = False
168+
self,
169+
maps: List[Dict[str, np.ndarray]],
170+
pooltype: str = "thread",
171+
maptype: str = "amap",
170172
) -> List[Dict[str, np.ndarray]]:
171173
"""Run the full post-processing pipeline in parallel for many model outputs.
172174
173175
Parameters
174176
----------
175177
maps : List[Dict[str, np.ndarray]]
176178
The model output map dictionaries in a list.
177-
progress_bar : bool, default=False
178-
If True, a tqdm progress bar is shown.
179+
pooltype : str, default="thread"
180+
The pathos pooltype. Allowed: ("process", "thread", "serial").
181+
Defaults to "thread". (Fastest in benchmarks.)
182+
maptype : str, default="amap"
183+
The map type of the pathos Pool object.
184+
Allowed: ("map", "amap", "imap", "uimap")
185+
Defaults to "amap". (Fastest in benchmarks).
179186
180187
Returns
181188
-------
182189
List[Dict[str, np.ndarray]]:
183190
The post-processed output map dictionaries in a list.
184191
"""
185-
seg_results = []
186-
with Pool() as pool:
187-
if progress_bar:
188-
it = tqdm(pool.imap(self.post_proc_pipeline, maps), total=len(maps))
189-
else:
190-
it = pool.imap(self.post_proc_pipeline, maps)
191-
192-
for x in it:
193-
seg_results.append(x)
192+
seg_results = run_pool(
193+
self.post_proc_pipeline, maps, ret=True, pooltype=pooltype, maptype=maptype
194+
)
194195

195196
return seg_results

cellseg_models_pytorch/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
soft_type_flatten,
3434
type_map_flatten,
3535
)
36+
from .multiproc import run_pool
3637
from .patching import (
3738
TilerStitcher,
3839
TilerStitcherTorch,
@@ -134,4 +135,5 @@
134135
"med_filt_parallel",
135136
"med_filt_sequential",
136137
"intersection",
138+
"run_pool",
137139
]

cellseg_models_pytorch/utils/file_manager.py

+46-28
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import re
2+
from functools import partial
23
from pathlib import Path
34
from typing import Any, Dict, List, Tuple, Union
45

56
import cv2
67
import numpy as np
78
import scipy.io as sio
8-
from pathos.multiprocessing import ThreadPool as Pool
9-
from tqdm import tqdm
109

1110
from .mask_utils import (
1211
bounding_box,
@@ -15,6 +14,7 @@
1514
get_inst_types,
1615
label_semantic,
1716
)
17+
from .multiproc import run_pool
1818

1919

2020
class FileHandler:
@@ -58,30 +58,34 @@ def read_mat(
5858
key: str = "inst_map",
5959
retype: bool = True,
6060
return_all: bool = False,
61-
) -> Union[np.ndarray, None]:
61+
) -> Union[np.ndarray, Dict[str, np.ndarray], None]:
6262
"""Read a mask from a .mat file.
6363
6464
If a mask is not found, return None
6565
6666
Parameters
6767
----------
6868
path : str or Path
69-
Path to the image file.
69+
Path to the .mat file.
7070
key : str, default="inst_map"
7171
Name/key of the mask type that is being read from .mat
7272
retype : bool, default=True
7373
Convert the matrix type.
7474
return_all : bool, default=False
7575
Return the whole dict. Overrides the `key` arg.
7676
77-
Returns
78-
-------
79-
np.ndarray or None:
80-
The mask indice matrix. Shape (H, W)
8177
8278
Raises
8379
------
8480
ValueError: If an illegal key is given.
81+
82+
Returns
83+
-------
84+
Union[np.ndarray, List[np.ndarray], None]:
85+
if return_all == False:
86+
The instance/type/semantic labelled mask. Shape: (H, W).
87+
if return_all == True:
88+
All the masks in the .mat file returned in a dictionary.
8589
"""
8690
dtypes = {
8791
"inst_map": "int32",
@@ -468,7 +472,8 @@ def save_masks_parallel(
468472
classes_type: Dict[str, str] = None,
469473
classes_sem: Dict[str, str] = None,
470474
offsets: bool = False,
471-
progress_bar: bool = False,
475+
pooltype: str = "thread",
476+
maptype: str = "amap",
472477
**kwargs,
473478
) -> None:
474479
"""Save the model output masks to a folder. (multi-threaded).
@@ -493,31 +498,44 @@ def save_masks_parallel(
493498
offsets : bool, default=False
494499
If True, geojson coords are shifted by the offsets that are encoded in
495500
the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` != ".json"
496-
progress_bar : bool, default=False
497-
If True, a tqdm progress bar is shown.
501+
pooltype : str, default="thread"
502+
The pathos pooltype. Allowed: ("process", "thread", "serial").
503+
Defaults to "thread". (Fastest in benchmarks.)
504+
maptype : str, default="amap"
505+
The map type of the pathos Pool object.
506+
Allowed: ("map", "amap", "imap", "uimap")
507+
Defaults to "amap". (Fastest in benchmarks).
498508
"""
499-
formats = [format] * len(maps)
500-
geo_formats = [geo_format] * len(maps)
501-
classes_type = [classes_type] * len(maps)
502-
classes_sem = [classes_sem] * len(maps)
503-
offsets = [offsets] * len(maps)
504-
args = tuple(
505-
zip(fnames, maps, formats, geo_formats, classes_type, classes_sem, offsets)
509+
func = partial(
510+
FileHandler._save_masks,
511+
format=format,
512+
geo_format=geo_format,
513+
classes_type=classes_type,
514+
classes_sem=classes_sem,
515+
offsets=offsets,
506516
)
507517

508-
with Pool() as pool:
509-
if progress_bar:
510-
it = tqdm(pool.imap(FileHandler._save_masks, args), total=len(maps))
511-
else:
512-
it = pool.imap(FileHandler._save_masks, args)
513-
514-
for _ in it:
515-
pass
518+
args = tuple(zip(fnames, maps))
519+
run_pool(func, args, ret=False, pooltype=pooltype, maptype=maptype)
516520

517521
@staticmethod
518-
def _save_masks(args: Tuple[Dict[str, np.ndarray], str, str]) -> None:
522+
def _save_masks(
523+
args: Tuple[str, Dict[str, np.ndarray]],
524+
format: str,
525+
geo_format: str,
526+
classes_type: Dict[str, str],
527+
classes_sem: Dict[str, str],
528+
offsets: bool,
529+
) -> None:
519530
"""Unpacks the args for `save_mask` to enable multi-threading."""
520-
return FileHandler.save_masks(*args)
531+
return FileHandler.save_masks(
532+
*args,
533+
format=format,
534+
geo_format=geo_format,
535+
classes_type=classes_type,
536+
classes_sem=classes_sem,
537+
offsets=offsets,
538+
)
521539

522540
@staticmethod
523541
def get_split(string: str) -> List[str]:
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Any, Callable, Generator, List, Union
2+
3+
from pathos.pools import ProcessPool, SerialPool, ThreadPool
4+
5+
__all__ = ["run_pool"]
6+
7+
8+
def iter_pool_generator(it: Generator, res: List = None) -> Union[List[Any], None]:
9+
"""Iterate over a pool generator object.
10+
11+
Parameters
12+
----------
13+
it : Generator
14+
A Generator object containing results from a concurrent run.
15+
res : List | None
16+
An empty list, where the results from the generator will be saved.
17+
If None, no results will be saved.
18+
19+
Returns
20+
-------
21+
Union[List[Any], None]:
22+
A list of results or None.
23+
"""
24+
if res is not None:
25+
for x in it:
26+
res.append(x)
27+
else:
28+
for _ in it:
29+
pass
30+
31+
return res
32+
33+
34+
def run_pool(
35+
func: Callable,
36+
args: List[Any],
37+
ret: bool = True,
38+
pooltype: str = "thread",
39+
maptype: str = "amap",
40+
) -> Union[List[Any], None]:
41+
"""Run a pathos Thread, Process or Serial pool object.
42+
43+
NOTE: if `ret` is set to True and `func` callable does not return anything. This
44+
will return a list of None values.
45+
46+
Parameters
47+
----------
48+
func : Callable
49+
The function that will be copied to existing core and run in parallel.
50+
args : List[Any]
51+
A list of arguments for each of the parallelly executed functions.
52+
ret : bool, default=True
53+
Flag, whether to return a list of results from the pool object. Will be set
54+
to False e.g. when saving data to disk in parallel etc.
55+
pooltype : str, default="thread"
56+
The pathos pooltype. Allowed: ("process", "thread", "serial")
57+
maptype : str, default="amap"
58+
The map type of the pathos Pool object.
59+
Allowed: ("map", "amap", "imap", "uimap")
60+
61+
Raises
62+
------
63+
ValueError: if illegal `pooltype` or `maptype` is given.
64+
65+
Returns
66+
-------
67+
Union[List[Any], None]:
68+
A list of results or None.
69+
"""
70+
allowed = ("process", "thread", "serial")
71+
if pooltype not in allowed:
72+
raise ValueError(f"Illegal `pooltype`. Got {pooltype}. Allowed: {allowed}")
73+
74+
allowed = ("map", "amap", "imap", "uimap")
75+
if maptype not in allowed:
76+
raise ValueError(f"Illegal `maptype`. Got {maptype}. Allowed: {allowed}")
77+
78+
Pool = None
79+
if pooltype == "thread":
80+
Pool = ThreadPool
81+
elif pooltype == "process":
82+
Pool = ProcessPool
83+
else:
84+
if maptype in ("amap", "uimap"):
85+
raise ValueError(
86+
f"`SerialPool` has only `map` & `imap` implemented. Got: {maptype}."
87+
)
88+
Pool = SerialPool
89+
90+
results = [] if ret else None
91+
if maptype == "map":
92+
with Pool() as pool:
93+
it = pool.map(func, args)
94+
results = iter_pool_generator(it, results)
95+
elif maptype == "amap":
96+
with Pool() as pool:
97+
results = pool.amap(func, args).get()
98+
elif maptype == "imap":
99+
with Pool() as pool:
100+
it = pool.imap(func, args)
101+
results = iter_pool_generator(it, results)
102+
elif maptype == "uimap":
103+
with Pool() as pool:
104+
it = pool.uimap(func, args)
105+
results = iter_pool_generator(it, results)
106+
107+
return results
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
3+
from cellseg_models_pytorch.utils import run_pool
4+
5+
6+
def wrap1(num):
7+
return num + 2
8+
9+
10+
def wrap2(arg):
11+
return
12+
13+
14+
@pytest.mark.parametrize(
15+
"typesets",
16+
[
17+
("thread", "imap"),
18+
("thread", "map"),
19+
("thread", "uimap"),
20+
("thread", "amap"),
21+
("process", "map"),
22+
("process", "amap"),
23+
("process", "imap"),
24+
("process", "uimap"),
25+
("serial", "map"),
26+
("serial", "imap"),
27+
pytest.param(("serial", "amap"), marks=pytest.mark.xfail),
28+
pytest.param(("serial", "uimap"), marks=pytest.mark.xfail),
29+
],
30+
)
31+
@pytest.mark.parametrize("func", [wrap1, wrap2])
32+
@pytest.mark.parametrize("ret", [True, False])
33+
def test_run_pool(typesets, func, ret):
34+
args = [1, 2, 3, 4, 5, 6, 7, 8]
35+
res = run_pool(func, args, ret=ret, pooltype=typesets[0], maptype=typesets[1])
36+
37+
if ret and func == wrap1:
38+
assert res == [3, 4, 5, 6, 7, 8, 9, 10]
39+
elif ret and func == wrap2:
40+
assert res == [None, None, None, None, None, None, None, None]
41+
else:
42+
# with amap, we always get a return val..
43+
if not ret and typesets[1] == "amap":
44+
pass
45+
else:
46+
assert res is None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
## Features
2+
3+
- added more pathos.Pool options for parallel processing. Added `ThreadPool`, `ProcessPool` & `SerialPool`
4+
- add all the mapping methods for each Pool obj. I.e. `amap`, `imap`, `uimap` and `map`
5+
6+
## Refactor
7+
8+
- Refactored multiprocessing code to be reusable and moved it under `utils`.
9+
10+
## Tests
11+
12+
- added tests for the multiprocessing tools.

0 commit comments

Comments
 (0)