1
1
import re
2
+ from functools import partial
2
3
from pathlib import Path
3
4
from typing import Any , Dict , List , Tuple , Union
4
5
5
6
import cv2
6
7
import numpy as np
7
8
import scipy .io as sio
8
- from pathos .multiprocessing import ThreadPool as Pool
9
- from tqdm import tqdm
10
9
11
10
from .mask_utils import (
12
11
bounding_box ,
15
14
get_inst_types ,
16
15
label_semantic ,
17
16
)
17
+ from .multiproc import run_pool
18
18
19
19
20
20
class FileHandler :
@@ -58,30 +58,34 @@ def read_mat(
58
58
key : str = "inst_map" ,
59
59
retype : bool = True ,
60
60
return_all : bool = False ,
61
- ) -> Union [np .ndarray , None ]:
61
+ ) -> Union [np .ndarray , Dict [ str , np . ndarray ], None ]:
62
62
"""Read a mask from a .mat file.
63
63
64
64
If a mask is not found, return None
65
65
66
66
Parameters
67
67
----------
68
68
path : str or Path
69
- Path to the image file.
69
+ Path to the .mat file.
70
70
key : str, default="inst_map"
71
71
Name/key of the mask type that is being read from .mat
72
72
retype : bool, default=True
73
73
Convert the matrix type.
74
74
return_all : bool, default=False
75
75
Return the whole dict. Overrides the `key` arg.
76
76
77
- Returns
78
- -------
79
- np.ndarray or None:
80
- The mask indice matrix. Shape (H, W)
81
77
82
78
Raises
83
79
------
84
80
ValueError: If an illegal key is given.
81
+
82
+ Returns
83
+ -------
84
+ Union[np.ndarray, List[np.ndarray], None]:
85
+ if return_all == False:
86
+ The instance/type/semantic labelled mask. Shape: (H, W).
87
+ if return_all == True:
88
+ All the masks in the .mat file returned in a dictionary.
85
89
"""
86
90
dtypes = {
87
91
"inst_map" : "int32" ,
@@ -468,7 +472,8 @@ def save_masks_parallel(
468
472
classes_type : Dict [str , str ] = None ,
469
473
classes_sem : Dict [str , str ] = None ,
470
474
offsets : bool = False ,
471
- progress_bar : bool = False ,
475
+ pooltype : str = "thread" ,
476
+ maptype : str = "amap" ,
472
477
** kwargs ,
473
478
) -> None :
474
479
"""Save the model output masks to a folder. (multi-threaded).
@@ -493,31 +498,44 @@ def save_masks_parallel(
493
498
offsets : bool, default=False
494
499
If True, geojson coords are shifted by the offsets that are encoded in
495
500
the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` != ".json"
496
- progress_bar : bool, default=False
497
- If True, a tqdm progress bar is shown.
501
+ pooltype : str, default="thread"
502
+ The pathos pooltype. Allowed: ("process", "thread", "serial").
503
+ Defaults to "thread". (Fastest in benchmarks.)
504
+ maptype : str, default="amap"
505
+ The map type of the pathos Pool object.
506
+ Allowed: ("map", "amap", "imap", "uimap")
507
+ Defaults to "amap". (Fastest in benchmarks).
498
508
"""
499
- formats = [ format ] * len ( maps )
500
- geo_formats = [ geo_format ] * len ( maps )
501
- classes_type = [ classes_type ] * len ( maps )
502
- classes_sem = [ classes_sem ] * len ( maps )
503
- offsets = [ offsets ] * len ( maps )
504
- args = tuple (
505
- zip ( fnames , maps , formats , geo_formats , classes_type , classes_sem , offsets )
509
+ func = partial (
510
+ FileHandler . _save_masks ,
511
+ format = format ,
512
+ geo_format = geo_format ,
513
+ classes_type = classes_type ,
514
+ classes_sem = classes_sem ,
515
+ offsets = offsets ,
506
516
)
507
517
508
- with Pool () as pool :
509
- if progress_bar :
510
- it = tqdm (pool .imap (FileHandler ._save_masks , args ), total = len (maps ))
511
- else :
512
- it = pool .imap (FileHandler ._save_masks , args )
513
-
514
- for _ in it :
515
- pass
518
+ args = tuple (zip (fnames , maps ))
519
+ run_pool (func , args , ret = False , pooltype = pooltype , maptype = maptype )
516
520
517
521
@staticmethod
518
- def _save_masks (args : Tuple [Dict [str , np .ndarray ], str , str ]) -> None :
522
+ def _save_masks (
523
+ args : Tuple [str , Dict [str , np .ndarray ]],
524
+ format : str ,
525
+ geo_format : str ,
526
+ classes_type : Dict [str , str ],
527
+ classes_sem : Dict [str , str ],
528
+ offsets : bool ,
529
+ ) -> None :
519
530
"""Unpacks the args for `save_mask` to enable multi-threading."""
520
- return FileHandler .save_masks (* args )
531
+ return FileHandler .save_masks (
532
+ * args ,
533
+ format = format ,
534
+ geo_format = geo_format ,
535
+ classes_type = classes_type ,
536
+ classes_sem = classes_sem ,
537
+ offsets = offsets ,
538
+ )
521
539
522
540
@staticmethod
523
541
def get_split (string : str ) -> List [str ]:
0 commit comments