Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions ipymini/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def __init__(self, kernel: "MiniKernel", subshell_id:str|None, user_ns: dict,
self.parent_header_var = contextvars.ContextVar("parent_header", default=None)
self.parent_idents_var = contextvars.ContextVar("parent_idents", default=None)
self.executing = threading.Event()
self.exec_scope = None
self.exec_scopes = set()
self.exec_scope_var = contextvars.ContextVar("exec_scope", default=None)
self.exec_count = 0
self.tasks = set()
self.sync_executing = threading.Event()
self.exec_state = ExecState.IDLE
self.last_exec_state = None
Expand Down Expand Up @@ -200,14 +203,15 @@ def interrupt(self)->bool:
return _raise_async_exception(thread_id, KeyboardInterrupt)

def cancel_async_execution(self, *, wake: bool = False)->bool:
"Cancel the current async cell execution, if there is one."
scope = self.exec_scope
if scope is None: return False
return scope.cancelled or scope.cancel("interrupt", wake=wake)
"Cancel all current async cell executions, if there are any."
scopes = list(self.exec_scopes)
if not scopes: return False
return any([scope.cancelled or scope.cancel("interrupt", wake=wake) for scope in scopes])

def _async_cancel_scope(self):
scope = self.exec_scope
if scope is None: self.exec_scope = scope = CancelScope()
scope = CancelScope()
self.exec_scopes.add(scope)
self.exec_scope_var.set(scope)
if self._get_exec_state() == ExecState.CANCELLING: scope.cancel("interrupt")
return scope

Expand Down Expand Up @@ -284,8 +288,23 @@ async def _main(self):
self.loop_ready.set()
dbg(f"SUBSHELL started id={self.subshell_id}")
await self.actor.run(bind=False)

async def _handle_actor_item(self, item):
await self._cancel_tasks()

async def _cancel_tasks(self):
if not self.tasks: return
tasks = list(self.tasks)
for task in tasks: task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)

def _handle_actor_item(self, item):
# Handle each message in its own task so the mailbox keeps draining while an async cell awaits
# (e.g. a cell calling back into the kernel via a server roundtrip). Sync cells never suspend, so
# they still run strictly in order; tasks start in mailbox order, keeping the aborting check ordered.
task = asyncio.create_task(self._handle_item(item))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

async def _handle_item(self, item):
msg, idents = item
if msg is subshell_abort_clear:
self._stop_aborting()
Expand Down Expand Up @@ -410,6 +429,7 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
allow_stdin = bool(content.get("allow_stdin", False))

dbg(f"HANDLE_EXEC id={msg_id} code={code[:30]!r}...")
with self.state_lock: self.exec_count += 1
self._set_exec_state(ExecState.RUNNING)
self.executing.set()
terminal_state = ExecState.COMPLETED
Expand All @@ -432,7 +452,8 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
result = await self.shell.execute(code, silent=silent, store_history=store_history,
user_expressions=user_expressions, allow_stdin=allow_stdin)
finally:
self.exec_scope = None
self.exec_scopes.discard(self.exec_scope_var.get())
self.exec_scope_var.set(None)
if timeout_handle: timeout_handle.cancel()
dbg(f"BRIDGE_DONE id={msg_id}")

Expand Down Expand Up @@ -470,8 +491,10 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
finally:
self.kernel.send_status("idle", msg)
self._set_last_exec_state(terminal_state)
self.executing.clear()
self._set_exec_state(ExecState.IDLE)
with self.state_lock:
self.exec_count -= 1
if (idle := self.exec_count == 0): self.exec_state = ExecState.IDLE
if idle: self.executing.clear()

def _shell_handler(self, msg: dict, idents: list[bytes]|None):
msg_type = nested_idx(msg, "header", "msg_type") or None
Expand Down
38 changes: 38 additions & 0 deletions tests/kernel/test_concurrent_execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import time
from ..kernel_utils import *

# Mimics solveit's `load_dialog` re-entrancy: a cell awaits an HTTP call to the solveit server,
# which sends further execute_requests back to the same kernel (`import_code`). The first cell
# only completes once the later requests have run, so the kernel must keep processing shell
# messages while an async cell is awaiting (ipyku_launcher/UnlockKernel semantics).

_waiter = """import asyncio
async def _wait_flag():
for _ in range(160):
if globals().get('_flag'): return True
await asyncio.sleep(0.05)
return False
assert await _wait_flag(), 'second execute never ran while first cell awaited'
"""


def test_execute_processed_while_async_cell_awaits():
with start_kernel() as (_, kc):
start = time.monotonic()
mid1 = kc.execute(_waiter)
mid2 = kc.execute("_flag = True")
replies = collect_shell_replies(kc, {mid1, mid2}, timeout=12)
assert replies[mid2]["content"]["status"] == "ok", replies[mid2]["content"]
assert replies[mid1]["content"]["status"] == "ok", replies[mid1]["content"]
assert time.monotonic() - start < 5, "execute_requests were serialized behind the awaiting cell"


def test_sync_cells_still_run_in_order():
with start_kernel() as (_, kc):
mids = [kc.execute(code) for code in ("order = []", "order.append(1)", "order.append(2)", "order.append(3)")]
replies = collect_shell_replies(kc, set(mids), timeout=10)
for mid in mids: assert replies[mid]["content"]["status"] == "ok", replies[mid]["content"]
_, reply, outputs = kc.exec_drain("print(order)")
assert reply["content"]["status"] == "ok"
texts = "".join(m["content"]["text"] for m in iopub_streams(outputs))
assert "[1, 2, 3]" in texts, texts
3 changes: 2 additions & 1 deletion tests/kernel/test_subshells.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def test_subshell_stop_on_error_isolated():
subshell_ids = [_create_subshell(kc) if is_subshell else None for is_subshell in are_subshells]

msg_ids = []
msg_id = _send_execute(kc, "import asyncio; await asyncio.sleep(0.1); raise ValueError()", subshell_id=subshell_ids[0])
# Sync cell: an async failing cell no longer aborts later submissions, since they run while it awaits
msg_id = _send_execute(kc, "import time; time.sleep(0.1); raise ValueError()", subshell_id=subshell_ids[0])
msg_ids.append(msg_id)
msg_id = _send_execute(kc, "print('hello')", subshell_id=subshell_ids[0])
msg_ids.append(msg_id)
Expand Down