77from executorlib .standalone .batched import batched_futures
88from executorlib .standalone .interactive .arguments import (
99 check_exception_was_raised ,
10+ check_list_of_futures_is_done ,
1011 get_exception_lst ,
1112 get_future_objects_from_input ,
1213 update_futures_in_input ,
@@ -185,6 +186,7 @@ def batched(
185186 "args" : (),
186187 "kwargs" : {"lst" : iterable , "n" : n , "skip_lst" : skip_lst },
187188 "future" : f ,
189+ "future_lst" : iterable ,
188190 "future_skip" : f_skip ,
189191 "resource_dict" : {},
190192 }
@@ -249,7 +251,7 @@ def _execute_tasks_with_dependencies(
249251 executor (TaskSchedulerBase): Executor to execute the tasks with after the dependencies are resolved.
250252 refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
251253 """
252- wait_lst : list = []
254+ future_dependency_lst : list = []
253255 while True :
254256 try :
255257 task_dict = future_queue .get_nowait ()
@@ -258,10 +260,10 @@ def _execute_tasks_with_dependencies(
258260 if ( # shutdown the executor
259261 task_dict is not None and "shutdown" in task_dict and task_dict ["shutdown" ]
260262 ):
261- while len (wait_lst ) > 0 :
263+ while len (future_dependency_lst ) > 0 :
262264 # Check functions in the wait list and execute them if all future objects are now ready
263- wait_lst = _update_waiting_task (
264- wait_lst = wait_lst ,
265+ future_dependency_lst = _handle_future_dependencies (
266+ future_dependency_lst = future_dependency_lst ,
265267 executor_queue = executor_queue ,
266268 refresh_rate = refresh_rate ,
267269 )
@@ -283,12 +285,24 @@ def _execute_tasks_with_dependencies(
283285 task_dict ["future" ].set_result (False )
284286 else :
285287 task_dict ["future" ].set_result (True )
288+ elif ( # handle batched function submitted to the executor
289+ task_dict is not None
290+ and "fn" in task_dict
291+ and task_dict ["fn" ] == "batched"
292+ and "future" in task_dict
293+ ):
294+ future_dependency_lst .append (task_dict )
295+ future_queue .task_done ()
286296 elif ( # handle function submitted to the executor
287- task_dict is not None and "fn" in task_dict and "future" in task_dict
297+ task_dict is not None
298+ and "fn" in task_dict
299+ and task_dict ["fn" ] != "batched"
300+ and "future" in task_dict
288301 ):
289- future_lst , ready_flag = get_future_objects_from_input (
302+ future_lst = get_future_objects_from_input (
290303 args = task_dict ["args" ], kwargs = task_dict ["kwargs" ]
291304 )
305+ ready_flag = check_list_of_futures_is_done (future_lst = future_lst )
292306 exception_lst = get_exception_lst (future_lst = future_lst )
293307 if not check_exception_was_raised (future_obj = task_dict ["future" ]):
294308 if len (exception_lst ) > 0 :
@@ -301,12 +315,12 @@ def _execute_tasks_with_dependencies(
301315 executor_queue .put (task_dict )
302316 else : # Otherwise add the function to the wait list
303317 task_dict ["future_lst" ] = future_lst
304- wait_lst .append (task_dict )
318+ future_dependency_lst .append (task_dict )
305319 future_queue .task_done ()
306- elif len (wait_lst ) > 0 :
320+ elif len (future_dependency_lst ) > 0 :
307321 # Check functions in the wait list and execute them if all future objects are now ready
308- wait_lst = _update_waiting_task (
309- wait_lst = wait_lst ,
322+ future_dependency_lst = _handle_future_dependencies (
323+ future_dependency_lst = future_dependency_lst ,
310324 executor_queue = executor_queue ,
311325 refresh_rate = refresh_rate ,
312326 )
@@ -315,22 +329,24 @@ def _execute_tasks_with_dependencies(
315329 sleep (refresh_rate )
316330
317331
318- def _update_waiting_task (
319- wait_lst : list [dict ], executor_queue : queue .Queue , refresh_rate : float = 0.01
332+ def _handle_future_dependencies (
333+ future_dependency_lst : list [dict ],
334+ executor_queue : queue .Queue ,
335+ refresh_rate : float = 0.01 ,
320336) -> list :
321337 """
322338 Submit the waiting tasks, which future inputs have been completed, to the executor
323339
324340 Args:
325- wait_lst (list): List of waiting tasks
341+ future_dependency_lst (list): List of waiting tasks
326342 executor_queue (Queue): Queue of the internal executor
327343 refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
328344
329345 Returns:
330346 list: list tasks which future inputs have not been completed
331347 """
332348 wait_tmp_lst = []
333- for task_wait_dict in wait_lst :
349+ for task_wait_dict in future_dependency_lst :
334350 exception_lst = get_exception_lst (future_lst = task_wait_dict ["future_lst" ])
335351 if len (exception_lst ) > 0 and task_wait_dict ["fn" ] != "batched" :
336352 task_wait_dict ["future" ].set_exception (exception_lst [0 ])
@@ -360,6 +376,6 @@ def _update_waiting_task(
360376 task_wait_dict ["future_skip" ].set_result ([id (f ) for f in done_lst ])
361377 else :
362378 wait_tmp_lst .append (task_wait_dict )
363- if len (wait_lst ) == len (wait_tmp_lst ):
379+ if len (future_dependency_lst ) == len (wait_tmp_lst ):
364380 sleep (refresh_rate )
365381 return wait_tmp_lst
0 commit comments