|
9 | 9 |
|
10 | 10 |
|
11 | 11 | coroutine = Callable[[int], Coroutine[Any, Any, str]] |
| 12 | +OnCompleteFn = Callable[[Any, Optional[Exception]], Awaitable[None]] |
12 | 13 |
|
13 | 14 |
|
14 | | -def coroutine_in_thread(coro: coroutine, callback: Optional[coroutine] = None): |
| 15 | +def coroutine_in_thread( |
| 16 | + coro: coroutine, |
| 17 | + callback: Optional[coroutine] = None, |
| 18 | + on_complete: OnCompleteFn = None, |
| 19 | +) -> threading.Event: |
15 | 20 | """Run a coroutine in a new thread with its own event loop.""" |
| 21 | + parent_loop = asyncio.get_running_loop() |
16 | 22 | done_event = threading.Event() |
17 | 23 |
|
18 | | - def run(): |
| 24 | + def _runner(): |
19 | 25 | new_loop = asyncio.new_event_loop() |
20 | 26 | asyncio.set_event_loop(new_loop) |
21 | | - result = new_loop.run_until_complete(coro) |
22 | | - # if callback exists: |
23 | | - if callback: |
24 | | - new_loop.run_until_complete(callback(result)) |
25 | | - new_loop.close() |
26 | | - done_event.set() # Signal that the coroutine has completed |
27 | | - thread = threading.Thread(target=run, daemon=True) |
28 | | - thread.start() |
| 27 | + result, exc = None, None |
| 28 | + try: |
| 29 | + result = new_loop.run_until_complete(coro) |
| 30 | + except Exception as e: # noqa: BLE001 |
| 31 | + exc = e |
| 32 | + finally: |
| 33 | + if callback: |
| 34 | + new_loop.run_until_complete( |
| 35 | + callback(result, exc, loop=new_loop) |
| 36 | + ) |
| 37 | + new_loop.close() |
| 38 | + done_event.set() # Signal that the coroutine has completed |
| 39 | + if on_complete is None: |
| 40 | + return |
| 41 | + fut = asyncio.run_coroutine_threadsafe( |
| 42 | + on_complete(result, exc), parent_loop |
| 43 | + ) |
| 44 | + fut.result() # Wait for the completion of the callback |
29 | 45 |
|
| 46 | + threading.Thread(target=_runner, daemon=True).start() |
30 | 47 | return done_event |
31 | 48 |
|
32 | 49 |
|
@@ -115,16 +132,33 @@ async def __call__(self): |
115 | 132 | # Delay the execution by jitter seconds |
116 | 133 | await asyncio.sleep(delay) |
117 | 134 | try: |
| 135 | + async def _finish(result: Any, exc: Exception): |
| 136 | + """Callback to handle the completion of the coroutine.""" |
| 137 | + if exc: |
| 138 | + self.logger.error( |
| 139 | + f"TaskWrapper {self._name} failed with exception: {exc}" |
| 140 | + ) |
| 141 | + result = { |
| 142 | + "status": "failed", |
| 143 | + "error": str(exc) |
| 144 | + } |
| 145 | + if self.tracker: |
| 146 | + await self.tracker.set_failed(self.task_uuid, exc) |
| 147 | + else: |
| 148 | + self.logger.debug( |
| 149 | + f"TaskWrapper {self._name} completed successfully." |
| 150 | + ) |
| 151 | + result = { |
| 152 | + "status": "done", |
| 153 | + "result": result |
| 154 | + } |
| 155 | + if self.tracker: |
| 156 | + await self.tracker.set_done(self.task_uuid, result) |
| 157 | + return result |
118 | 158 | with ThreadPoolExecutor(max_workers=1) as executor: |
119 | 159 | coro = self.fn(*self.args, **self.kwargs) |
120 | | - coroutine_in_thread(coro, self._callback_) |
121 | | - result = { |
122 | | - "status": "done" |
123 | | - } |
124 | | - if self.tracker: |
125 | | - # Set the job as done in the tracker |
126 | | - await self.tracker.set_done(self.task_uuid, result) |
127 | | - return result |
| 160 | + coroutine_in_thread(coro, self._callback_, on_complete=_finish) |
| 161 | + return {"status": "running"} |
128 | 162 | except asyncio.CancelledError: |
129 | 163 | self.logger.warning( |
130 | 164 | f"TaskWrapper {self.fn.__name__} was cancelled." |
|
0 commit comments