diff --git a/changes/11235.enhance.md b/changes/11235.enhance.md new file mode 100644 index 00000000000..41a741378ef --- /dev/null +++ b/changes/11235.enhance.md @@ -0,0 +1 @@ +Retry once on stale-connection errors from the shared aiodocker client after dockerd restarts, so user-visible operations do not spuriously fail on the first post-restart call. diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index b01c75b05a1..46c43ee3c7c 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -11,7 +11,15 @@ import signal import struct import sys -from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping, Sequence +from collections.abc import ( + AsyncGenerator, + Awaitable, + Callable, + Iterable, + Mapping, + MutableMapping, + Sequence, +) from dataclasses import dataclass from decimal import Decimal from functools import partial @@ -194,7 +202,11 @@ async def get_extra_volumes(docker: Docker, lang: str) -> list[VolumeInfo]: - avail_volumes = (await docker.volumes.list())["Volumes"] # type: ignore[no-untyped-call] + volumes_result = await _retry_on_stale_connection( + lambda: docker.volumes.list(), # type: ignore[no-untyped-call] + operation="list_volumes", + ) + avail_volumes = volumes_result["Volumes"] if not avail_volumes: return [] avail_volume_names = {v["Name"] for v in avail_volumes} @@ -280,6 +292,58 @@ def _DockerContainerError_reduce(self: DockerContainerError) -> tuple[type, tupl ) +# Aim: catch "the pooled socket was reset by dockerd restart / keepalive timeout". +# - ``ServerDisconnectedError`` is the canonical case (dockerd closed our pooled +# keepalive socket; aiohttp surfaces it on the next request over that socket). +# - ``ClientOSError`` covers the kernel-level path (EPIPE / ECONNRESET) that can +# appear before aiohttp's own server-side detection kicks in. +# Explicitly NOT caught (by design — they fall through): +# - ``ServerTimeoutError`` (subclass of ``ServerConnectionError``): a legitimate +# long-running ``images.pull`` / ``images.push`` blowing its ``timeout=`` is a +# response-too-slow signal, not a stale-pool symptom. Retrying silently would +# mask real slowness and emit a misleading "stale aiodocker connection" log. +# - ``ClientSSLError`` (TLS / certificate issues): a retry will not help. +# Note: ``ClientConnectorError`` (fresh connection refused — dockerd is actually +# down) inherits from ``ClientOSError`` in aiohttp, so it IS technically caught +# by the tuple below. The one-shot retry is cheap and, if dockerd is truly down, +# the second attempt still fails and the error propagates with its original type. +_STALE_CONNECTION_ERRORS: Final[tuple[type[BaseException], ...]] = ( + aiohttp.ServerDisconnectedError, + aiohttp.ClientOSError, +) + + +async def _retry_on_stale_connection[T]( + coro_factory: Callable[[], Awaitable[T]], + *, + operation: str, +) -> T: + """Run ``coro_factory()``; on stale-connection errors, retry exactly once. + + The shared aiodocker client pools keepalive sockets inside its + ``aiohttp.ClientSession``. After ``systemctl restart docker``, the first + post-restart call can pick a stale socket and fail with + ``aiohttp.ServerDisconnectedError`` or an OS-level socket error + (``aiohttp.ClientOSError``, e.g. EPIPE / ECONNRESET); aiohttp reconnects + transparently on the next attempt, so a single retry is sufficient to + absorb the one-shot failure. + + For persistent connection failures (e.g., dockerd actually down), the + second attempt fails and the exception propagates normally. Response- + too-slow timeouts (``ServerTimeoutError``) are intentionally NOT retried + here — see the comment above ``_STALE_CONNECTION_ERRORS``. + """ + try: + return await coro_factory() + except _STALE_CONNECTION_ERRORS as e: + log.warning( + "stale aiodocker connection on {}; retrying once: {!r}", + operation, + e, + ) + return await coro_factory() + + @dataclass class DockerPurgeImageReq: image: str @@ -1251,7 +1315,10 @@ async def _rollback_container_creation() -> None: docker = self.docker container: DockerContainer | None = None try: - container = await docker.containers.create(config=container_config, name=kernel_name) + container = await _retry_on_stale_connection( + lambda: docker.containers.create(config=container_config, name=kernel_name), + operation="create_kernel_container", + ) if container is None: raise ContainerCreationError( container_id="", @@ -1282,7 +1349,10 @@ async def _rollback_container_creation() -> None: raise try: - await container.start() + await _retry_on_stale_connection( + lambda: container.start(), + operation="start_kernel_container", + ) except asyncio.CancelledError as e: await _rollback_container_creation() raise ContainerCreationError( @@ -1320,8 +1390,14 @@ async def _rollback_container_creation() -> None: additional_network_names |= set(n) for name in additional_network_names: - network = await docker.networks.get(name) - await network.connect({"Container": container._id}) + network = await _retry_on_stale_connection( + lambda: docker.networks.get(name), + operation="get_network", + ) + await _retry_on_stale_connection( + lambda: network.connect({"Container": container._id}), + operation="connect_network", + ) kernel_obj.set_container_id(ContainerId(cid)) container_network_info: ContainerNetworkInfo | None = None @@ -1623,7 +1699,10 @@ def get_cgroup_version(self) -> str: @override async def extract_image_command(self, image: str) -> str | None: - result = await self.docker.images.get(image) + result = await _retry_on_stale_connection( + lambda: self.docker.images.get(image), + operation="get_image", + ) return cast(str | None, result["Config"].get("Cmd")) @override @@ -1633,7 +1712,11 @@ async def enumerate_containers( ) -> Sequence[tuple[KernelId, Container]]: result = [] fetch_tasks = [] - for container in await self.docker.containers.list(): + containers = await _retry_on_stale_connection( + lambda: self.docker.containers.list(), + operation="list_containers", + ) + for container in containers: async def _fetch_container_info(container: DockerContainer) -> None: kernel_id_str: str = "(unknown)" @@ -1647,7 +1730,10 @@ async def _fetch_container_info(container: DockerContainer) -> None: container["Config"]["Labels"].get(LabelName.OWNER_AGENT, "") ) if self.id == owner_id: - await container.show() + await _retry_on_stale_connection( + lambda: container.show(), + operation="show_container", + ) result.append( ( kernel_id, @@ -1701,12 +1787,31 @@ async def resolve_image_distro(self, image: ImageConfig) -> str: "Cmd": ["ldd", "--version"], } - container = await docker.containers.create(container_config) - await container.start() - await container.wait() # wait until container finishes to prevent race condition - container_log = await container.log(stdout=True, stderr=True, follow=False) - await container.stop() - await container.delete() + container = await _retry_on_stale_connection( + lambda: docker.containers.create(container_config), + operation="create_distro_probe_container", + ) + await _retry_on_stale_connection( + lambda: container.start(), + operation="start_distro_probe_container", + ) + # wait until container finishes to prevent race condition + await _retry_on_stale_connection( + lambda: container.wait(), + operation="wait_distro_probe_container", + ) + container_log = await _retry_on_stale_connection( + lambda: container.log(stdout=True, stderr=True, follow=False), + operation="log_distro_probe_container", + ) + await _retry_on_stale_connection( + lambda: container.stop(), + operation="stop_distro_probe_container", + ) + await _retry_on_stale_connection( + lambda: container.delete(), + operation="delete_distro_probe_container", + ) log.debug("response: {}", container_log) version_lines = container_log[0].splitlines() if m := LDD_GLIBC_REGEX.search(version_lines[0]): @@ -1730,7 +1835,10 @@ async def resolve_image_distro(self, image: ImageConfig) -> str: @override async def scan_images(self) -> ScanImagesResult: docker = self.docker - all_images = await docker.images.list() + all_images = await _retry_on_stale_connection( + lambda: docker.images.list(), + operation="list_images", + ) scanned_images: dict[ImageCanonical, InstalledImageInfo] = {} removed_images: dict[ImageCanonical, InstalledImageInfo] = {} for image in all_images: @@ -1751,7 +1859,10 @@ async def scan_images(self) -> ScanImagesResult: self.checked_invalid_images.add(repo_tag) continue - img_detail = await docker.images.inspect(repo_tag) + img_detail = await _retry_on_stale_connection( + lambda: docker.images.inspect(repo_tag), + operation="inspect_image", + ) labels = (img_detail.get("Config") or {}).get("Labels") if labels is None: continue @@ -1891,7 +2002,10 @@ async def push_image( kwargs: dict[str, Any] = {"auth": auth_config} if timeout_seconds != Sentinel.TOKEN: kwargs["timeout"] = timeout_seconds - result = await self.docker.images.push(image_ref.canonical, **kwargs) + result = await _retry_on_stale_connection( + lambda: self.docker.images.push(image_ref.canonical, **kwargs), + operation="push_image", + ) if not result: raise RuntimeError("Failed to push image: unexpected return value from aiodocker") @@ -1915,8 +2029,11 @@ async def pull_image( "auth": encoded_creds, } log.info("pulling image {} from registry", image_ref.canonical) - result = await self.docker.images.pull( - image_ref.canonical, auth=auth_config, timeout=timeout_seconds + result = await _retry_on_stale_connection( + lambda: self.docker.images.pull( + image_ref.canonical, auth=auth_config, timeout=timeout_seconds + ), + operation="pull_image", ) if not result: @@ -1926,8 +2043,11 @@ async def pull_image( async def _purge_image(self, request: DockerPurgeImageReq) -> PurgeImageResp: try: - await self.docker.images.delete( - request.image, force=request.force, noprune=request.noprune + await _retry_on_stale_connection( + lambda: self.docker.images.delete( + request.image, force=request.force, noprune=request.noprune + ), + operation="delete_image", ) return PurgeImageResp.success(image=request.image) except Exception as e: @@ -1960,7 +2080,10 @@ async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior ) -> bool: try: - image_info = await self.docker.images.inspect(image_ref.canonical) + image_info = await _retry_on_stale_connection( + lambda: self.docker.images.inspect(image_ref.canonical), + operation="inspect_image", + ) if auto_pull == AutoPullBehavior.DIGEST: if image_info["Id"] != image_id: return True @@ -2050,7 +2173,10 @@ async def destroy_kernel( container = self.docker.containers.container(container_id) # The default timeout of the docker stop API is 10 seconds # to kill if container does not self-terminate. - await container.stop() + await _retry_on_stale_connection( + lambda: container.stop(), + operation="stop_container", + ) except DockerError as e: if e.status == HTTPStatus.CONFLICT and "is not running" in e.message: # already dead @@ -2129,7 +2255,10 @@ async def log_iter() -> AsyncGenerator[bytes, None]: container = docker.containers.container(container_id) try: with timeout(90): - await container.delete(force=True, v=True) + await _retry_on_stale_connection( + lambda: container.delete(force=True, v=True), + operation="delete_container", + ) except DockerError as e: if ( e.status == HTTPStatus.CONFLICT and "already in progress" in e.message @@ -2165,24 +2294,36 @@ async def log_iter() -> AsyncGenerator[bytes, None]: async def create_local_network(self, network_name: str) -> None: docker = self.docker try: - await docker.networks.get(network_name) + await _retry_on_stale_connection( + lambda: docker.networks.get(network_name), + operation="get_network", + ) except DockerError as e: if e.status == HTTPStatus.NOT_FOUND: - await docker.networks.create({ - "Name": network_name, - "Driver": "bridge", - "Labels": { - "ai.backend.cluster-network": "1", - }, - }) + await _retry_on_stale_connection( + lambda: docker.networks.create({ + "Name": network_name, + "Driver": "bridge", + "Labels": { + "ai.backend.cluster-network": "1", + }, + }), + operation="create_network", + ) else: raise @override async def destroy_local_network(self, network_name: str) -> None: try: - network = await self.docker.networks.get(network_name) - await network.delete() + network = await _retry_on_stale_connection( + lambda: self.docker.networks.get(network_name), + operation="get_network", + ) + await _retry_on_stale_connection( + lambda: network.delete(), + operation="delete_network", + ) except DockerError as e: if e.status == HTTPStatus.NOT_FOUND: # skip silently if already removed/missing diff --git a/tests/component/agent/docker/test_agent.py b/tests/component/agent/docker/test_agent.py index 6098531b125..10270c3fc7f 100644 --- a/tests/component/agent/docker/test_agent.py +++ b/tests/component/agent/docker/test_agent.py @@ -8,16 +8,17 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock +import aiohttp import pytest from aiodocker.exceptions import DockerError from ai.backend.agent.agent import AgentClass -from ai.backend.agent.docker.agent import DockerAgent +from ai.backend.agent.docker.agent import DockerAgent, _retry_on_stale_connection from ai.backend.agent.kernel import KernelRegistry from ai.backend.common.arch import DEFAULT_IMAGE_ARCH from ai.backend.common.docker import ImageRef from ai.backend.common.exception import ImageNotAvailable -from ai.backend.common.types import AutoPullBehavior +from ai.backend.common.types import AutoPullBehavior, ImageConfig class DummyEtcd: @@ -248,3 +249,188 @@ class _SimulatedShutdownError(Exception): with pytest.raises(_SimulatedShutdownError): await unmanaged_agent.shutdown(signal.SIGTERM) assert unmanaged_agent.docker.session.closed is True + + +class TestRetryOnStaleConnection: + """Unit tests for the stale-connection retry helper.""" + + async def test_retry_on_stale_connection_retries_once_then_succeeds(self) -> None: + calls = 0 + sentinel = object() + + async def factory() -> Any: + nonlocal calls + calls += 1 + if calls == 1: + # ServerDisconnectedError is the canonical stale-pool signal + # and is a concrete subclass of ClientConnectionError. + raise aiohttp.ServerDisconnectedError() + return sentinel + + result = await _retry_on_stale_connection(factory, operation="test_op") + assert result is sentinel + assert calls == 2 + + async def test_retry_on_stale_connection_retries_once_on_server_disconnected(self) -> None: + calls = 0 + sentinel = object() + + async def factory() -> Any: + nonlocal calls + calls += 1 + if calls == 1: + raise aiohttp.ServerDisconnectedError() + return sentinel + + result = await _retry_on_stale_connection(factory, operation="test_op") + assert result is sentinel + assert calls == 2 + + async def test_retry_on_stale_connection_retries_once_on_client_os_error(self) -> None: + calls = 0 + sentinel = object() + + async def factory() -> Any: + nonlocal calls + calls += 1 + if calls == 1: + # Kernel-level reset (ECONNRESET) that can appear before + # aiohttp's own server-side detection kicks in. + raise aiohttp.ClientOSError("connection reset by peer") + return sentinel + + result = await _retry_on_stale_connection(factory, operation="test_op") + assert result is sentinel + assert calls == 2 + + async def test_retry_on_stale_connection_does_not_retry_other_errors(self) -> None: + calls = 0 + docker_error = DockerError( + status=HTTPStatus.CONFLICT, + data={"message": "simulated conflict"}, + ) + + async def factory() -> Any: + nonlocal calls + calls += 1 + raise docker_error + + with pytest.raises(DockerError) as exc_info: + await _retry_on_stale_connection(factory, operation="test_op") + assert exc_info.value is docker_error + assert calls == 1 + + async def test_retry_on_stale_connection_propagates_persistent_failure(self) -> None: + calls = 0 + + async def factory() -> Any: + nonlocal calls + calls += 1 + raise aiohttp.ServerDisconnectedError() + + with pytest.raises(aiohttp.ServerDisconnectedError): + await _retry_on_stale_connection(factory, operation="test_op") + # First attempt + one retry = 2 invocations total. + assert calls == 2 + + async def test_server_timeout_is_not_retried(self) -> None: + """``ServerTimeoutError`` must propagate on the first attempt. + + Regression guard for the narrowed catch tuple: a long-running + ``images.push`` / ``images.pull`` that blows its ``timeout=`` is a + response-too-slow signal, not a stale-pool symptom. It must NOT be + silently retried (which would emit a misleading "stale aiodocker + connection" warning and mask real slowness). + """ + calls = 0 + + async def factory() -> Any: + nonlocal calls + calls += 1 + raise aiohttp.ServerTimeoutError("simulated response timeout") + + with pytest.raises(aiohttp.ServerTimeoutError): + await _retry_on_stale_connection(factory, operation="test_op") + assert calls == 1 + + +async def test_check_image_retries_on_stale_connection(agent: DockerAgent, mocker: Any) -> None: + """``check_image`` should absorb a one-shot stale-socket error.""" + behavior = AutoPullBehavior.DIGEST + inspect_mock = AsyncMock( + side_effect=[ + aiohttp.ServerDisconnectedError(), + digest_matching_image_info, + ], + ) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + assert inspect_mock.await_count == 2 + + +async def test_container_create_retries_on_stale_connection( + agent: DockerAgent, mocker: Any +) -> None: + """``resolve_image_distro`` (a wrapped ``containers.create`` call-site) must + absorb a one-shot stale-socket error on the create call. + + This exercises a real wrapped method — not the helper in isolation — so the + narrowed retry tuple is validated end-to-end on the ``containers.create`` + path (the smallest wrapped call-site of ``create`` on ``DockerAgent``). + """ + # Mock valkey_stat_client so the cache miss path is taken and the distro + # write at the end is a no-op. ``close`` is also awaitable so shutdown can + # run cleanly in the ``agent`` fixture teardown. + valkey_client = MagicMock() + valkey_client.get_image_distro = AsyncMock(return_value=None) + valkey_client.set_image_distro = AsyncMock(return_value=None) + valkey_client.close = AsyncMock(return_value=None) + valkey_client.set_agent_container_count = AsyncMock(return_value=None) + mocker.patch.object(agent, "valkey_stat_client", new=valkey_client) + + # The probe container mock: start/wait/stop/delete are no-ops, log returns + # a musl-identifying line so resolve_image_distro short-circuits on alpine. + probe_container = MagicMock() + probe_container.start = AsyncMock(return_value=None) + probe_container.wait = AsyncMock(return_value=None) + probe_container.log = AsyncMock(return_value=["musl libc (x86_64)"]) + probe_container.stop = AsyncMock(return_value=None) + probe_container.delete = AsyncMock(return_value=None) + + # First create attempt hits a stale pooled socket; second attempt succeeds. + create_mock = AsyncMock( + side_effect=[ + aiohttp.ServerDisconnectedError(), + probe_container, + ], + ) + mocker.patch.object(agent.docker.containers, "create", new=create_mock) + + image_config: ImageConfig = { + "canonical": "lablup/lua:5.3-alpine3.8", + "project": "lablup", + "architecture": DEFAULT_IMAGE_ARCH, + "digest": "sha256:b000000000000000000000000000000000000000000000000000000000000001", + "repo_digest": None, + "registry": { + "name": "index.docker.io", + "url": "https://index.docker.io", + "username": None, + "password": None, + }, + "labels": {}, + "is_local": False, + "auto_pull": AutoPullBehavior.DIGEST, + } + distro = await agent.resolve_image_distro(image_config) + + assert distro == "alpine3.8" + assert create_mock.await_count == 2 + # Downstream container lifecycle methods must have been invoked exactly + # once against the (successfully created) probe container. + probe_container.start.assert_awaited_once() + probe_container.wait.assert_awaited_once() + probe_container.log.assert_awaited_once() + probe_container.stop.assert_awaited_once() + probe_container.delete.assert_awaited_once()