22import inspect
33from concurrent .interpreters import create , create_queue
44from contextlib import contextmanager
5- from enum import StrEnum , auto
65from functools import cache , wraps
76from textwrap import dedent
87from threading import Thread
@@ -21,12 +20,6 @@ class InterpreterError(Exception): ...
2120class RunnerError (Exception ): ...
2221
2322
24- class State (StrEnum ):
25- STOPPED = auto ()
26- STARTED = auto ()
27- STOPPING = auto ()
28-
29-
3023class Runner :
3124 def __init__ (self , * , workers : int ) -> None :
3225 self ._tasks = create_queue ()
@@ -54,7 +47,6 @@ def load_entry_point(entry_point_type, path_or_module, name):
5447 while True:
5548 match tasks.get():
5649 case None:
57- results.put(None)
5850 break
5951 case id, entry_point_type, module, name, args, kwargs:
6052 try:
@@ -66,7 +58,6 @@ def load_entry_point(entry_point_type, path_or_module, name):
6658 results.put((id, True, res))
6759 """ )
6860 self .workers = workers
69- self .stopped = True
7061 self .threads = []
7162
7263 def _worker (self ) -> None :
@@ -81,35 +72,38 @@ def start(self) -> Iterator[Self]:
8172
8273 This will create the workers eagerly.
8374 """
84- threads = [
85- Thread (target = self ._coordinator , daemon = True ),
86- * (Thread (target = self ._worker , daemon = True ) for _ in range (self .workers )),
87- ]
88- for t in threads :
89- t .start ()
90- self .stopped = False
75+ coordinator = Thread (target = self ._coordinator , daemon = True )
76+ workers = [Thread (target = self ._worker , daemon = True ) for _ in range (self .workers )]
77+ coordinator .start ()
78+ for worker in workers :
79+ worker .start ()
80+
9181 try :
9282 yield self
9383 finally :
84+ # Signal to the workers
9485 for _ in range (self .workers ):
9586 self ._tasks .put (None )
96- self .stopped = True
9787
98- for t in threads :
99- t .join ()
88+ # Wait for workers to exit
89+ for worker in workers :
90+ worker .join ()
91+
92+ # Signal to the coordinator
93+ self ._results .put (None )
94+
95+ # Wait for coordinator to exit
96+ coordinator .join ()
10097
10198 def _coordinator (self ) -> None :
102- workers = self .workers
103- while workers > 0 :
99+ while True :
104100 match self ._results .get ():
105101 case None :
106102 # Interpreter closed
107- workers -= 1
103+ return
108104 case int (i ), False , str (reason ):
109105 future , loop = self ._futures .pop (i )
110- loop .call_soon_threadsafe (
111- future .set_exception , InterpreterError (reason )
112- )
106+ loop .call_soon_threadsafe (future .set_exception , InterpreterError (reason ))
113107 case int (i ), True , result :
114108 future , loop = self ._futures .pop (i )
115109 loop .call_soon_threadsafe (future .set_result , result )
@@ -157,7 +151,6 @@ async def run_module_function(
157151 * args : Shareable ,
158152 ** kwargs : Shareable ,
159153 ) -> object :
160- assert not self .stopped , "Runner must be started"
161154 future = asyncio .Future ()
162155 id_ = id (future )
163156 self ._futures [id_ ] = future , asyncio .get_running_loop ()
0 commit comments