diff --git a/src/executorlib/standalone/batched.py b/src/executorlib/standalone/batched.py index 3c44c0c7..8c4ab149 100644 --- a/src/executorlib/standalone/batched.py +++ b/src/executorlib/standalone/batched.py @@ -3,7 +3,7 @@ def batched_futures( lst: list[Future], nested_skip_lst: list[Future[list]], n: int -) -> list[list]: +) -> tuple[bool, list[Future]]: """ Batch n completed future objects. If the number of completed futures is smaller than n and the end of the batch is not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set @@ -11,19 +11,29 @@ def batched_futures( Args: lst (list): list of all future objects - nested_skip_lst (list): nest list of individual results already assigned to previous batches + nested_skip_lst (list): list of future objects, which contain the list of future objects ids which should be skipped for the batch n (int): batch size Returns: list: results of the batched futures """ - skip_set = {id(item) for f in nested_skip_lst for item in f.result()} + skip_set = {fid for f in nested_skip_lst for fid in f.result()} done_lst = [] + failed_lst = [] n_expected = min(n, len(lst) - len(skip_set)) for v in lst: - if v.done() and id(v.result()) not in skip_set: - done_lst.append(v.result()) - if len(done_lst) == n_expected: - return done_lst - return [] + if id(v) not in skip_set and v.done(): + if v.exception() is not None: + failed_lst.append(v) + elif id(v) not in skip_set and v.done(): + done_lst.append(v) + if len(done_lst) == n_expected: + return True, done_lst + if (len(lst) - len(skip_set)) == len(failed_lst): + return ( + False, + failed_lst[:n_expected], + ) # raise the exception only after all futures have failed + else: + return True, [] diff --git a/src/executorlib/task_scheduler/interactive/dependency.py b/src/executorlib/task_scheduler/interactive/dependency.py index 349b3f1c..9f2a69dd 100644 --- a/src/executorlib/task_scheduler/interactive/dependency.py +++ b/src/executorlib/task_scheduler/interactive/dependency.py @@ -177,6 +177,7 @@ def batched( future_lst: list[Future] = [] for _ in range(len(iterable) // n + (1 if len(iterable) % n > 0 else 0)): f: Future = Future() + f_skip: Future = Future() if self._future_queue is not None: self._future_queue.put( { @@ -184,10 +185,11 @@ def batched( "args": (), "kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst}, "future": f, + "future_skip": f_skip, "resource_dict": {}, } ) - skip_lst = skip_lst.copy() + [f] # be careful + skip_lst = skip_lst.copy() + [f_skip] # be careful future_lst.append(f) return future_lst @@ -330,7 +332,7 @@ def _update_waiting_task( wait_tmp_lst = [] for task_wait_dict in wait_lst: exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"]) - if len(exception_lst) > 0: + if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched": task_wait_dict["future"].set_exception(exception_lst[0]) elif task_wait_dict["fn"] != "batched" and all( future.done() for future in task_wait_dict["future_lst"] @@ -343,15 +345,19 @@ def _update_waiting_task( elif task_wait_dict["fn"] == "batched" and all( future.done() for future in task_wait_dict["kwargs"]["skip_lst"] ): - done_lst = batched_futures( + success, done_lst = batched_futures( lst=task_wait_dict["kwargs"]["lst"], n=task_wait_dict["kwargs"]["n"], nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"], ) - if len(done_lst) == 0: + if success and len(done_lst) == 0: wait_tmp_lst.append(task_wait_dict) + elif success and len(done_lst) > 0: + task_wait_dict["future"].set_result([f.result() for f in done_lst]) + task_wait_dict["future_skip"].set_result([id(f) for f in done_lst]) else: - task_wait_dict["future"].set_result(done_lst) + task_wait_dict["future"].set_exception(done_lst[0].exception()) + task_wait_dict["future_skip"].set_result([id(f) for f in done_lst]) else: wait_tmp_lst.append(task_wait_dict) if len(wait_lst) == len(wait_tmp_lst): diff --git a/tests/unit/executor/test_single_dependencies.py b/tests/unit/executor/test_single_dependencies.py index 98e1b14c..83de011a 100644 --- a/tests/unit/executor/test_single_dependencies.py +++ b/tests/unit/executor/test_single_dependencies.py @@ -82,6 +82,30 @@ def test_batched(self): self.assertEqual(len(result_lst), 4) self.assertTrue(t3-t2 > t2-t1) + def test_batched_error_future(self): + with SingleNodeExecutor() as exe: + t1 = time() + future_first_lst = [] + for i in range(10): + if i % 3 == 0: + future_first_lst.append(exe.submit(raise_error, parameter=0)) + else: + future_first_lst.append(exe.submit(return_input_dict, i)) + future_second_lst = exe.batched(future_first_lst, n=3) + + future_third_lst = [] + for f in future_second_lst: + future_third_lst.append(exe.submit(sum, f)) + + t2 = time() + self.assertEqual(future_third_lst[0].result() + future_third_lst[1].result(), 27) + with self.assertRaises(RuntimeError): + future_third_lst[2].result() + with self.assertRaises(RuntimeError): + future_third_lst[3].result() + t3 = time() + self.assertTrue(t3-t2 > t2-t1) + def test_batched_error(self): with self.assertRaises(TypeError): with SingleNodeExecutor() as exe: diff --git a/tests/unit/standalone/test_batched.py b/tests/unit/standalone/test_batched.py index 31e3d578..58388746 100644 --- a/tests/unit/standalone/test_batched.py +++ b/tests/unit/standalone/test_batched.py @@ -1,27 +1,82 @@ -from unittest import TestCase +import unittest from concurrent.futures import Future -from executorlib.standalone.batched import batched_futures +from executorlib.task_scheduler.interactive.dependency import batched_futures -class TestBatched(TestCase): + +class TestBatched(unittest.TestCase): def test_batched_futures(self): lst = [] - for i in list(range(10)): + for i in range(10): f = Future() f.set_result(i) lst.append(f) batched_lst = [Future(), Future(), Future()] - batched_lst[0].set_result([0, 1, 2]) - batched_lst[1].set_result([3, 4, 5]) - batched_lst[2].set_result([6, 7, 8]) - self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [0, 1, 2]) - self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3), [3, 4, 5]) - self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3), [6, 7, 8]) - self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3), [9]) + batched_lst[0].set_result([id(lst[0]), id(lst[1]), id(lst[2])]) + batched_lst[1].set_result([id(lst[3]), id(lst[4]), id(lst[5])]) + batched_lst[2].set_result([id(lst[6]), id(lst[7]), id(lst[8])]) + success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set()) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [0, 1, 2]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [3, 4, 5]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [6, 7, 8]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [9]) + + def test_batched_futures_duplicated(self): + lst = [] + for i in range(1,4): + for _ in range(3): + f = Future() + f.set_result(i) + lst.append(f) + batched_lst = [Future(), Future(), Future()] + batched_lst[0].set_result([id(lst[0]), id(lst[1]), id(lst[2])]) + batched_lst[1].set_result([id(lst[3]), id(lst[4]), id(lst[5])]) + batched_lst[2].set_result([id(lst[6]), id(lst[7]), id(lst[8])]) + success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set()) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [1, 1, 1]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [2, 2, 2]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [3, 3, 3]) + + def test_batched_futures(self): + lst = [] + for i in range(10): + f = Future() + if i % 3 == 0: + f.set_exception(ValueError(f"Error for {i}")) + else: + f.set_result(i) + lst.append(f) + batched_lst = [Future(), Future()] + batched_lst[0].set_result([id(lst[1]), id(lst[2]), id(lst[4])]) + batched_lst[1].set_result([id(lst[5]), id(lst[7]), id(lst[8])]) + success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set()) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [1, 2, 4]) + success, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3) + self.assertTrue(success) + self.assertEqual([f.result() for f in done_lst], [5, 7, 8]) + succss, done_lst = batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3) + self.assertFalse(succss) + with self.assertRaises(ValueError): + raise done_lst[0].exception() def test_batched_futures_not_finished(self): lst = [] for _ in list(range(10)): f = Future() lst.append(f) - self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), []) + success, done_lst = batched_futures(lst=lst, n=3, nested_skip_lst=set()) + self.assertTrue(success) + self.assertEqual(done_lst, [])