44from enum import Enum
55from importlib .metadata import PackageNotFoundError , version
66from fastcore .basics import nested_idx , store_attr
7- from microio import ActorCore , CancelScope , CloseScope , ServiceGroup
7+ from microio import ActorCore , CloseScope , ScopeGroup , ServiceGroup , WorkTracker
88import zmq
99from jupyter_client .session import Session
1010from .shell import MiniShell
1111from .comms import get_comm_manager
12+ from .unlock import _release as _unlock_release
1213from .debug import DebugFlags , setup_debug , trace_msg
1314from .zmqthread import AsyncRouterThread , HeartbeatThread , IOPubThread , StdinRouterThread
1415
@@ -128,17 +129,17 @@ def __init__(self, kernel: "MiniKernel", subshell_id:str|None, user_ns: dict,
128129 self .thread = None if not run_in_thread else threading .Thread (target = self ._run_loop , daemon = True , name = name )
129130 self .loop = None
130131 self .loop_ready = threading .Event ()
131- self .actor = ActorCore (self ._handle_actor_item )
132+ self .actor = ActorCore (self ._handle_actor_item , concurrent = True )
132133 self .aborting = False
133134 self .abort_handle = None
134135 self ._shell = None
135136 self .shell_ready = threading .Event ()
136137 self .parent_header_var = contextvars .ContextVar ("parent_header" , default = None )
137138 self .parent_idents_var = contextvars .ContextVar ("parent_idents" , default = None )
138- self .executing = threading .Event ()
139- self .exec_scope = None
139+ self .exec_tracker = WorkTracker ()
140+ self .executing = self .exec_tracker .busy
141+ self .exec_scopes = ScopeGroup ()
140142 self .sync_executing = threading .Event ()
141- self .exec_state = ExecState .IDLE
142143 self .last_exec_state = None
143144 self .state_lock = threading .Lock ()
144145 self .shell_handlers = dict (kernel_info_request = self ._handle_kernel_info , connect_request = self ._handle_connect ,
@@ -157,7 +158,7 @@ def _init_shell(self):
157158 if self ._shell is not None : return
158159 self ._shell = MiniShell (request_input = self .request_input , debug_event_callback = self .send_debug_event ,
159160 zmq_context = self .kernel .context , user_ns = self .user_ns , use_singleton = self .use_singleton ,
160- async_cancel_scope = self ._async_cancel_scope , sync_execution_context = self ._sync_execution_context )
161+ exec_scopes = self .exec_scopes , sync_execution_context = self ._sync_execution_context )
161162 self ._shell .ipy .kernel = self .kernel
162163 self ._shell .set_stream_sender (self ._send_stream )
163164 self ._shell .set_display_sender (self ._send_display_event )
@@ -177,45 +178,46 @@ def stop(self, interrupt: bool = False):
177178
178179 def join (self , timeout :float | None = None )-> bool : return _join_or_log (self .thread , timeout = timeout )
179180
180- def submit (self , msg : dict , idents : list [bytes ]| None )-> bool : return bool (self .actor .submit ((msg , idents )))
181-
182- def _set_exec_state (self , state : ExecState ):
183- with self .state_lock : self .exec_state = state
181+ def submit (self , msg : dict , idents : list [bytes ]| None )-> bool :
182+ "Queue executes behind the cell baton; run other messages on the loop directly."
183+ if nested_idx (msg , "header" , "msg_type" ) == "execute_request" : return bool (self .actor .submit ((msg , idents )))
184+ return self ._submit_direct (msg , idents )
185+
186+ def _submit_direct (self , msg : dict , idents : list [bytes ]| None )-> bool :
187+ "Handle a non-execute message promptly, without queueing behind busy cells."
188+ if self .scope .closed : return False
189+ loop = self .loop
190+ if loop is None : return bool (self .actor .submit ((msg , idents ))) # loop not started yet; deliver via mailbox
191+ item = (msg , idents )
192+ def _run (): loop .create_task (self ._handle_actor_item (item , lambda : None ))
193+ try : loop .call_soon_threadsafe (_run )
194+ except RuntimeError : return False
195+ return True
184196
185197 def _set_last_exec_state (self , state : ExecState ):
186198 with self .state_lock : self .last_exec_state = state
187199
188- def _get_exec_state (self )-> ExecState :
189- with self .state_lock : return self .exec_state
190-
191200 def interrupt (self )-> bool :
192- "Raise KeyboardInterrupt in subshell thread if executing."
193- if self ._get_exec_state () != ExecState .RUNNING : return False
194- self ._set_exec_state (ExecState .CANCELLING )
195- if self .cancel_async_execution (): return True
201+ "Cancel running async executions, or raise KeyboardInterrupt in a sync one."
202+ if not self .executing .is_set (): return False
203+ already_cancelling = self .exec_scopes .cancelling
204+ if self .exec_scopes .cancel ("interrupt" , latch = True ): return True
205+ if already_cancelling : return True # still cancelling from a previous interrupt; don't double-inject
196206 if not self .sync_executing .is_set (): return True
197207 if self .thread is None : return False
198208 thread_id = self .thread .ident
199209 if thread_id is None : return False
200210 return _raise_async_exception (thread_id , KeyboardInterrupt )
201211
202212 def cancel_async_execution (self , * , wake : bool = False )-> bool :
203- "Cancel the current async cell execution, if there is one."
204- scope = self .exec_scope
205- if scope is None : return False
206- return scope .cancelled or scope .cancel ("interrupt" , wake = wake )
207-
208- def _async_cancel_scope (self ):
209- scope = self .exec_scope
210- if scope is None : self .exec_scope = scope = CancelScope ()
211- if self ._get_exec_state () == ExecState .CANCELLING : scope .cancel ("interrupt" )
212- return scope
213+ "Cancel all active async cell executions, if any."
214+ return self .exec_scopes .cancel ("interrupt" , wake = wake , latch = True )
213215
214216 @contextmanager
215217 def _sync_execution_context (self ):
216218 self .sync_executing .set ()
217219 try :
218- if self ._get_exec_state () == ExecState . CANCELLING : raise KeyboardInterrupt
220+ if self .exec_scopes . cancelling : raise KeyboardInterrupt
219221 yield
220222 finally : self .sync_executing .clear ()
221223
@@ -285,7 +287,7 @@ async def _main(self):
285287 dbg (f"SUBSHELL started id={ self .subshell_id } " )
286288 await self .actor .run (bind = False )
287289
288- async def _handle_actor_item (self , item ):
290+ async def _handle_actor_item (self , item , release ):
289291 msg , idents = item
290292 if msg is subshell_abort_clear :
291293 self ._stop_aborting ()
@@ -294,7 +296,7 @@ async def _handle_actor_item(self, item):
294296 msg_type = nested_idx (msg , "header" , "msg_type" ) or "?"
295297 msg_id = (nested_idx (msg , "header" , "msg_id" ) or "?" )[:8 ]
296298 dbg (f"EXEC { msg_type } id={ msg_id } " )
297- try : await self ._handle_message (msg , idents )
299+ try : await self ._handle_message (msg , idents , release )
298300 except Exception as exc : self ._handle_internal_error (msg , idents , exc )
299301 dbg (f"DONE { msg_type } id={ msg_id } " )
300302
@@ -317,7 +319,7 @@ def _send_error_reply(self, msg_type:str, error:dict, msg: dict, idents: list[by
317319 self .send_reply (_reply_type (msg_type ), reply , msg , idents )
318320 if msg_type == "execute_request" : self .kernel .send_status ("idle" , msg )
319321
320- async def _handle_message (self , msg : dict , idents : list [bytes ]| None ):
322+ async def _handle_message (self , msg : dict , idents : list [bytes ]| None , release ):
321323 msg_type = msg ["header" ]["msg_type" ]
322324 msg_id = msg ["header" ].get ("msg_id" , "?" )[:8 ]
323325 dbg (f"HANDLE_MSG { msg_type } id={ msg_id } " )
@@ -334,7 +336,7 @@ async def _handle_message(self, msg: dict, idents: list[bytes]|None):
334336 return
335337 if msg_type == "execute_request" :
336338 dbg (f"DISPATCH_EXEC id={ msg_id } " )
337- await self ._handle_execute (msg , idents )
339+ await self ._handle_execute (msg , idents , release )
338340 return
339341 self ._dispatch_shell_non_execute (msg , idents )
340342 finally :
@@ -399,7 +401,17 @@ def _handle_connect(self, msg: dict, idents: list[bytes]|None):
399401 stdin_port = self .kernel .connection .stdin_port , control_port = self .kernel .connection .control_port , hb_port = self .kernel .connection .hb_port )
400402 self .send_reply ("connect_reply" , content , msg , idents )
401403
402- async def _handle_execute (self , msg : dict , idents : list [bytes ]| None ):
404+ def _safe_release (self , release ):
405+ "Wrap an actor release callback so it is safe to call from any thread."
406+ loop = asyncio .get_running_loop ()
407+ def _do ():
408+ try : running = asyncio .get_running_loop ()
409+ except RuntimeError : running = None
410+ if running is loop : release ()
411+ else : loop .call_soon_threadsafe (release )
412+ return _do
413+
414+ async def _handle_execute (self , msg : dict , idents : list [bytes ]| None , release ):
403415 msg_id = (nested_idx (msg , "header" , "msg_id" ) or "?" )[:8 ]
404416 content = msg .get ("content" , {})
405417 code = content .get ("code" , "" )
@@ -410,8 +422,9 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
410422 allow_stdin = bool (content .get ("allow_stdin" , False ))
411423
412424 dbg (f"HANDLE_EXEC id={ msg_id } code={ code [:30 ]!r} ..." )
413- self ._set_exec_state (ExecState .RUNNING )
414- self .executing .set ()
425+ _unlock_release .set (self ._safe_release (release ))
426+ self .exec_scopes .clear () # a new execute ends any previous cancelling window
427+ self .exec_tracker .add ()
415428 terminal_state = ExecState .COMPLETED
416429 iopub = self .kernel .iopub
417430 sent_reply = sent_error = False
@@ -432,12 +445,11 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
432445 result = await self .shell .execute (code , silent = silent , store_history = store_history ,
433446 user_expressions = user_expressions , allow_stdin = allow_stdin )
434447 finally :
435- self .exec_scope = None
436448 if timeout_handle : timeout_handle .cancel ()
437449 dbg (f"BRIDGE_DONE id={ msg_id } " )
438450
439451 error = result .get ("error" )
440- if error and error .get ("ename" ) == "CancelledError" and self ._get_exec_state () == ExecState . CANCELLING :
452+ if error and error .get ("ename" ) == "CancelledError" and self .exec_scopes . cancelling :
441453 error = dict (error ) | dict (ename = "KeyboardInterrupt" , evalue = "" )
442454 exec_count = result .get ("execution_count" )
443455
@@ -470,8 +482,7 @@ async def _handle_execute(self, msg: dict, idents: list[bytes]|None):
470482 finally :
471483 self .kernel .send_status ("idle" , msg )
472484 self ._set_last_exec_state (terminal_state )
473- self .executing .clear ()
474- self ._set_exec_state (ExecState .IDLE )
485+ self .exec_tracker .done ()
475486
476487 def _shell_handler (self , msg : dict , idents : list [bytes ]| None ):
477488 msg_type = nested_idx (msg , "header" , "msg_type" ) or None
@@ -714,7 +725,6 @@ def handle_sigint(self, signum, frame):
714725 for subshell in children : subshell .interrupt ()
715726 parent = self .subshells .parent
716727 if not parent .executing .is_set (): return
717- parent ._set_exec_state (ExecState .CANCELLING )
718728 if parent .cancel_async_execution (wake = True ): return
719729 if not parent .sync_executing .is_set (): return
720730 raise KeyboardInterrupt
0 commit comments