44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7+ import queue
78import threading
89import time
910from collections .abc import Callable
@@ -455,8 +456,11 @@ def _process_server_entry(
455456 server_kwargs : dict ,
456457 shutdown_event : MPEvent ,
457458 ready_queue ,
459+ control_queue ,
460+ control_response_queue ,
458461) -> None :
459462 """Run an :class:`InferenceServer` loop inside a child process."""
463+ server = None
460464 try :
461465 model = policy_factory ()
462466 server = InferenceServer (
@@ -465,11 +469,36 @@ def _process_server_entry(
465469 shutdown_event = shutdown_event ,
466470 ** server_kwargs ,
467471 )
472+ server .start ()
468473 ready_queue .put ((True , None ))
469- server ._run ()
474+ while not shutdown_event .is_set ():
475+ try :
476+ request_id , command , kwargs = control_queue .get (timeout = 0.05 )
477+ except queue .Empty :
478+ continue
479+ try :
480+ if command == "stats" :
481+ payload = server .stats (** kwargs )
482+ elif command == "health" :
483+ payload = {
484+ "alive" : server .is_alive ,
485+ "policy_version" : server .policy_version ,
486+ }
487+ elif command == "shutdown" :
488+ shutdown_event .set ()
489+ payload = {"accepted" : True }
490+ else :
491+ raise RuntimeError (f"Unknown process-server command: { command } " )
492+ except BaseException as exc :
493+ control_response_queue .put ((request_id , False , repr (exc )))
494+ else :
495+ control_response_queue .put ((request_id , True , payload ))
470496 except BaseException as exc :
471497 ready_queue .put ((False , repr (exc )))
472498 raise
499+ finally :
500+ if server is not None :
501+ server .shutdown (timeout = 1.0 )
473502
474503
475504class ProcessInferenceServer :
@@ -595,6 +624,9 @@ def __init__(
595624 self ._ctx = mp_context
596625 self ._shutdown_event = self ._ctx .Event ()
597626 self ._ready_queue = self ._ctx .Queue ()
627+ self ._control_queue = self ._ctx .Queue ()
628+ self ._control_response_queue = self ._ctx .Queue ()
629+ self ._next_control_request_id = 0
598630 self ._process : mp .Process | None = None
599631 self ._server_kwargs = {
600632 "max_batch_size" : max_batch_size ,
@@ -625,6 +657,8 @@ def start(self) -> ProcessInferenceServer:
625657 "server_kwargs" : self ._server_kwargs ,
626658 "shutdown_event" : self ._shutdown_event ,
627659 "ready_queue" : self ._ready_queue ,
660+ "control_queue" : self ._control_queue ,
661+ "control_response_queue" : self ._control_response_queue ,
628662 },
629663 daemon = True ,
630664 name = "ProcessInferenceServer" ,
@@ -636,8 +670,44 @@ def start(self) -> ProcessInferenceServer:
636670 raise RuntimeError (f"ProcessInferenceServer failed to start: { payload } " )
637671 return self
638672
673+ def _request_control (
674+ self , command : str , kwargs : dict | None = None , timeout : float = 5.0
675+ ):
676+ if self ._process is None :
677+ raise RuntimeError ("ProcessInferenceServer is not running." )
678+ if not self ._process .is_alive ():
679+ raise RuntimeError (
680+ "ProcessInferenceServer process is not alive "
681+ f"(exitcode={ self ._process .exitcode } )."
682+ )
683+ request_id = self ._next_control_request_id
684+ self ._next_control_request_id += 1
685+ self ._control_queue .put ((request_id , command , kwargs or {}))
686+ deadline = time .monotonic () + timeout
687+ while True :
688+ remaining = deadline - time .monotonic ()
689+ if remaining <= 0 :
690+ raise TimeoutError (
691+ f"Timed out waiting for ProcessInferenceServer { command !r} ."
692+ )
693+ response_id , ok , payload = self ._control_response_queue .get (
694+ timeout = remaining
695+ )
696+ if response_id != request_id :
697+ continue
698+ if not ok :
699+ raise RuntimeError (
700+ f"ProcessInferenceServer { command !r} failed: { payload } "
701+ )
702+ return payload
703+
639704 def shutdown (self , timeout : float | None = 5.0 ) -> None :
640705 """Signal the child process to stop and wait for it to exit."""
706+ if self .is_alive :
707+ try :
708+ self ._request_control ("shutdown" , timeout = timeout or 5.0 )
709+ except Exception :
710+ pass
641711 self ._shutdown_event .set ()
642712 process = self ._process
643713 if process is None :
@@ -654,12 +724,25 @@ def is_alive(self) -> bool:
654724 return self ._process is not None and self ._process .is_alive ()
655725
656726 def stats (self , * , reset : bool = False ) -> dict [str , float | int ]:
657- """Return process-server stats.
727+ """Return process-server stats from the child process .
658728
659- Live stats are not shared across processes yet, so this currently
660- returns an empty dictionary.
729+ Args:
730+ reset (bool, optional): if ``True``, reset counters in the child
731+ process after taking the snapshot.
661732 """
662- return {}
733+ return self ._request_control ("stats" , {"reset" : reset })
734+
735+ def health (self ) -> dict [str , int | bool | None ]:
736+ """Return a lightweight child-process health snapshot."""
737+ process = self ._process
738+ result = {
739+ "process_alive" : process .is_alive () if process is not None else False ,
740+ "pid" : process .pid if process is not None else None ,
741+ "exitcode" : process .exitcode if process is not None else None ,
742+ }
743+ if process is not None and process .is_alive ():
744+ result .update (self ._request_control ("health" ))
745+ return result
663746
664747 def __enter__ (self ) -> ProcessInferenceServer :
665748 return self .start ()
0 commit comments