Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11235.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Retry once on stale-connection errors from the shared aiodocker client after dockerd restarts, so user-visible operations do not spuriously fail on the first post-restart call.
211 changes: 176 additions & 35 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import signal
import struct
import sys
from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping, Sequence
from collections.abc import (
AsyncGenerator,
Awaitable,
Callable,
Iterable,
Mapping,
MutableMapping,
Sequence,
)
from dataclasses import dataclass
from decimal import Decimal
from functools import partial
Expand Down Expand Up @@ -194,7 +202,11 @@


async def get_extra_volumes(docker: Docker, lang: str) -> list[VolumeInfo]:
avail_volumes = (await docker.volumes.list())["Volumes"] # type: ignore[no-untyped-call]
volumes_result = await _retry_on_stale_connection(
lambda: docker.volumes.list(), # type: ignore[no-untyped-call]
operation="list_volumes",
)
avail_volumes = volumes_result["Volumes"]
if not avail_volumes:
return []
avail_volume_names = {v["Name"] for v in avail_volumes}
Expand Down Expand Up @@ -280,6 +292,58 @@ def _DockerContainerError_reduce(self: DockerContainerError) -> tuple[type, tupl
)


# Aim: catch "the pooled socket was reset by dockerd restart / keepalive timeout".
# - ``ServerDisconnectedError`` is the canonical case (dockerd closed our pooled
# keepalive socket; aiohttp surfaces it on the next request over that socket).
# - ``ClientOSError`` covers the kernel-level path (EPIPE / ECONNRESET) that can
# appear before aiohttp's own server-side detection kicks in.
# Explicitly NOT caught (by design — they fall through):
# - ``ServerTimeoutError`` (subclass of ``ServerConnectionError``): a legitimate
# long-running ``images.pull`` / ``images.push`` blowing its ``timeout=`` is a
# response-too-slow signal, not a stale-pool symptom. Retrying silently would
# mask real slowness and emit a misleading "stale aiodocker connection" log.
# - ``ClientSSLError`` (TLS / certificate issues): a retry will not help.
# Note: ``ClientConnectorError`` (fresh connection refused — dockerd is actually
# down) inherits from ``ClientOSError`` in aiohttp, so it IS technically caught
# by the tuple below. The one-shot retry is cheap and, if dockerd is truly down,
# the second attempt still fails and the error propagates with its original type.
_STALE_CONNECTION_ERRORS: Final[tuple[type[BaseException], ...]] = (
aiohttp.ServerDisconnectedError,
aiohttp.ClientOSError,
)


async def _retry_on_stale_connection[T](
coro_factory: Callable[[], Awaitable[T]],
*,
operation: str,
) -> T:
"""Run ``coro_factory()``; on stale-connection errors, retry exactly once.

The shared aiodocker client pools keepalive sockets inside its
``aiohttp.ClientSession``. After ``systemctl restart docker``, the first
post-restart call can pick a stale socket and fail with
``aiohttp.ServerDisconnectedError`` or an OS-level socket error
(``aiohttp.ClientOSError``, e.g. EPIPE / ECONNRESET); aiohttp reconnects
transparently on the next attempt, so a single retry is sufficient to
absorb the one-shot failure.

For persistent connection failures (e.g., dockerd actually down), the
second attempt fails and the exception propagates normally. Response-
too-slow timeouts (``ServerTimeoutError``) are intentionally NOT retried
here — see the comment above ``_STALE_CONNECTION_ERRORS``.
"""
try:
return await coro_factory()
except _STALE_CONNECTION_ERRORS as e:
log.warning(
"stale aiodocker connection on {}; retrying once: {!r}",
operation,
e,
)
return await coro_factory()


@dataclass
class DockerPurgeImageReq:
image: str
Expand Down Expand Up @@ -1251,7 +1315,10 @@ async def _rollback_container_creation() -> None:
docker = self.docker
container: DockerContainer | None = None
try:
container = await docker.containers.create(config=container_config, name=kernel_name)
container = await _retry_on_stale_connection(
lambda: docker.containers.create(config=container_config, name=kernel_name),
operation="create_kernel_container",
)
if container is None:
raise ContainerCreationError(
container_id="",
Expand Down Expand Up @@ -1282,7 +1349,10 @@ async def _rollback_container_creation() -> None:
raise

try:
await container.start()
await _retry_on_stale_connection(
lambda: container.start(),
operation="start_kernel_container",
)
except asyncio.CancelledError as e:
await _rollback_container_creation()
raise ContainerCreationError(
Expand Down Expand Up @@ -1320,8 +1390,14 @@ async def _rollback_container_creation() -> None:
additional_network_names |= set(n)

for name in additional_network_names:
network = await docker.networks.get(name)
await network.connect({"Container": container._id})
network = await _retry_on_stale_connection(
lambda: docker.networks.get(name),
operation="get_network",
)
await _retry_on_stale_connection(
lambda: network.connect({"Container": container._id}),
operation="connect_network",
)

kernel_obj.set_container_id(ContainerId(cid))
container_network_info: ContainerNetworkInfo | None = None
Expand Down Expand Up @@ -1623,7 +1699,10 @@ def get_cgroup_version(self) -> str:

@override
async def extract_image_command(self, image: str) -> str | None:
result = await self.docker.images.get(image)
result = await _retry_on_stale_connection(
lambda: self.docker.images.get(image),
operation="get_image",
)
return cast(str | None, result["Config"].get("Cmd"))

@override
Expand All @@ -1633,7 +1712,11 @@ async def enumerate_containers(
) -> Sequence[tuple[KernelId, Container]]:
result = []
fetch_tasks = []
for container in await self.docker.containers.list():
containers = await _retry_on_stale_connection(
lambda: self.docker.containers.list(),
operation="list_containers",
)
for container in containers:

async def _fetch_container_info(container: DockerContainer) -> None:
kernel_id_str: str = "(unknown)"
Expand All @@ -1647,7 +1730,10 @@ async def _fetch_container_info(container: DockerContainer) -> None:
container["Config"]["Labels"].get(LabelName.OWNER_AGENT, "")
)
if self.id == owner_id:
await container.show()
await _retry_on_stale_connection(
lambda: container.show(),
operation="show_container",
)
result.append(
(
kernel_id,
Expand Down Expand Up @@ -1701,12 +1787,31 @@ async def resolve_image_distro(self, image: ImageConfig) -> str:
"Cmd": ["ldd", "--version"],
}

container = await docker.containers.create(container_config)
await container.start()
await container.wait() # wait until container finishes to prevent race condition
container_log = await container.log(stdout=True, stderr=True, follow=False)
await container.stop()
await container.delete()
container = await _retry_on_stale_connection(
lambda: docker.containers.create(container_config),
operation="create_distro_probe_container",
)
await _retry_on_stale_connection(
lambda: container.start(),
operation="start_distro_probe_container",
)
# wait until container finishes to prevent race condition
await _retry_on_stale_connection(
lambda: container.wait(),
operation="wait_distro_probe_container",
)
container_log = await _retry_on_stale_connection(
lambda: container.log(stdout=True, stderr=True, follow=False),
operation="log_distro_probe_container",
)
await _retry_on_stale_connection(
lambda: container.stop(),
operation="stop_distro_probe_container",
)
await _retry_on_stale_connection(
lambda: container.delete(),
operation="delete_distro_probe_container",
)
log.debug("response: {}", container_log)
version_lines = container_log[0].splitlines()
if m := LDD_GLIBC_REGEX.search(version_lines[0]):
Expand All @@ -1730,7 +1835,10 @@ async def resolve_image_distro(self, image: ImageConfig) -> str:
@override
async def scan_images(self) -> ScanImagesResult:
docker = self.docker
all_images = await docker.images.list()
all_images = await _retry_on_stale_connection(
lambda: docker.images.list(),
operation="list_images",
)
scanned_images: dict[ImageCanonical, InstalledImageInfo] = {}
removed_images: dict[ImageCanonical, InstalledImageInfo] = {}
for image in all_images:
Expand All @@ -1751,7 +1859,10 @@ async def scan_images(self) -> ScanImagesResult:
self.checked_invalid_images.add(repo_tag)
continue

img_detail = await docker.images.inspect(repo_tag)
img_detail = await _retry_on_stale_connection(
lambda: docker.images.inspect(repo_tag),
operation="inspect_image",
)
labels = (img_detail.get("Config") or {}).get("Labels")
if labels is None:
continue
Expand Down Expand Up @@ -1891,7 +2002,10 @@ async def push_image(
kwargs: dict[str, Any] = {"auth": auth_config}
if timeout_seconds != Sentinel.TOKEN:
kwargs["timeout"] = timeout_seconds
result = await self.docker.images.push(image_ref.canonical, **kwargs)
result = await _retry_on_stale_connection(
lambda: self.docker.images.push(image_ref.canonical, **kwargs),
operation="push_image",
)

if not result:
raise RuntimeError("Failed to push image: unexpected return value from aiodocker")
Expand All @@ -1915,8 +2029,11 @@ async def pull_image(
"auth": encoded_creds,
}
log.info("pulling image {} from registry", image_ref.canonical)
result = await self.docker.images.pull(
image_ref.canonical, auth=auth_config, timeout=timeout_seconds
result = await _retry_on_stale_connection(
lambda: self.docker.images.pull(
image_ref.canonical, auth=auth_config, timeout=timeout_seconds
),
operation="pull_image",
)

if not result:
Expand All @@ -1926,8 +2043,11 @@ async def pull_image(

async def _purge_image(self, request: DockerPurgeImageReq) -> PurgeImageResp:
try:
await self.docker.images.delete(
request.image, force=request.force, noprune=request.noprune
await _retry_on_stale_connection(
lambda: self.docker.images.delete(
request.image, force=request.force, noprune=request.noprune
),
operation="delete_image",
)
return PurgeImageResp.success(image=request.image)
except Exception as e:
Expand Down Expand Up @@ -1960,7 +2080,10 @@ async def check_image(
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
) -> bool:
try:
image_info = await self.docker.images.inspect(image_ref.canonical)
image_info = await _retry_on_stale_connection(
lambda: self.docker.images.inspect(image_ref.canonical),
operation="inspect_image",
)
if auto_pull == AutoPullBehavior.DIGEST:
if image_info["Id"] != image_id:
return True
Expand Down Expand Up @@ -2050,7 +2173,10 @@ async def destroy_kernel(
container = self.docker.containers.container(container_id)
# The default timeout of the docker stop API is 10 seconds
# to kill if container does not self-terminate.
await container.stop()
await _retry_on_stale_connection(
lambda: container.stop(),
operation="stop_container",
)
except DockerError as e:
if e.status == HTTPStatus.CONFLICT and "is not running" in e.message:
# already dead
Expand Down Expand Up @@ -2129,7 +2255,10 @@ async def log_iter() -> AsyncGenerator[bytes, None]:
container = docker.containers.container(container_id)
try:
with timeout(90):
await container.delete(force=True, v=True)
await _retry_on_stale_connection(
lambda: container.delete(force=True, v=True),
operation="delete_container",
)
except DockerError as e:
if (
e.status == HTTPStatus.CONFLICT and "already in progress" in e.message
Expand Down Expand Up @@ -2165,24 +2294,36 @@ async def log_iter() -> AsyncGenerator[bytes, None]:
async def create_local_network(self, network_name: str) -> None:
docker = self.docker
try:
await docker.networks.get(network_name)
await _retry_on_stale_connection(
lambda: docker.networks.get(network_name),
operation="get_network",
)
except DockerError as e:
if e.status == HTTPStatus.NOT_FOUND:
await docker.networks.create({
"Name": network_name,
"Driver": "bridge",
"Labels": {
"ai.backend.cluster-network": "1",
},
})
await _retry_on_stale_connection(
lambda: docker.networks.create({
"Name": network_name,
"Driver": "bridge",
"Labels": {
"ai.backend.cluster-network": "1",
},
}),
operation="create_network",
)
else:
raise

@override
async def destroy_local_network(self, network_name: str) -> None:
try:
network = await self.docker.networks.get(network_name)
await network.delete()
network = await _retry_on_stale_connection(
lambda: self.docker.networks.get(network_name),
operation="get_network",
)
await _retry_on_stale_connection(
lambda: network.delete(),
operation="delete_network",
)
except DockerError as e:
if e.status == HTTPStatus.NOT_FOUND:
# skip silently if already removed/missing
Expand Down
Loading
Loading