@@ -56,6 +56,11 @@ def _unique_order_preserving(
5656 return [i for i in iterable if not (i in seen or seen_add (i ))], seen
5757
5858
59+ def _callback_wrapper (chosen_runner : Callable [..., Any ], * args : Any , ** kwargs : Any ) -> Any :
60+ numba .set_num_threads (1 )
61+ return chosen_runner (* args , ** kwargs )
62+
63+
5964class Signal (Enum ):
6065 """Signaling values when informing parallelizer."""
6166
@@ -164,9 +169,7 @@ def update(pbar: tqdm.std.tqdm, queue: SigQueue, n_total: int) -> None:
164169 if pbar is not None :
165170 pbar .close ()
166171
167- def callback_wrapper (* args : Any , ** kwargs : Any ) -> Any :
168- numba .set_num_threads (1 )
169- return (runner if use_runner else callback )(* args , ** kwargs )
172+ chosen_runner = runner if use_runner else callback
170173
171174 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
172175 numba .set_num_threads (1 )
@@ -179,8 +182,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
179182 pbar , queue , thread = None , None , None
180183
181184 res = jl .Parallel (n_jobs = n_jobs , backend = backend )(
182- jl .delayed (callback_wrapper )(
183- * ((i , cs ) if use_ixs else (cs , )),
185+ jl .delayed (_callback_wrapper )(
186+ * ((chosen_runner , i , cs ) if use_ixs else (chosen_runner , cs )),
184187 * args ,
185188 ** kwargs ,
186189 queue = queue ,
0 commit comments