Skip to content

Commit 11facd9

Browse files
authored
Improve Modal backups implementation (#238)
1 parent 647a68b commit 11facd9

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

cubed/runtime/executors/modal_async.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515

1616
# We need map_unordered for the use_backups implementation
1717
async def map_unordered(
18-
app_function, input, use_backups=False, return_stats=False, name=None, **kwargs
18+
app_function,
19+
input,
20+
use_backups=False,
21+
backup_function=None,
22+
return_stats=False,
23+
name=None,
24+
**kwargs,
1925
):
2026
"""
2127
Apply a function to items of an input list, yielding results as they are completed
@@ -42,6 +48,8 @@ async def map_unordered(
4248
yield result
4349
return
4450

51+
backup_function = backup_function or app_function
52+
4553
tasks = {
4654
asyncio.ensure_future(app_function.call.aio(i, **kwargs)): i for i in input
4755
}
@@ -55,12 +63,14 @@ async def map_unordered(
5563
finished, pending = await asyncio.wait(
5664
pending, return_when=asyncio.FIRST_COMPLETED, timeout=2
5765
)
58-
# print("finished", finished)
59-
# print("pending", pending)
60-
6166
for task in finished:
6267
# TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions
6368
if task.exception():
69+
# if the task has a backup that is not done, or is done with no exception, then don't raise this exception
70+
backup = backups.get(task, None)
71+
if backup:
72+
if not backup.done() or not backup.exception():
73+
continue
6474
raise task.exception()
6575
end_times[task] = time.monotonic()
6676
if return_stats:
@@ -76,9 +86,11 @@ async def map_unordered(
7686
if use_backups:
7787
backup = backups.get(task, None)
7888
if backup:
79-
pending.remove(backup)
89+
if backup in pending:
90+
pending.remove(backup)
8091
del backups[task]
8192
del backups[backup]
93+
backup.cancel()
8294

8395
if use_backups:
8496
now = time.monotonic()
@@ -87,8 +99,11 @@ async def map_unordered(
8799
task, now, start_times, end_times
88100
):
89101
# launch backup task
102+
print("Launching backup task")
90103
i = tasks[task]
91-
new_task = asyncio.ensure_future(app_function.call.aio(i, **kwargs))
104+
new_task = asyncio.ensure_future(
105+
backup_function.call.aio(i, **kwargs)
106+
)
92107
tasks[new_task] = i
93108
start_times[new_task] = time.monotonic()
94109
pending.add(new_task)

0 commit comments

Comments
 (0)