Skip to content

Commit 230349d

Browse files
Replace BaseThread's add_task with start_soon (#1300)
This PR replaces a BaseThread's add_task() method with start_soon(). The new name is less confusing as it's the same as in AnyIO, and it allows to start a task in the thread even after the thread has been started. We also get rid of _IOPubThread, which has no reason to be different than a BaseThread. (from #1291)
1 parent eb0aee6 commit 230349d

File tree

4 files changed

+70
-55
lines changed

4 files changed

+70
-55
lines changed

ipykernel/iostream.py

+9-37
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from binascii import b2a_hex
1717
from collections import defaultdict, deque
1818
from io import StringIO, TextIOBase
19-
from threading import Event, Thread, local
19+
from threading import local
2020
from typing import Any, Callable
2121

2222
import zmq
23-
from anyio import create_task_group, run, sleep, to_thread
23+
from anyio import sleep
2424
from jupyter_client.session import extract_header
2525

26+
from .thread import BaseThread
27+
2628
# -----------------------------------------------------------------------------
2729
# Globals
2830
# -----------------------------------------------------------------------------
@@ -37,38 +39,6 @@
3739
# -----------------------------------------------------------------------------
3840

3941

40-
class _IOPubThread(Thread):
41-
"""A thread for a IOPub."""
42-
43-
def __init__(self, tasks, **kwargs):
44-
"""Initialize the thread."""
45-
super().__init__(name="IOPub", **kwargs)
46-
self._tasks = tasks
47-
self.pydev_do_not_trace = True
48-
self.is_pydev_daemon_thread = True
49-
self.daemon = True
50-
self.__stop = Event()
51-
52-
def run(self):
53-
"""Run the thread."""
54-
self.name = "IOPub"
55-
run(self._main)
56-
57-
async def _main(self):
58-
async with create_task_group() as tg:
59-
for task in self._tasks:
60-
tg.start_soon(task)
61-
await to_thread.run_sync(self.__stop.wait)
62-
tg.cancel_scope.cancel()
63-
64-
def stop(self):
65-
"""Stop the thread.
66-
67-
This method is threadsafe.
68-
"""
69-
self.__stop.set()
70-
71-
7242
class IOPubThread:
7343
"""An object for sending IOPub messages in a background thread
7444
@@ -111,7 +81,9 @@ def __init__(self, socket, pipe=False):
11181
tasks = [self._handle_event, self._run_event_pipe_gc]
11282
if pipe:
11383
tasks.append(self._handle_pipe_msgs)
114-
self.thread = _IOPubThread(tasks)
84+
self.thread = BaseThread(name="IOPub", daemon=True)
85+
for task in tasks:
86+
self.thread.start_soon(task)
11587

11688
def _setup_event_pipe(self):
11789
"""Create the PULL socket listening for events that should fire in this thread."""
@@ -181,7 +153,7 @@ async def _handle_event(self):
181153
event_f = self._events.popleft()
182154
event_f()
183155
except Exception:
184-
if self.thread.__stop.is_set():
156+
if self.thread.stopped.is_set():
185157
return
186158
raise
187159

@@ -215,7 +187,7 @@ async def _handle_pipe_msgs(self):
215187
while True:
216188
await self._handle_pipe_msg()
217189
except Exception:
218-
if self.thread.__stop.is_set():
190+
if self.thread.stopped.is_set():
219191
return
220192
raise
221193

ipykernel/kernelbase.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import uuid
1717
import warnings
1818
from datetime import datetime
19+
from functools import partial
1920
from signal import SIGINT, SIGTERM, Signals
2021

2122
from .thread import CONTROL_THREAD_NAME
@@ -529,7 +530,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
529530
self.control_stop = threading.Event()
530531
if not self._is_test and self.control_socket is not None:
531532
if self.control_thread:
532-
self.control_thread.add_task(self.control_main)
533+
self.control_thread.start_soon(self.control_main)
533534
self.control_thread.start()
534535
else:
535536
tg.start_soon(self.control_main)
@@ -544,9 +545,11 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
544545

545546
# Assign tasks to and start shell channel thread.
546547
manager = self.shell_channel_thread.manager
547-
self.shell_channel_thread.add_task(self.shell_channel_thread_main)
548-
self.shell_channel_thread.add_task(manager.listen_from_control, self.shell_main)
549-
self.shell_channel_thread.add_task(manager.listen_from_subshells)
548+
self.shell_channel_thread.start_soon(self.shell_channel_thread_main)
549+
self.shell_channel_thread.start_soon(
550+
partial(manager.listen_from_control, self.shell_main)
551+
)
552+
self.shell_channel_thread.start_soon(manager.listen_from_subshells)
550553
self.shell_channel_thread.start()
551554
else:
552555
if not self._is_test and self.shell_socket is not None:

ipykernel/subshell_manager.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import typing as t
88
import uuid
99
from dataclasses import dataclass
10+
from functools import partial
1011
from threading import Lock, current_thread, main_thread
1112

1213
import zmq
@@ -186,8 +187,8 @@ async def _create_subshell(self, subshell_task: t.Any) -> str:
186187
await self._send_stream.send(subshell_id)
187188

188189
address = self._get_inproc_socket_address(subshell_id)
189-
thread.add_task(thread.create_pair_socket, self._context, address)
190-
thread.add_task(subshell_task, subshell_id)
190+
thread.start_soon(partial(thread.create_pair_socket, self._context, address))
191+
thread.start_soon(partial(subshell_task, subshell_id))
191192
thread.start()
192193

193194
return subshell_id

ipykernel/thread.py

+51-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Base class for threads."""
2+
23
from __future__ import annotations
34

4-
import typing as t
5+
from collections.abc import Awaitable
6+
from queue import Queue
57
from threading import Event, Thread
8+
from typing import Any, Callable
69

710
from anyio import create_task_group, run, to_thread
11+
from anyio.abc import TaskGroup
812

913
CONTROL_THREAD_NAME = "Control"
1014
SHELL_CHANNEL_THREAD_NAME = "Shell channel"
@@ -16,29 +20,64 @@ class BaseThread(Thread):
1620
def __init__(self, **kwargs):
1721
"""Initialize the thread."""
1822
super().__init__(**kwargs)
23+
self.started = Event()
24+
self.stopped = Event()
1925
self.pydev_do_not_trace = True
2026
self.is_pydev_daemon_thread = True
21-
self.__stop = Event()
22-
self._tasks_and_args: list[tuple[t.Any, t.Any]] = []
27+
self._tasks: Queue[tuple[str, Callable[[], Awaitable[Any]]] | None] = Queue()
28+
self._result: Queue[Any] = Queue()
29+
self._exception: Exception | None = None
30+
31+
@property
32+
def exception(self) -> Exception | None:
33+
return self._exception
34+
35+
@property
36+
def task_group(self) -> TaskGroup:
37+
return self._task_group
2338

24-
def add_task(self, task: t.Any, *args: t.Any) -> None:
25-
# May only add tasks before the thread is started.
26-
self._tasks_and_args.append((task, args))
39+
def start_soon(self, coro: Callable[[], Awaitable[Any]]) -> None:
40+
self._tasks.put(("start_soon", coro))
2741

28-
def run(self) -> t.Any:
42+
def run_async(self, coro: Callable[[], Awaitable[Any]]) -> Any:
43+
self._tasks.put(("run_async", coro))
44+
return self._result.get()
45+
46+
def run_sync(self, func: Callable[..., Any]) -> Any:
47+
self._tasks.put(("run_sync", func))
48+
return self._result.get()
49+
50+
def run(self) -> None:
2951
"""Run the thread."""
30-
return run(self._main)
52+
try:
53+
run(self._main)
54+
except Exception as exc:
55+
self._exception = exc
3156

3257
async def _main(self) -> None:
3358
async with create_task_group() as tg:
34-
for task, args in self._tasks_and_args:
35-
tg.start_soon(task, *args)
36-
await to_thread.run_sync(self.__stop.wait)
59+
self._task_group = tg
60+
self.started.set()
61+
while True:
62+
task = await to_thread.run_sync(self._tasks.get)
63+
if task is None:
64+
break
65+
func, arg = task
66+
if func == "start_soon":
67+
tg.start_soon(arg)
68+
elif func == "run_async":
69+
res = await arg
70+
self._result.put(res)
71+
else: # func == "run_sync"
72+
res = arg()
73+
self._result.put(res)
74+
3775
tg.cancel_scope.cancel()
3876

3977
def stop(self) -> None:
4078
"""Stop the thread.
4179
4280
This method is threadsafe.
4381
"""
44-
self.__stop.set()
82+
self._tasks.put(None)
83+
self.stopped.set()

0 commit comments

Comments
 (0)