Skip to content
Merged
20 changes: 14 additions & 6 deletions src/executorlib/standalone/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def batched_futures(
lst: list[Future], nested_skip_lst: list[Future[list]], n: int
) -> list[list]:
) -> list[list] | BaseException:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
"""
Batch n completed future objects. If the number of completed futures is smaller than n and the end of the batch is
not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set
Expand All @@ -20,10 +20,18 @@ def batched_futures(
skip_set = {id(item) for f in nested_skip_lst for item in f.result()}

done_lst = []
failed_lst = []
n_expected = min(n, len(lst) - len(skip_set))
for v in lst:
if v.done() and id(v.result()) not in skip_set:
done_lst.append(v.result())
if len(done_lst) == n_expected:
return done_lst
return []
if v.done():
excp = v.exception()
if excp is not None:
failed_lst.append(excp)
elif id(v.result()) not in skip_set:
done_lst.append(v.result())
if len(done_lst) == n_expected:
return done_lst
if len(failed_lst) == len(lst) - len(skip_set) and len(failed_lst) > 0:
return failed_lst[0] # raise the exception only after all futures have failed
else:
return []
6 changes: 4 additions & 2 deletions src/executorlib/task_scheduler/interactive/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,12 @@ def _update_waiting_task(
n=task_wait_dict["kwargs"]["n"],
nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"],
)
if len(done_lst) == 0:
if isinstance(done_lst, list) and len(done_lst) == 0:
wait_tmp_lst.append(task_wait_dict)
else:
elif isinstance(done_lst, list) and len(done_lst) > 0:
task_wait_dict["future"].set_result(done_lst)
else:
task_wait_dict["future"].set_exception(done_lst)
else:
wait_tmp_lst.append(task_wait_dict)
if len(wait_lst) == len(wait_tmp_lst):
Expand Down
41 changes: 37 additions & 4 deletions tests/unit/standalone/test_batched.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from unittest import TestCase
import unittest
from concurrent.futures import Future
from executorlib.standalone.batched import batched_futures

from executorlib.task_scheduler.interactive.dependency import batched_futures

class TestBatched(TestCase):

class TestBatched(unittest.TestCase):
def test_batched_futures(self):
lst = []
for i in list(range(10)):
for i in range(10):
f = Future()
f.set_result(i)
lst.append(f)
Expand All @@ -19,6 +20,38 @@ def test_batched_futures(self):
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3), [6, 7, 8])
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3), [9])

def test_batched_futures_duplicated(self):
lst = []
for i in range(1,4):
for _ in range(3):
f = Future()
f.set_result(i)
lst.append(f)
batched_lst = [Future(), Future(), Future()]
batched_lst[0].set_result([1, 1, 1])
batched_lst[1].set_result([2, 2, 2])
batched_lst[2].set_result([3, 3, 3])
self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [1, 1, 1])
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3), [2, 2, 2])
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3), [3, 3, 3])

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def test_batched_futures(self):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
lst = []
for i in range(10):
f = Future()
if i % 3 == 0:
f.set_exception(ValueError(f"Error for {i}"))
else:
f.set_result(i)
lst.append(f)
batched_lst = [Future(), Future()]
batched_lst[0].set_result([1, 2, 4])
batched_lst[1].set_result([5, 7, 8])
self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [1, 2, 4])
self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3), [5, 7, 8])
with self.assertRaises(ValueError):
raise batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3)

def test_batched_futures_not_finished(self):
lst = []
for _ in list(range(10)):
Expand Down
Loading