Skip to content

Commit 5a1b01d

Browse files
committed
Wait for the workers to exit before exiting the interpreter
1 parent cd8c330 commit 5a1b01d

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

src/aiointerpreters/runner.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import inspect
33
from concurrent.interpreters import create, create_queue
44
from contextlib import contextmanager
5-
from enum import StrEnum, auto
65
from functools import cache, wraps
76
from textwrap import dedent
87
from threading import Thread
@@ -21,12 +20,6 @@ class InterpreterError(Exception): ...
2120
class RunnerError(Exception): ...
2221

2322

24-
class State(StrEnum):
25-
STOPPED = auto()
26-
STARTED = auto()
27-
STOPPING = auto()
28-
29-
3023
class Runner:
3124
def __init__(self, *, workers: int) -> None:
3225
self._tasks = create_queue()
@@ -54,7 +47,6 @@ def load_entry_point(entry_point_type, path_or_module, name):
5447
while True:
5548
match tasks.get():
5649
case None:
57-
results.put(None)
5850
break
5951
case id, entry_point_type, module, name, args, kwargs:
6052
try:
@@ -66,7 +58,6 @@ def load_entry_point(entry_point_type, path_or_module, name):
6658
results.put((id, True, res))
6759
""")
6860
self.workers = workers
69-
self.stopped = True
7061
self.threads = []
7162

7263
def _worker(self) -> None:
@@ -81,35 +72,38 @@ def start(self) -> Iterator[Self]:
8172
8273
This will create the workers eagerly.
8374
"""
84-
threads = [
85-
Thread(target=self._coordinator, daemon=True),
86-
*(Thread(target=self._worker, daemon=True) for _ in range(self.workers)),
87-
]
88-
for t in threads:
89-
t.start()
90-
self.stopped = False
75+
coordinator = Thread(target=self._coordinator, daemon=True)
76+
workers = [Thread(target=self._worker, daemon=True) for _ in range(self.workers)]
77+
coordinator.start()
78+
for worker in workers:
79+
worker.start()
80+
9181
try:
9282
yield self
9383
finally:
84+
# Signal to the workers
9485
for _ in range(self.workers):
9586
self._tasks.put(None)
96-
self.stopped = True
9787

98-
for t in threads:
99-
t.join()
88+
# Wait for workers to exit
89+
for worker in workers:
90+
worker.join()
91+
92+
# Signal to the coordinator
93+
self._results.put(None)
94+
95+
# Wait for coordinator to exit
96+
coordinator.join()
10097

10198
def _coordinator(self) -> None:
102-
workers = self.workers
103-
while workers > 0:
99+
while True:
104100
match self._results.get():
105101
case None:
106102
# Interpreter closed
107-
workers -= 1
103+
return
108104
case int(i), False, str(reason):
109105
future, loop = self._futures.pop(i)
110-
loop.call_soon_threadsafe(
111-
future.set_exception, InterpreterError(reason)
112-
)
106+
loop.call_soon_threadsafe(future.set_exception, InterpreterError(reason))
113107
case int(i), True, result:
114108
future, loop = self._futures.pop(i)
115109
loop.call_soon_threadsafe(future.set_result, result)
@@ -157,7 +151,6 @@ async def run_module_function(
157151
*args: Shareable,
158152
**kwargs: Shareable,
159153
) -> object:
160-
assert not self.stopped, "Runner must be started"
161154
future = asyncio.Future()
162155
id_ = id(future)
163156
self._futures[id_] = future, asyncio.get_running_loop()

0 commit comments

Comments
 (0)