Skip to content

Commit ba290e1

Browse files
jan-janssenpyiron-runnerpre-commit-ci[bot]
authored
[Feature] Faster batching (#1015)
* [Feature] Faster batching * Format black * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add docstring --------- Co-authored-by: pyiron-runner <pyiron@mpie.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b62cede commit ba290e1

3 files changed

Lines changed: 51 additions & 21 deletions

File tree

src/executorlib/standalone/interactive/arguments.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,21 @@ def find_future_in_list(lst):
2727

2828
find_future_in_list(lst=args)
2929
find_future_in_list(lst=kwargs.values())
30-
boolean_flag = len([future for future in future_lst if future.done()]) == len(
31-
future_lst
32-
)
33-
return future_lst, boolean_flag
30+
31+
return future_lst
32+
33+
34+
def check_list_of_futures_is_done(future_lst: list[Future]) -> bool:
35+
"""
36+
Check if all future objects in the list of future objects are done
37+
38+
Args:
39+
future_lst (list): list of future objects
40+
41+
Returns:
42+
bool: True if all future objects in the list of future objects are done, False otherwise
43+
"""
44+
return len([future for future in future_lst if future.done()]) == len(future_lst)
3445

3546

3647
def get_exception_lst(future_lst: list[Future]) -> list:

src/executorlib/task_scheduler/interactive/dependency.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from executorlib.standalone.batched import batched_futures
88
from executorlib.standalone.interactive.arguments import (
99
check_exception_was_raised,
10+
check_list_of_futures_is_done,
1011
get_exception_lst,
1112
get_future_objects_from_input,
1213
update_futures_in_input,
@@ -185,6 +186,7 @@ def batched(
185186
"args": (),
186187
"kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst},
187188
"future": f,
189+
"future_lst": iterable,
188190
"future_skip": f_skip,
189191
"resource_dict": {},
190192
}
@@ -249,7 +251,7 @@ def _execute_tasks_with_dependencies(
249251
executor (TaskSchedulerBase): Executor to execute the tasks with after the dependencies are resolved.
250252
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
251253
"""
252-
wait_lst: list = []
254+
future_dependency_lst: list = []
253255
while True:
254256
try:
255257
task_dict = future_queue.get_nowait()
@@ -258,10 +260,10 @@ def _execute_tasks_with_dependencies(
258260
if ( # shutdown the executor
259261
task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]
260262
):
261-
while len(wait_lst) > 0:
263+
while len(future_dependency_lst) > 0:
262264
# Check functions in the wait list and execute them if all future objects are now ready
263-
wait_lst = _update_waiting_task(
264-
wait_lst=wait_lst,
265+
future_dependency_lst = _handle_future_dependencies(
266+
future_dependency_lst=future_dependency_lst,
265267
executor_queue=executor_queue,
266268
refresh_rate=refresh_rate,
267269
)
@@ -283,12 +285,24 @@ def _execute_tasks_with_dependencies(
283285
task_dict["future"].set_result(False)
284286
else:
285287
task_dict["future"].set_result(True)
288+
elif ( # handle batched function submitted to the executor
289+
task_dict is not None
290+
and "fn" in task_dict
291+
and task_dict["fn"] == "batched"
292+
and "future" in task_dict
293+
):
294+
future_dependency_lst.append(task_dict)
295+
future_queue.task_done()
286296
elif ( # handle function submitted to the executor
287-
task_dict is not None and "fn" in task_dict and "future" in task_dict
297+
task_dict is not None
298+
and "fn" in task_dict
299+
and task_dict["fn"] != "batched"
300+
and "future" in task_dict
288301
):
289-
future_lst, ready_flag = get_future_objects_from_input(
302+
future_lst = get_future_objects_from_input(
290303
args=task_dict["args"], kwargs=task_dict["kwargs"]
291304
)
305+
ready_flag = check_list_of_futures_is_done(future_lst=future_lst)
292306
exception_lst = get_exception_lst(future_lst=future_lst)
293307
if not check_exception_was_raised(future_obj=task_dict["future"]):
294308
if len(exception_lst) > 0:
@@ -301,12 +315,12 @@ def _execute_tasks_with_dependencies(
301315
executor_queue.put(task_dict)
302316
else: # Otherwise add the function to the wait list
303317
task_dict["future_lst"] = future_lst
304-
wait_lst.append(task_dict)
318+
future_dependency_lst.append(task_dict)
305319
future_queue.task_done()
306-
elif len(wait_lst) > 0:
320+
elif len(future_dependency_lst) > 0:
307321
# Check functions in the wait list and execute them if all future objects are now ready
308-
wait_lst = _update_waiting_task(
309-
wait_lst=wait_lst,
322+
future_dependency_lst = _handle_future_dependencies(
323+
future_dependency_lst=future_dependency_lst,
310324
executor_queue=executor_queue,
311325
refresh_rate=refresh_rate,
312326
)
@@ -315,22 +329,24 @@ def _execute_tasks_with_dependencies(
315329
sleep(refresh_rate)
316330

317331

318-
def _update_waiting_task(
319-
wait_lst: list[dict], executor_queue: queue.Queue, refresh_rate: float = 0.01
332+
def _handle_future_dependencies(
333+
future_dependency_lst: list[dict],
334+
executor_queue: queue.Queue,
335+
refresh_rate: float = 0.01,
320336
) -> list:
321337
"""
322338
Submit the waiting tasks, which future inputs have been completed, to the executor
323339
324340
Args:
325-
wait_lst (list): List of waiting tasks
341+
future_dependency_lst (list): List of waiting tasks
326342
executor_queue (Queue): Queue of the internal executor
327343
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
328344
329345
Returns:
330346
list: list tasks which future inputs have not been completed
331347
"""
332348
wait_tmp_lst = []
333-
for task_wait_dict in wait_lst:
349+
for task_wait_dict in future_dependency_lst:
334350
exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"])
335351
if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched":
336352
task_wait_dict["future"].set_exception(exception_lst[0])
@@ -360,6 +376,6 @@ def _update_waiting_task(
360376
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
361377
else:
362378
wait_tmp_lst.append(task_wait_dict)
363-
if len(wait_lst) == len(wait_tmp_lst):
379+
if len(future_dependency_lst) == len(wait_tmp_lst):
364380
sleep(refresh_rate)
365381
return wait_tmp_lst

tests/unit/standalone/interactive/test_arguments.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from executorlib.standalone.interactive.arguments import (
55
check_exception_was_raised,
6+
check_list_of_futures_is_done,
67
get_exception_lst,
78
get_future_objects_from_input,
89
update_futures_in_input,
@@ -13,14 +14,16 @@ class TestSerial(unittest.TestCase):
1314
def test_get_future_objects_from_input_with_future(self):
1415
input_args = (1, 2, Future(), [Future()], {3: Future()})
1516
input_kwargs = {"a": 1, "b": [Future()], "c": {"d": Future()}, "e": Future()}
16-
future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
17+
future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
18+
boolean_flag = check_list_of_futures_is_done(future_lst=future_lst)
1719
self.assertEqual(len(future_lst), 6)
1820
self.assertFalse(boolean_flag)
1921

2022
def test_get_future_objects_from_input_without_future(self):
2123
input_args = (1, 2)
2224
input_kwargs = {"a": 1}
23-
future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
25+
future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
26+
boolean_flag = check_list_of_futures_is_done(future_lst=future_lst)
2427
self.assertEqual(len(future_lst), 0)
2528
self.assertTrue(boolean_flag)
2629

0 commit comments

Comments
 (0)