diff --git a/changes/11234.enhance.md b/changes/11234.enhance.md new file mode 100644 index 00000000000..5ce3ac01c56 --- /dev/null +++ b/changes/11234.enhance.md @@ -0,0 +1 @@ +Share a single `DockerStatsStreamer` across CPU/Memory intrinsic plugins so each container opens one persistent Docker stats stream instead of two. diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index c78038409c1..725daee38f9 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -1417,41 +1417,25 @@ async def _handle_start_event(self, ev: ContainerLifecycleEvent) -> None: if ev.container_id is not None: kernel_obj.set_container_id(ev.container_id) if ev.container_id is not None: - await self._notify_compute_plugins_container_started(ev.container_id) + await self._on_container_started(ev.container_id) log.info("Kernel {0} started", ev.kernel_id) - async def _notify_compute_plugins_container_started(self, container_id: ContainerId) -> None: - """Notify all compute plugins that ``container_id`` has started. + async def _on_container_started(self, container_id: ContainerId) -> None: + """Hook for subclasses to react to a container transitioning to RUNNING. - Plugins may use this to eagerly initialise per-container state - (e.g. :class:`DockerStatsStreamer`). Plugin failures are logged but - never prevent the container from running. + The default implementation is a no-op. Concrete agents (e.g. the Docker + agent) override this to start per-container resources such as a + long-lived stats stream reader. """ - short_cid = str(container_id)[:13] - for device_name, computer_ctx in self.computers.items(): - try: - await computer_ctx.instance.notify_container_started(str(container_id)) - except Exception as e: - log.warning( - "compute plugin {} notify_container_started failed (cid:{}): {!r}", - device_name, - short_cid, - e, - ) + return - async def _notify_compute_plugins_container_destroyed(self, container_id: ContainerId) -> None: - """Notify all compute plugins that ``container_id`` has been cleaned up.""" - short_cid = str(container_id)[:13] - for device_name, computer_ctx in self.computers.items(): - try: - await computer_ctx.instance.notify_container_destroyed(str(container_id)) - except Exception as e: - log.warning( - "compute plugin {} notify_container_destroyed failed (cid:{}): {!r}", - device_name, - short_cid, - e, - ) + async def _on_container_destroyed(self, container_id: ContainerId) -> None: + """Hook for subclasses to react to a container being cleaned up. + + The default implementation is a no-op. Concrete agents (e.g. the Docker + agent) override this to release per-container resources. + """ + return async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: log.info( @@ -1537,7 +1521,7 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: await destruction_task del destruction_task if ev.container_id is not None: - await self._notify_compute_plugins_container_destroyed(ev.container_id) + await self._on_container_destroyed(ev.container_id) await self.stat_ctx.remove_kernel_metric(ev.kernel_id, ev.container_id) async with self.registry_lock: try: diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index ab6fab49675..46839ed95e0 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -50,6 +50,9 @@ ScanImagesResult, ) from ai.backend.agent.config.unified import AgentUnifiedConfig, ContainerSandboxType, ScratchType +from ai.backend.agent.docker.intrinsic import ( + DockerStatsStreamer, +) from ai.backend.agent.etcd import AgentEtcdClientView from ai.backend.agent.exception import ( ContainerCreationError, @@ -1422,6 +1425,7 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]): docker_ptask_group: aiotools.PersistentTaskGroup gwbridge_subnet: str | None checked_invalid_images: set[str] + _stats_streamer: DockerStatsStreamer network_plugin_ctx: NetworkPluginContext @@ -1517,6 +1521,26 @@ async def __ainit__(self) -> None: ) self.docker_info = docker_info await self._kernel_recovery_adapter.adapt_recovery_data() + + # For legacy accelerator plugins + self.docker = Docker() + + # Single DockerStatsStreamer shared across intrinsic compute plugins. + # Keeps one long-lived ``container.stats(stream=True)`` reader per + # container instead of one per plugin, cutting open connections to + # dockerd in half for CPU + Memory plugins. + # + # NOTE: the streamer MUST be created and attached to plugins BEFORE + # ``super().__ainit__()`` runs. ``AbstractAgent.__ainit__`` calls + # ``scan_running_kernels()`` and then starts the lifecycle handler + # task (``process_lifecycle_events``); on warm restart that pipeline + # can fire container-start events which land in + # ``_on_container_started`` -> ``self._stats_streamer.start(...)`` + # before any code AFTER ``super().__ainit__()`` gets to run. + self._stats_streamer = DockerStatsStreamer(self.docker) + for computer_ctx in self.computers.values(): + computer_ctx.instance.attach_stats_streamer(self._stats_streamer) + await super().__ainit__() try: async with Docker() as docker: @@ -1556,9 +1580,6 @@ async def __ainit__(self) -> None: self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events()) self.docker_ptask_group = aiotools.PersistentTaskGroup() - # For legacy accelerator plugins - self.docker = Docker() - self.network_plugin_ctx = NetworkPluginContext( self.etcd, self.local_config.model_dump(by_alias=True) ) @@ -1576,6 +1597,13 @@ async def shutdown(self, stop_signal: signal.Signals) -> None: if self.docker_ptask_group is not None: await self.docker_ptask_group.shutdown() + # Close the shared stats streamer before the underlying Docker client + # so in-flight reader tasks can cleanly drain their stream iterators. + # ``_stats_streamer`` is declared non-Optional and is assigned + # synchronously at the top of ``__ainit__`` before any ``await``, + # so it always exists once the agent is constructed. + await self._stats_streamer.close() + try: await super().shutdown(stop_signal) finally: @@ -1587,6 +1615,14 @@ async def shutdown(self, stop_signal: signal.Signals) -> None: if self.docker: await self.docker.close() + @override + async def _on_container_started(self, container_id: ContainerId) -> None: + self._stats_streamer.start(str(container_id)) + + @override + async def _on_container_destroyed(self, container_id: ContainerId) -> None: + await self._stats_streamer.stop(str(container_id)) + @override async def _load_kernel_registry_from_recovery(self) -> MutableMapping[KernelId, AbstractKernel]: return await self._kernel_recovery.load_kernel_registry() diff --git a/src/ai/backend/agent/docker/intrinsic.py b/src/ai/backend/agent/docker/intrinsic.py index 6ba0d2231bd..db60a760624 100644 --- a/src/ai/backend/agent/docker/intrinsic.py +++ b/src/ai/backend/agent/docker/intrinsic.py @@ -356,21 +356,18 @@ class CPUPlugin(AbstractComputePlugin): async def init(self, context: Any | None = None) -> None: self._docker = Docker() - # TODO(#11232): Consolidate per-plugin streamer into a single shared instance owned by the agent. - self._stats_streamer = DockerStatsStreamer(self._docker) async def cleanup(self) -> None: - await self._stats_streamer.close() await self._docker.close() async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: pass - async def notify_container_started(self, container_id: str) -> None: - self._stats_streamer.start(container_id) - - async def notify_container_destroyed(self, container_id: str) -> None: - await self._stats_streamer.stop(container_id) + def attach_stats_streamer(self, streamer: DockerStatsStreamer) -> None: + """Attach the agent-owned :class:`DockerStatsStreamer` used for reading + per-container stats. Called once by :class:`DockerAgent` after plugin + init so the streamer is shared across intrinsic plugins.""" + self._stats_streamer = streamer async def list_devices(self) -> Collection[CPUDevice]: cores = await libnuma.get_available_cores() @@ -689,21 +686,18 @@ class MemoryPlugin(AbstractComputePlugin): async def init(self, context: Any | None = None) -> None: self._docker = Docker() - # TODO(#11232): Consolidate per-plugin streamer into a single shared instance owned by the agent. - self._stats_streamer = DockerStatsStreamer(self._docker) async def cleanup(self) -> None: - await self._stats_streamer.close() await self._docker.close() async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: pass - async def notify_container_started(self, container_id: str) -> None: - self._stats_streamer.start(container_id) - - async def notify_container_destroyed(self, container_id: str) -> None: - await self._stats_streamer.stop(container_id) + def attach_stats_streamer(self, streamer: DockerStatsStreamer) -> None: + """Attach the agent-owned :class:`DockerStatsStreamer` used for reading + per-container stats. Called once by :class:`DockerAgent` after plugin + init so the streamer is shared across intrinsic plugins.""" + self._stats_streamer = streamer async def list_devices(self) -> Collection[MemoryDevice]: memory_size = psutil.virtual_memory().total diff --git a/src/ai/backend/agent/resources.py b/src/ai/backend/agent/resources.py index 8bf5bd30bfe..0692c59bf56 100644 --- a/src/ai/backend/agent/resources.py +++ b/src/ai/backend/agent/resources.py @@ -77,6 +77,8 @@ from aiofiles.threadpool.text import AsyncTextIOWrapper + from ai.backend.agent.docker.intrinsic import DockerStatsStreamer + type DeviceAllocation = Mapping[SlotName, Mapping[DeviceId, Decimal]] @@ -451,25 +453,6 @@ async def generate_resource_data(self, device_alloc: DeviceAllocation) -> Mappin """ return {} - async def notify_container_started(self, container_id: str) -> None: - """ - Lifecycle hook invoked by the agent when a container transitions to RUNNING. - - Subclasses may override this to eagerly spin up per-container resources - (e.g. start a long-lived stats stream reader) instead of relying on - lazy initialisation from the next stat collection cycle. Default: no-op. - """ - return - - async def notify_container_destroyed(self, container_id: str) -> None: - """ - Lifecycle hook invoked by the agent when a container is being cleaned up. - - Subclasses may override this to release per-container resources - (e.g. cancel the long-lived stats stream reader task). Default: no-op. - """ - return - @abstractmethod async def restore_from_container( self, @@ -535,6 +518,16 @@ def get_additional_allowed_syscalls(self) -> list[str]: """ return [] + def attach_stats_streamer(self, streamer: DockerStatsStreamer) -> None: + """ + Hook point for Docker-backed plugins that need the agent-owned + :class:`DockerStatsStreamer` to read per-container stats from a + single shared stream. The base implementation is a no-op so + plugins that do not consume the streamer (e.g. K8s, Dummy, + third-party accelerators) can inherit the default safely. + """ + pass + type ComputersMap = Mapping[DeviceName, ComputerContext] type SlotsMap = Mapping[SlotName, Decimal] diff --git a/tests/unit/agent/test_docker_intrinsic.py b/tests/unit/agent/test_docker_intrinsic.py index 10d93080a44..31019669399 100644 --- a/tests/unit/agent/test_docker_intrinsic.py +++ b/tests/unit/agent/test_docker_intrinsic.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, cast +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import aiohttp @@ -14,6 +14,7 @@ from aiodocker.exceptions import DockerError from ai.backend.agent.agent import AbstractAgent +from ai.backend.agent.docker.agent import DockerAgent from ai.backend.agent.docker.intrinsic import ( ContainerNetStat, CPUPlugin, @@ -21,9 +22,7 @@ MemoryPlugin, read_proc_net_dev, ) -from ai.backend.agent.resources import ComputerContext from ai.backend.agent.stats import StatModes -from ai.backend.common.types import ContainerId, DeviceName class BaseDockerIntrinsicTest: @@ -878,71 +877,170 @@ def fake_container_cls(docker: Any, id: str) -> _FakeDockerContainer: assert streamer.get_latest("cid_gone") is None -class _AgentStub: - """Minimal stand-in exposing only the attributes the notify helpers read.""" +class TestSharedStatsStreamerWiring: + """Verify the agent owns a single streamer that both intrinsic plugins share.""" - def __init__(self, computers: dict[DeviceName, ComputerContext]) -> None: - self.computers = computers + async def test_agent_owns_single_statsstreamer_shared_with_plugins(self) -> None: + """After :meth:`DockerAgent.attach_stats_streamer` is called for each + intrinsic plugin, both plugins must reference the SAME streamer + instance that the agent owns on ``self._stats_streamer``.""" + streamer = DockerStatsStreamer(AsyncMock()) + cpu_plugin = CPUPlugin.__new__(CPUPlugin) + mem_plugin = MemoryPlugin.__new__(MemoryPlugin) + cpu_plugin.attach_stats_streamer(streamer) + mem_plugin.attach_stats_streamer(streamer) -class TestAgentContainerLifecycleHooks: - """Tests that the agent's start/clean event handlers notify compute plugins.""" + assert cpu_plugin._stats_streamer is streamer + assert mem_plugin._stats_streamer is streamer + assert cpu_plugin._stats_streamer is mem_plugin._stats_streamer - @staticmethod - def _make_agent_stub( - computers: dict[str, ComputerContext], - ) -> AbstractAgent[Any, Any]: - """Build a stub exposing ``self.computers`` so the notify helpers can - be exercised without instantiating the full AbstractAgent (which is - abstract with many required overrides).""" - typed_computers: dict[DeviceName, ComputerContext] = { - DeviceName(k): v for k, v in computers.items() - } - return cast(AbstractAgent[Any, Any], _AgentStub(typed_computers)) - - async def test_notify_started_calls_plugin_hook(self) -> None: - """The agent's helper invokes notify_container_started on every plugin.""" - plugin = AsyncMock() - plugin.notify_container_started = AsyncMock() - plugin.notify_container_destroyed = AsyncMock() - agent = self._make_agent_stub({ - "cpu": ComputerContext(instance=plugin, devices=[], alloc_map=MagicMock()), - }) - - await AbstractAgent._notify_compute_plugins_container_started(agent, ContainerId("cid_000")) - plugin.notify_container_started.assert_awaited_once_with("cid_000") - plugin.notify_container_destroyed.assert_not_called() - - async def test_notify_destroyed_calls_plugin_hook(self) -> None: - """The agent's helper invokes notify_container_destroyed on every plugin.""" - plugin = AsyncMock() - plugin.notify_container_started = AsyncMock() - plugin.notify_container_destroyed = AsyncMock() - agent = self._make_agent_stub({ - "cpu": ComputerContext(instance=plugin, devices=[], alloc_map=MagicMock()), - }) - - await AbstractAgent._notify_compute_plugins_container_destroyed( - agent, ContainerId("cid_000") + async def test_agent_stats_streamer_closed_on_shutdown(self) -> None: + """``DockerAgent.shutdown`` must close the shared streamer so in-flight + reader tasks are cancelled before the underlying Docker client goes + away.""" + streamer = DockerStatsStreamer(AsyncMock()) + streamer.start("cid_000") + assert "cid_000" in streamer._tasks + + await streamer.close() + + assert streamer._tasks == {} + assert streamer._latest == {} + + async def test_stats_streamer_available_before_scan_running_kernels( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Ordering invariant: ``DockerAgent.__ainit__`` must install the shared + :class:`DockerStatsStreamer` on ``self`` and attach it to the intrinsic + plugins BEFORE ``super().__ainit__()`` runs. + + ``AbstractAgent.__ainit__`` calls :meth:`scan_running_kernels` and + starts the container lifecycle handler; on warm restart either of + those can fire ``_on_container_started`` -> ``self._stats_streamer`` + before any code placed AFTER ``super().__ainit__()`` gets to execute. + If the streamer is not set by then, ``_on_container_started`` raises + ``AttributeError`` on a bare class annotation. + + This test monkeypatches :meth:`AbstractAgent.scan_running_kernels` + (the exact point where the blocker bites) and asserts the streamer + is already set when it runs. It also confirms the intrinsic plugins + received the same streamer instance before the super init is reached. + """ + observed_streamer: list[DockerStatsStreamer | None] = [] + observed_cpu_streamer: list[DockerStatsStreamer | None] = [] + observed_mem_streamer: list[DockerStatsStreamer | None] = [] + + class _StopAfterOrderingAssertion(BaseException): + """Sentinel used to short-circuit ``__ainit__`` once the ordering + invariant has been verified. A ``BaseException`` subclass ensures + it propagates past any ``except Exception:`` guards in the init + flow below the stubbed ``super().__ainit__()`` call.""" + + cpu_plugin = CPUPlugin.__new__(CPUPlugin) + mem_plugin = MemoryPlugin.__new__(MemoryPlugin) + + computer_ctx_cpu = MagicMock() + computer_ctx_cpu.instance = cpu_plugin + computer_ctx_mem = MagicMock() + computer_ctx_mem.instance = mem_plugin + + async def fake_scan_running_kernels(self: Any) -> None: + # Record the streamer state at the exact moment the race would + # bite on warm restart. + observed_streamer.append(getattr(self, "_stats_streamer", None)) + observed_cpu_streamer.append(getattr(cpu_plugin, "_stats_streamer", None)) + observed_mem_streamer.append(getattr(mem_plugin, "_stats_streamer", None)) + + async def fake_super_ainit(self: Any) -> None: + # Emulate the part of AbstractAgent.__ainit__ that matters for + # this ordering test: call scan_running_kernels, which is where + # the blocker actually fires on warm restart. Then bail out + # so the post-super section (which depends on ``self.id``, + # Redis, networks, etc.) is not exercised. + await self.scan_running_kernels() + raise _StopAfterOrderingAssertion() + + # Patch the exact targets of the race. + monkeypatch.setattr( + AbstractAgent, + "scan_running_kernels", + fake_scan_running_kernels, + raising=True, + ) + monkeypatch.setattr( + AbstractAgent, + "__ainit__", + fake_super_ainit, + raising=True, + ) + + # Stub the pre-super Docker interactions so __ainit__ does not + # require a live Docker daemon. Only the ordering is under test. + mock_docker_client = AsyncMock() + mock_docker_client.version = AsyncMock( + return_value={"Version": "0", "ApiVersion": "0", "KernelVersion": "test"}, ) - plugin.notify_container_destroyed.assert_awaited_once_with("cid_000") - plugin.notify_container_started.assert_not_called() - - async def test_notify_tolerates_plugin_exceptions(self) -> None: - """If one plugin raises, others are still notified; the error is logged.""" - broken = AsyncMock() - broken.notify_container_started = AsyncMock(side_effect=RuntimeError("boom")) - healthy = AsyncMock() - healthy.notify_container_started = AsyncMock() - - agent = self._make_agent_stub({ - "broken": ComputerContext(instance=broken, devices=[], alloc_map=MagicMock()), - "healthy": ComputerContext(instance=healthy, devices=[], alloc_map=MagicMock()), - }) - - await AbstractAgent._notify_compute_plugins_container_started(agent, ContainerId("cid_000")) - broken.notify_container_started.assert_awaited_once_with("cid_000") - healthy.notify_container_started.assert_awaited_once_with("cid_000") + mock_docker_client.system = MagicMock() + mock_docker_client.system.info = AsyncMock(return_value={"CgroupDriver": "cgroupfs"}) + mock_docker_client.connector = aiohttp.UnixConnector(path="/var/run/docker.sock") + + def docker_factory(*args: Any, **kwargs: Any) -> AsyncMock: + return mock_docker_client + + monkeypatch.setattr( + "ai.backend.agent.docker.agent.Docker", + docker_factory, + ) + + # ``async with closing_async(Docker()) as docker`` is used in the + # pre-super section; wrap the mock in an async context manager. + class _AsyncCMWrapper: + def __init__(self, obj: Any) -> None: + self._obj = obj + + async def __aenter__(self) -> Any: + return self._obj + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + monkeypatch.setattr( + "ai.backend.agent.docker.agent.closing_async", + lambda obj: _AsyncCMWrapper(obj), + ) + + # Construct a minimally-initialised DockerAgent without running + # the heavy synchronous __init__ (which needs real etcd, registries, + # etc.). Only the attributes touched before super().__ainit__() need + # to be populated. + agent = DockerAgent.__new__(DockerAgent) + agent.local_config = MagicMock() + agent.local_config.agent.docker_mode = "native" + agent._kernel_recovery_adapter = MagicMock() + agent._kernel_recovery_adapter.adapt_recovery_data = AsyncMock(return_value=None) + # Typed as ``Mapping[DeviceName, ComputerContext]`` in AbstractAgent; + # the test uses ``MagicMock`` stand-ins so the attach loop can iterate. + fake_computers: Any = { + "cpu": computer_ctx_cpu, + "mem": computer_ctx_mem, + } + agent.computers = fake_computers + + with pytest.raises(_StopAfterOrderingAssertion): + await agent.__ainit__() + + # The patched scan_running_kernels ran exactly once (via the stub + # super) and at that moment the streamer must already be set. + assert len(observed_streamer) == 1 + assert observed_streamer[0] is not None + assert isinstance(observed_streamer[0], DockerStatsStreamer) + # Both intrinsic plugins must already hold the same streamer instance. + assert observed_cpu_streamer[0] is observed_streamer[0] + assert observed_mem_streamer[0] is observed_streamer[0] + + await agent._stats_streamer.close() def aiohttp_client_connection_error(msg: str) -> Exception: