Skip to content

Commit 8b85a5f

Browse files
committed
Update
[ghstack-poisoned]
1 parent 609b776 commit 8b85a5f

2 files changed

Lines changed: 93 additions & 6 deletions

File tree

test/test_inference_server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,10 +870,14 @@ def test_process_server_start_shutdown(self):
870870
mp_context=ctx,
871871
) as server:
872872
assert server.is_alive
873-
assert server.stats() == {}
874873
result = client(TensorDict({"observation": torch.ones(1)}))
874+
stats = server.stats()
875+
health = server.health()
875876
assert "action" in result.keys()
876877
assert result["action"].shape == (1,)
878+
assert stats["requests"] == 1
879+
assert stats["avg_batch_size"] == 1
880+
assert health["process_alive"]
877881
assert not server.is_alive
878882

879883
def test_process_server_exception_propagates(self):

torchrl/modules/inference_server/_server.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import queue
78
import threading
89
import time
910
from collections.abc import Callable
@@ -455,8 +456,11 @@ def _process_server_entry(
455456
server_kwargs: dict,
456457
shutdown_event: MPEvent,
457458
ready_queue,
459+
control_queue,
460+
control_response_queue,
458461
) -> None:
459462
"""Run an :class:`InferenceServer` loop inside a child process."""
463+
server = None
460464
try:
461465
model = policy_factory()
462466
server = InferenceServer(
@@ -465,11 +469,36 @@ def _process_server_entry(
465469
shutdown_event=shutdown_event,
466470
**server_kwargs,
467471
)
472+
server.start()
468473
ready_queue.put((True, None))
469-
server._run()
474+
while not shutdown_event.is_set():
475+
try:
476+
request_id, command, kwargs = control_queue.get(timeout=0.05)
477+
except queue.Empty:
478+
continue
479+
try:
480+
if command == "stats":
481+
payload = server.stats(**kwargs)
482+
elif command == "health":
483+
payload = {
484+
"alive": server.is_alive,
485+
"policy_version": server.policy_version,
486+
}
487+
elif command == "shutdown":
488+
shutdown_event.set()
489+
payload = {"accepted": True}
490+
else:
491+
raise RuntimeError(f"Unknown process-server command: {command}")
492+
except BaseException as exc:
493+
control_response_queue.put((request_id, False, repr(exc)))
494+
else:
495+
control_response_queue.put((request_id, True, payload))
470496
except BaseException as exc:
471497
ready_queue.put((False, repr(exc)))
472498
raise
499+
finally:
500+
if server is not None:
501+
server.shutdown(timeout=1.0)
473502

474503

475504
class ProcessInferenceServer:
@@ -595,6 +624,9 @@ def __init__(
595624
self._ctx = mp_context
596625
self._shutdown_event = self._ctx.Event()
597626
self._ready_queue = self._ctx.Queue()
627+
self._control_queue = self._ctx.Queue()
628+
self._control_response_queue = self._ctx.Queue()
629+
self._next_control_request_id = 0
598630
self._process: mp.Process | None = None
599631
self._server_kwargs = {
600632
"max_batch_size": max_batch_size,
@@ -625,6 +657,8 @@ def start(self) -> ProcessInferenceServer:
625657
"server_kwargs": self._server_kwargs,
626658
"shutdown_event": self._shutdown_event,
627659
"ready_queue": self._ready_queue,
660+
"control_queue": self._control_queue,
661+
"control_response_queue": self._control_response_queue,
628662
},
629663
daemon=True,
630664
name="ProcessInferenceServer",
@@ -636,8 +670,44 @@ def start(self) -> ProcessInferenceServer:
636670
raise RuntimeError(f"ProcessInferenceServer failed to start: {payload}")
637671
return self
638672

673+
def _request_control(
674+
self, command: str, kwargs: dict | None = None, timeout: float = 5.0
675+
):
676+
if self._process is None:
677+
raise RuntimeError("ProcessInferenceServer is not running.")
678+
if not self._process.is_alive():
679+
raise RuntimeError(
680+
"ProcessInferenceServer process is not alive "
681+
f"(exitcode={self._process.exitcode})."
682+
)
683+
request_id = self._next_control_request_id
684+
self._next_control_request_id += 1
685+
self._control_queue.put((request_id, command, kwargs or {}))
686+
deadline = time.monotonic() + timeout
687+
while True:
688+
remaining = deadline - time.monotonic()
689+
if remaining <= 0:
690+
raise TimeoutError(
691+
f"Timed out waiting for ProcessInferenceServer {command!r}."
692+
)
693+
response_id, ok, payload = self._control_response_queue.get(
694+
timeout=remaining
695+
)
696+
if response_id != request_id:
697+
continue
698+
if not ok:
699+
raise RuntimeError(
700+
f"ProcessInferenceServer {command!r} failed: {payload}"
701+
)
702+
return payload
703+
639704
def shutdown(self, timeout: float | None = 5.0) -> None:
640705
"""Signal the child process to stop and wait for it to exit."""
706+
if self.is_alive:
707+
try:
708+
self._request_control("shutdown", timeout=timeout or 5.0)
709+
except Exception:
710+
pass
641711
self._shutdown_event.set()
642712
process = self._process
643713
if process is None:
@@ -654,12 +724,25 @@ def is_alive(self) -> bool:
654724
return self._process is not None and self._process.is_alive()
655725

656726
def stats(self, *, reset: bool = False) -> dict[str, float | int]:
657-
"""Return process-server stats.
727+
"""Return process-server stats from the child process.
658728
659-
Live stats are not shared across processes yet, so this currently
660-
returns an empty dictionary.
729+
Args:
730+
reset (bool, optional): if ``True``, reset counters in the child
731+
process after taking the snapshot.
661732
"""
662-
return {}
733+
return self._request_control("stats", {"reset": reset})
734+
735+
def health(self) -> dict[str, int | bool | None]:
736+
"""Return a lightweight child-process health snapshot."""
737+
process = self._process
738+
result = {
739+
"process_alive": process.is_alive() if process is not None else False,
740+
"pid": process.pid if process is not None else None,
741+
"exitcode": process.exitcode if process is not None else None,
742+
}
743+
if process is not None and process.is_alive():
744+
result.update(self._request_control("health"))
745+
return result
663746

664747
def __enter__(self) -> ProcessInferenceServer:
665748
return self.start()

0 commit comments

Comments
 (0)