Skip to content

Commit 04f2983

Browse files
jan-janssenpre-commit-ci[bot]pyiron-runner
authored
[Feature] Check which future objects are skipped and keep track of their IDs (#1014)
* [Test] Add duplicated test for batched * Handling exceptions * revert refactor * Raise an exception when futures failed * fix * add unit test for failing futures * compare to valueerror * fixes * two futures worked two failed * order is not guranteed * Keep track of the futures * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Format black * fix type annotation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: pyiron-runner <pyiron@mpie.de>
1 parent 65f1ef3 commit 04f2983

4 files changed

Lines changed: 120 additions & 25 deletions

File tree

src/executorlib/standalone/batched.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,37 @@
33

44
def batched_futures(
55
lst: list[Future], nested_skip_lst: list[Future[list]], n: int
6-
) -> list[list]:
6+
) -> tuple[bool, list[Future]]:
77
"""
88
Batch n completed future objects. If the number of completed futures is smaller than n and the end of the batch is
99
not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set
1010
then they are returned as batch.
1111
1212
Args:
1313
lst (list): list of all future objects
14-
nested_skip_lst (list): nest list of individual results already assigned to previous batches
14+
nested_skip_lst (list): list of future objects, which contain the list of future objects ids which should be skipped for the batch
1515
n (int): batch size
1616
1717
Returns:
1818
list: results of the batched futures
1919
"""
20-
skip_set = {id(item) for f in nested_skip_lst for item in f.result()}
20+
skip_set = {fid for f in nested_skip_lst for fid in f.result()}
2121

2222
done_lst = []
23+
failed_lst = []
2324
n_expected = min(n, len(lst) - len(skip_set))
2425
for v in lst:
25-
if v.done() and id(v.result()) not in skip_set:
26-
done_lst.append(v.result())
27-
if len(done_lst) == n_expected:
28-
return done_lst
29-
return []
26+
if id(v) not in skip_set and v.done():
27+
if v.exception() is not None:
28+
failed_lst.append(v)
29+
elif id(v) not in skip_set and v.done():
30+
done_lst.append(v)
31+
if len(done_lst) == n_expected:
32+
return True, done_lst
33+
if (len(lst) - len(skip_set)) == len(failed_lst):
34+
return (
35+
False,
36+
failed_lst[:n_expected],
37+
) # raise the exception only after all futures have failed
38+
else:
39+
return True, []

src/executorlib/task_scheduler/interactive/dependency.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,19 @@ def batched(
177177
future_lst: list[Future] = []
178178
for _ in range(len(iterable) // n + (1 if len(iterable) % n > 0 else 0)):
179179
f: Future = Future()
180+
f_skip: Future = Future()
180181
if self._future_queue is not None:
181182
self._future_queue.put(
182183
{
183184
"fn": "batched",
184185
"args": (),
185186
"kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst},
186187
"future": f,
188+
"future_skip": f_skip,
187189
"resource_dict": {},
188190
}
189191
)
190-
skip_lst = skip_lst.copy() + [f] # be careful
192+
skip_lst = skip_lst.copy() + [f_skip] # be careful
191193
future_lst.append(f)
192194

193195
return future_lst
@@ -330,7 +332,7 @@ def _update_waiting_task(
330332
wait_tmp_lst = []
331333
for task_wait_dict in wait_lst:
332334
exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"])
333-
if len(exception_lst) > 0:
335+
if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched":
334336
task_wait_dict["future"].set_exception(exception_lst[0])
335337
elif task_wait_dict["fn"] != "batched" and all(
336338
future.done() for future in task_wait_dict["future_lst"]
@@ -343,15 +345,19 @@ def _update_waiting_task(
343345
elif task_wait_dict["fn"] == "batched" and all(
344346
future.done() for future in task_wait_dict["kwargs"]["skip_lst"]
345347
):
346-
done_lst = batched_futures(
348+
success, done_lst = batched_futures(
347349
lst=task_wait_dict["kwargs"]["lst"],
348350
n=task_wait_dict["kwargs"]["n"],
349351
nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"],
350352
)
351-
if len(done_lst) == 0:
353+
if success and len(done_lst) == 0:
352354
wait_tmp_lst.append(task_wait_dict)
355+
elif success and len(done_lst) > 0:
356+
task_wait_dict["future"].set_result([f.result() for f in done_lst])
357+
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
353358
else:
354-
task_wait_dict["future"].set_result(done_lst)
359+
task_wait_dict["future"].set_exception(done_lst[0].exception())
360+
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
355361
else:
356362
wait_tmp_lst.append(task_wait_dict)
357363
if len(wait_lst) == len(wait_tmp_lst):

tests/unit/executor/test_single_dependencies.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,30 @@ def test_batched(self):
8282
self.assertEqual(len(result_lst), 4)
8383
self.assertTrue(t3-t2 > t2-t1)
8484

85+
def test_batched_error_future(self):
86+
with SingleNodeExecutor() as exe:
87+
t1 = time()
88+
future_first_lst = []
89+
for i in range(10):
90+
if i % 3 == 0:
91+
future_first_lst.append(exe.submit(raise_error, parameter=0))
92+
else:
93+
future_first_lst.append(exe.submit(return_input_dict, i))
94+
future_second_lst = exe.batched(future_first_lst, n=3)
95+
96+
future_third_lst = []
97+
for f in future_second_lst:
98+
future_third_lst.append(exe.submit(sum, f))
99+
100+
t2 = time()
101+
self.assertEqual(future_third_lst[0].result() + future_third_lst[1].result(), 27)
102+
with self.assertRaises(RuntimeError):
103+
future_third_lst[2].result()
104+
with self.assertRaises(RuntimeError):
105+
future_third_lst[3].result()
106+
t3 = time()
107+
self.assertTrue(t3-t2 > t2-t1)
108+
85109
def test_batched_error(self):
86110
with self.assertRaises(TypeError):
87111
with SingleNodeExecutor() as exe:
Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,82 @@
1-
from unittest import TestCase
1+
import unittest
22
from concurrent.futures import Future
3-
from executorlib.standalone.batched import batched_futures
43

4+
from executorlib.task_scheduler.interactive.dependency import batched_futures
55

6-
class TestBatched(TestCase):
6+
7+
class TestBatched(unittest.TestCase):
78
def test_batched_futures(self):
89
lst = []
9-
for i in list(range(10)):
10+
for i in range(10):
1011
f = Future()
1112
f.set_result(i)
1213
lst.append(f)
1314
batched_lst = [Future(), Future(), Future()]
14-
batched_lst[0].set_result([0, 1, 2])
15-
batched_lst[1].set_result([3, 4, 5])
16-
batched_lst[2].set_result([6, 7, 8])
17-
self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [0, 1, 2])
18-
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3), [3, 4, 5])
19-
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3), [6, 7, 8])
20-
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3), [9])
15+
batched_lst[0].set_result([id(lst[0]), id(lst[1]), id(lst[2])])
16+
batched_lst[1].set_result([id(lst[3]), id(lst[4]), id(lst[5])])
17+
batched_lst[2].set_result([id(lst[6]), id(lst[7]), id(lst[8])])
18+
success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set())
19+
self.assertTrue(success)
20+
self.assertEqual([f.result() for f in done_lst], [0, 1, 2])
21+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3)
22+
self.assertTrue(success)
23+
self.assertEqual([f.result() for f in done_lst], [3, 4, 5])
24+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3)
25+
self.assertTrue(success)
26+
self.assertEqual([f.result() for f in done_lst], [6, 7, 8])
27+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3)
28+
self.assertTrue(success)
29+
self.assertEqual([f.result() for f in done_lst], [9])
30+
31+
def test_batched_futures_duplicated(self):
32+
lst = []
33+
for i in range(1,4):
34+
for _ in range(3):
35+
f = Future()
36+
f.set_result(i)
37+
lst.append(f)
38+
batched_lst = [Future(), Future(), Future()]
39+
batched_lst[0].set_result([id(lst[0]), id(lst[1]), id(lst[2])])
40+
batched_lst[1].set_result([id(lst[3]), id(lst[4]), id(lst[5])])
41+
batched_lst[2].set_result([id(lst[6]), id(lst[7]), id(lst[8])])
42+
success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set())
43+
self.assertTrue(success)
44+
self.assertEqual([f.result() for f in done_lst], [1, 1, 1])
45+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3)
46+
self.assertTrue(success)
47+
self.assertEqual([f.result() for f in done_lst], [2, 2, 2])
48+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3)
49+
self.assertTrue(success)
50+
self.assertEqual([f.result() for f in done_lst], [3, 3, 3])
51+
52+
def test_batched_futures(self):
53+
lst = []
54+
for i in range(10):
55+
f = Future()
56+
if i % 3 == 0:
57+
f.set_exception(ValueError(f"Error for {i}"))
58+
else:
59+
f.set_result(i)
60+
lst.append(f)
61+
batched_lst = [Future(), Future()]
62+
batched_lst[0].set_result([id(lst[1]), id(lst[2]), id(lst[4])])
63+
batched_lst[1].set_result([id(lst[5]), id(lst[7]), id(lst[8])])
64+
success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set())
65+
self.assertTrue(success)
66+
self.assertEqual([f.result() for f in done_lst], [1, 2, 4])
67+
success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3)
68+
self.assertTrue(success)
69+
self.assertEqual([f.result() for f in done_lst], [5, 7, 8])
70+
succss, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3)
71+
self.assertFalse(succss)
72+
with self.assertRaises(ValueError):
73+
raise done_lst[0].exception()
2174

2275
def test_batched_futures_not_finished(self):
2376
lst = []
2477
for _ in list(range(10)):
2578
f = Future()
2679
lst.append(f)
27-
self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [])
80+
success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set())
81+
self.assertTrue(success)
82+
self.assertEqual(done_lst, [])

0 commit comments

Comments
 (0)