diff --git a/changes/11226.enhance.md b/changes/11226.enhance.md new file mode 100644 index 00000000000..d17e4b96ca0 --- /dev/null +++ b/changes/11226.enhance.md @@ -0,0 +1 @@ +Reuse a long-lived aiodocker client in the agent instead of opening a fresh connection per container operation. diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index ab6fab49675..b01c75b05a1 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -300,6 +300,7 @@ class DockerKernelCreationContext(AbstractKernelCreationContext[DockerKernel]): resource_lock: asyncio.Lock cluster_ssh_port_mapping: ClusterSSHPortMapping | None gwbridge_subnet: str | None + docker: Docker network_plugin_ctx: NetworkPluginContext @@ -316,6 +317,7 @@ def __init__( agent_sockpath: Path, resource_lock: asyncio.Lock, network_plugin_ctx: NetworkPluginContext, + docker: Docker, restarting: bool = False, cluster_ssh_port_mapping: ClusterSSHPortMapping | None = None, gwbridge_subnet: str | None = None, @@ -351,6 +353,7 @@ def __init__( self.gwbridge_subnet = gwbridge_subnet self.network_plugin_ctx = network_plugin_ctx + self.docker = docker def _kernel_resource_spec_read(self, filename: Path | str) -> KernelResourceSpec: filepath = Path(filename) @@ -589,8 +592,7 @@ async def get_intrinsic_mounts(self) -> Sequence[Mount]: ) # extra mounts - async with closing_async(Docker()) as docker: - extra_mount_list = await get_extra_volumes(docker, self.image_ref.short) + extra_mount_list = await get_extra_volumes(self.docker, self.image_ref.short) for v in extra_mount_list: permission = MountPermission.READ_ONLY if v.mode == "ro" else MountPermission.READ_WRITE mounts.append( @@ -838,11 +840,10 @@ async def apply_accelerator_allocation( computer: AbstractComputePlugin, device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], ) -> None: - async with closing_async(Docker()) as docker: - update_nested_dict( - self.computer_docker_args, - await computer.generate_docker_args(docker, device_alloc), - ) + update_nested_dict( + self.computer_docker_args, + await computer.generate_docker_args(self.docker, device_alloc), + ) @override async def generate_accelerator_mounts( @@ -1247,155 +1248,151 @@ async def _rollback_container_creation() -> None: self.computers[dev_name].alloc_map.free(device_alloc) # We are all set! Create and start the container. - async with closing_async(Docker()) as docker: - container: DockerContainer | None = None - try: - container = await docker.containers.create( - config=container_config, name=kernel_name + docker = self.docker + container: DockerContainer | None = None + try: + container = await docker.containers.create(config=container_config, name=kernel_name) + if container is None: + raise ContainerCreationError( + container_id="", + message="Docker API returned None when creating container", ) - if container is None: - raise ContainerCreationError( - container_id="", - message="Docker API returned None when creating container", - ) - cid = cast(str, container._id) - async with AsyncFileWriter( - target_filename=self.config_dir / "resource.txt", - access_mode="a", - ) as writer: - await writer.write(f"CID={cid}\n") - - except asyncio.CancelledError as e: - if container is not None: - raise ContainerCreationError( - container_id=ContainerId(container.id), - message="Container creation was cancelled", - ) from e - raise - except Exception as e: - # Oops, we have to restore the allocated resources! - await _rollback_container_creation() - if container is not None: - raise ContainerCreationError( - container_id=ContainerId(container.id), - message=f"Unexpected error during container creation: {e!r}", - ) from e - raise - - try: - await container.start() - except asyncio.CancelledError as e: - await _rollback_container_creation() + cid = cast(str, container._id) + async with AsyncFileWriter( + target_filename=self.config_dir / "resource.txt", + access_mode="a", + ) as writer: + await writer.write(f"CID={cid}\n") + + except asyncio.CancelledError as e: + if container is not None: raise ContainerCreationError( - container_id=cid, - message="Container start was cancelled", + container_id=ContainerId(container.id), + message="Container creation was cancelled", ) from e - except Exception as e: + raise + except Exception as e: + # Oops, we have to restore the allocated resources! + await _rollback_container_creation() + if container is not None: + raise ContainerCreationError( + container_id=ContainerId(container.id), + message=f"Unexpected error during container creation: {e!r}", + ) from e + raise + + try: + await container.start() + except asyncio.CancelledError as e: + await _rollback_container_creation() + raise ContainerCreationError( + container_id=cid, + message="Container start was cancelled", + ) from e + except Exception as e: + await _rollback_container_creation() + raise ContainerCreationError( + container_id=cid, + message=f"Unexpected error during container start: {e!r}", + ) from e + + if self.internal_data.get("sudo_session_enabled", False): + exec = await container.exec( + [ + # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user + "sh", + "-c", + 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', + ], + user="root", + ) + shell_response = await exec.start(detach=True) + if shell_response: await _rollback_container_creation() raise ContainerCreationError( container_id=cid, - message=f"Unexpected error during container start: {e!r}", - ) from e + message=f"sudoers provision failed: {shell_response.decode()}", + ) + + additional_network_names: set[str] = set() + for dev_name, device_alloc in resource_spec.allocations.items(): + n = await self.computers[dev_name].instance.get_docker_networks(device_alloc) + additional_network_names |= set(n) - if self.internal_data.get("sudo_session_enabled", False): - exec = await container.exec( + for name in additional_network_names: + network = await docker.networks.get(name) + await network.connect({"Container": container._id}) + + kernel_obj.set_container_id(ContainerId(cid)) + container_network_info: ContainerNetworkInfo | None = None + if (mode := cluster_info["network_config"].get("mode")) and mode != "bridge": + try: + plugin = self.network_plugin_ctx.plugins[mode] + except KeyError as e: + raise RuntimeError(f"Network plugin {mode} not loaded!") from e + if ContainerNetworkCapability.GLOBAL in (await plugin.get_capabilities()): + container_network_info = await plugin.expose_ports( + kernel_obj, + str(container_bind_host), [ - # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user - "sh", - "-c", - 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', + (host_port, container_port) + for host_port, container_port in zip(host_ports, exposed_ports, strict=True) ], - user="root", ) - shell_response = await exec.start(detach=True) - if shell_response: - await _rollback_container_creation() + + created_host_ports: tuple[int, ...] + repl_in_port = 0 + repl_out_port = 0 + if container_network_info: + kernel_host = container_network_info.container_host + port_map = container_network_info.services + if "replin" not in port_map or "replout" not in port_map: + raise InvalidArgumentError("replin and replout ports are required in port_map") + + repl_in_port = port_map["replin"][2000] + repl_out_port = port_map["replout"][2001] + stdin_port = 0 # left for legacy + stdout_port = 0 # left for legacy + + for sport in service_ports: + created_host_ports = tuple( + port_map[sport["name"]][cport] for cport in sport["container_ports"] + ) + sport["host_ports"] = created_host_ports + else: + kernel_host = advertised_kernel_host or container_bind_host + ctnr_host_port_map: MutableMapping[int, int] = {} + stdin_port = 0 + stdout_port = 0 + for idx, port in enumerate(exposed_ports): + ports: list[PortInfo] | None = await container.port(port) + if not ports: raise ContainerCreationError( container_id=cid, - message=f"sudoers provision failed: {shell_response.decode()}", - ) - - additional_network_names: set[str] = set() - for dev_name, device_alloc in resource_spec.allocations.items(): - n = await self.computers[dev_name].instance.get_docker_networks(device_alloc) - additional_network_names |= set(n) - - for name in additional_network_names: - network = await docker.networks.get(name) - await network.connect({"Container": container._id}) - - kernel_obj.set_container_id(ContainerId(cid)) - container_network_info: ContainerNetworkInfo | None = None - if (mode := cluster_info["network_config"].get("mode")) and mode != "bridge": - try: - plugin = self.network_plugin_ctx.plugins[mode] - except KeyError as e: - raise RuntimeError(f"Network plugin {mode} not loaded!") from e - if ContainerNetworkCapability.GLOBAL in (await plugin.get_capabilities()): - container_network_info = await plugin.expose_ports( - kernel_obj, - str(container_bind_host), - [ - (host_port, container_port) - for host_port, container_port in zip( - host_ports, exposed_ports, strict=True - ) - ], - ) - - created_host_ports: tuple[int, ...] - repl_in_port = 0 - repl_out_port = 0 - if container_network_info: - kernel_host = container_network_info.container_host - port_map = container_network_info.services - if "replin" not in port_map or "replout" not in port_map: - raise InvalidArgumentError("replin and replout ports are required in port_map") - - repl_in_port = port_map["replin"][2000] - repl_out_port = port_map["replout"][2001] - stdin_port = 0 # left for legacy - stdout_port = 0 # left for legacy - - for sport in service_ports: - created_host_ports = tuple( - port_map[sport["name"]][cport] for cport in sport["container_ports"] + message=f"Container port {port} not found in port mapping", ) - sport["host_ports"] = created_host_ports - else: - kernel_host = advertised_kernel_host or container_bind_host - ctnr_host_port_map: MutableMapping[int, int] = {} - stdin_port = 0 - stdout_port = 0 - for idx, port in enumerate(exposed_ports): - ports: list[PortInfo] | None = await container.port(port) - if not ports: - raise ContainerCreationError( - container_id=cid, - message=f"Container port {port} not found in port mapping", - ) - host_port = int(ports[0]["HostPort"]) - if host_port != host_ports[idx]: - await _rollback_container_creation() - raise ContainerCreationError( - container_id=cid, - message=f"Port mapping mismatch. {host_port = }, {host_ports[idx] = }", - ) - if port == 2000: # intrinsic - repl_in_port = host_port - elif port == 2001: # intrinsic - repl_out_port = host_port - elif port == 2002: # legacy - stdin_port = host_port - elif port == 2003: # legacy - stdout_port = host_port - else: - ctnr_host_port_map[port] = host_port - for sport in service_ports: - created_host_ports = tuple( - ctnr_host_port_map[cport] for cport in sport["container_ports"] + host_port = int(ports[0]["HostPort"]) + if host_port != host_ports[idx]: + await _rollback_container_creation() + raise ContainerCreationError( + container_id=cid, + message=f"Port mapping mismatch. {host_port = }, {host_ports[idx] = }", ) - sport["host_ports"] = created_host_ports + if port == 2000: # intrinsic + repl_in_port = host_port + elif port == 2001: # intrinsic + repl_out_port = host_port + elif port == 2002: # legacy + stdin_port = host_port + elif port == 2003: # legacy + stdout_port = host_port + else: + ctnr_host_port_map[port] = host_port + for sport in service_ports: + created_host_ports = tuple( + ctnr_host_port_map[cport] for cport in sport["container_ports"] + ) + sport["host_ports"] = created_host_ports if repl_in_port == 0: raise InvalidArgumentError("repl_in_port should have been assigned") @@ -1415,6 +1412,7 @@ async def _rollback_container_creation() -> None: class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]): + docker: Docker docker_info: Mapping[str, Any] monitor_docker_task: asyncio.Task[Any] agent_sockpath: Path @@ -1482,7 +1480,9 @@ def __init__( ) async def __ainit__(self) -> None: - async with closing_async(Docker()) as docker: + self.docker = Docker() + try: + docker = self.docker docker_host = "" match docker.connector: case aiohttp.TCPConnector(): @@ -1516,75 +1516,78 @@ async def __ainit__(self) -> None: docker_info["CgroupVersion"], ) self.docker_info = docker_info - await self._kernel_recovery_adapter.adapt_recovery_data() - await super().__ainit__() - try: - async with Docker() as docker: + await self._kernel_recovery_adapter.adapt_recovery_data() + await super().__ainit__() + try: gwbridge = await docker.networks.get("docker_gwbridge") gwbridge_info = await gwbridge.show() self.gwbridge_subnet = gwbridge_info["IPAM"]["Config"][0]["Subnet"] - except (DockerError, KeyError, IndexError): - self.gwbridge_subnet = None - ipc_base_path = self.local_config.agent.ipc_base_path - (ipc_base_path / "container").mkdir(parents=True, exist_ok=True) - self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock" - # Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs - if sys.platform != "darwin": - socket_relay_name = f"backendai-socket-relay.{self.id}" - socket_relay_container = PersistentServiceContainer( - "backendai-socket-relay:latest", - { - "Cmd": [ - f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777", - f"TCP-CONNECT:127.0.0.1:{self.local_config.agent.agent_sock_port}", - ], - "HostConfig": { - "Mounts": [ - { - "Type": "bind", - "Source": str(ipc_base_path / "container"), - "Target": "/ipc", - }, + except (DockerError, KeyError, IndexError): + self.gwbridge_subnet = None + ipc_base_path = self.local_config.agent.ipc_base_path + (ipc_base_path / "container").mkdir(parents=True, exist_ok=True) + self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock" + # Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs + if sys.platform != "darwin": + socket_relay_name = f"backendai-socket-relay.{self.id}" + socket_relay_container = PersistentServiceContainer( + "backendai-socket-relay:latest", + { + "Cmd": [ + f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777", + f"TCP-CONNECT:127.0.0.1:{self.local_config.agent.agent_sock_port}", ], - "NetworkMode": "host", + "HostConfig": { + "Mounts": [ + { + "Type": "bind", + "Source": str(ipc_base_path / "container"), + "Target": "/ipc", + }, + ], + "NetworkMode": "host", + }, }, - }, - name=socket_relay_name, - ) - await socket_relay_container.ensure_running_latest() - self.agent_sock_task = asyncio.create_task(self.handle_agent_socket()) - self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events()) - self.docker_ptask_group = aiotools.PersistentTaskGroup() - - # For legacy accelerator plugins - self.docker = Docker() + name=socket_relay_name, + ) + await socket_relay_container.ensure_running_latest() + self.agent_sock_task = asyncio.create_task(self.handle_agent_socket()) + self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events()) + self.docker_ptask_group = aiotools.PersistentTaskGroup() - self.network_plugin_ctx = NetworkPluginContext( - self.etcd, self.local_config.model_dump(by_alias=True) - ) - await self.network_plugin_ctx.init( - context=self, - allowlist=self.local_config.agent.allow_network_plugins, - blocklist=self.local_config.agent.block_network_plugins, - ) + self.network_plugin_ctx = NetworkPluginContext( + self.etcd, self.local_config.model_dump(by_alias=True) + ) + await self.network_plugin_ctx.init( + context=self, + allowlist=self.local_config.agent.allow_network_plugins, + blocklist=self.local_config.agent.block_network_plugins, + ) + except Exception: + # Release the shared aiodocker client if boot fails after its construction + # so the underlying aiohttp.ClientSession does not leak. + await self.docker.close() + raise async def shutdown(self, stop_signal: signal.Signals) -> None: - # Stop handling agent sock. - if self.agent_sock_task is not None: - self.agent_sock_task.cancel() - await self.agent_sock_task - if self.docker_ptask_group is not None: - await self.docker_ptask_group.shutdown() - try: - await super().shutdown(stop_signal) - finally: - # Stop docker event monitoring. - if self.monitor_docker_task is not None: - self.monitor_docker_task.cancel() - await self.monitor_docker_task + # Stop handling agent sock. + if self.agent_sock_task is not None: + self.agent_sock_task.cancel() + await self.agent_sock_task + if self.docker_ptask_group is not None: + await self.docker_ptask_group.shutdown() - if self.docker: + try: + await super().shutdown(stop_signal) + finally: + # Stop docker event monitoring. + if self.monitor_docker_task is not None: + self.monitor_docker_task.cancel() + await self.monitor_docker_task + finally: + # Outer finally guarantees the shared aiodocker client is released + # even if inner shutdown steps raise. await self.docker.close() @override @@ -1620,9 +1623,8 @@ def get_cgroup_version(self) -> str: @override async def extract_image_command(self, image: str) -> str | None: - async with closing_async(Docker()) as docker: - result = await docker.images.get(image) - return cast(str | None, result["Config"].get("Cmd")) + result = await self.docker.images.get(image) + return cast(str | None, result["Config"].get("Cmd")) @override async def enumerate_containers( @@ -1631,45 +1633,44 @@ async def enumerate_containers( ) -> Sequence[tuple[KernelId, Container]]: result = [] fetch_tasks = [] - async with closing_async(Docker()) as docker: - for container in await docker.containers.list(): + for container in await self.docker.containers.list(): - async def _fetch_container_info(container: DockerContainer) -> None: - kernel_id_str: str = "(unknown)" - try: - kernel_id = await get_kernel_id_from_container(container) - if kernel_id is None: - return - kernel_id_str = str(kernel_id) - if container["State"]["Status"] in status_filter: - owner_id = AgentId( - container["Config"]["Labels"].get(LabelName.OWNER_AGENT, "") - ) - if self.id == owner_id: - await container.show() - result.append( - ( - kernel_id, - container_from_docker_container(container), - ), - ) - except DockerError as e: - if e.status == HTTPStatus.NOT_FOUND: - log.warning(e.message) - return - raise - except asyncio.CancelledError: - pass - except Exception: - log.exception( - "error while fetching container information (cid:{}, k:{})", - container._id, - kernel_id_str, + async def _fetch_container_info(container: DockerContainer) -> None: + kernel_id_str: str = "(unknown)" + try: + kernel_id = await get_kernel_id_from_container(container) + if kernel_id is None: + return + kernel_id_str = str(kernel_id) + if container["State"]["Status"] in status_filter: + owner_id = AgentId( + container["Config"]["Labels"].get(LabelName.OWNER_AGENT, "") ) + if self.id == owner_id: + await container.show() + result.append( + ( + kernel_id, + container_from_docker_container(container), + ), + ) + except DockerError as e: + if e.status == HTTPStatus.NOT_FOUND: + log.warning(e.message) + return + raise + except asyncio.CancelledError: + pass + except Exception: + log.exception( + "error while fetching container information (cid:{}, k:{})", + container._id, + kernel_id_str, + ) - fetch_tasks.append(_fetch_container_info(container)) + fetch_tasks.append(_fetch_container_info(container)) - await asyncio.gather(*fetch_tasks, return_exceptions=True) + await asyncio.gather(*fetch_tasks, return_exceptions=True) return result @override @@ -1679,101 +1680,101 @@ async def resolve_image_distro(self, image: ImageConfig) -> str: if distro: return distro - async with Docker() as docker: - image_id = image["digest"].partition(":")[-1] - # check if distro data is available on redis cache - cached_distro = await self.valkey_stat_client.get_image_distro(image_id) - if cached_distro: - return cached_distro - - container_config: dict[str, Any] = { - "Image": image["canonical"], - "Tty": True, - "Privileged": False, - "AttachStdin": False, - "AttachStdout": True, - "AttachStderr": True, - "HostConfig": { - "Init": True, - }, - "Entrypoint": [""], - "Cmd": ["ldd", "--version"], - } + docker = self.docker + image_id = image["digest"].partition(":")[-1] + # check if distro data is available on redis cache + cached_distro = await self.valkey_stat_client.get_image_distro(image_id) + if cached_distro: + return cached_distro - container = await docker.containers.create(container_config) - await container.start() - await container.wait() # wait until container finishes to prevent race condition - container_log = await container.log(stdout=True, stderr=True, follow=False) - await container.stop() - await container.delete() - log.debug("response: {}", container_log) - version_lines = container_log[0].splitlines() - if m := LDD_GLIBC_REGEX.search(version_lines[0]): - version = float(m.group(1)) - if version in known_glibc_distros: - distro = known_glibc_distros[version] - else: - for idx, known_version in enumerate(known_glibc_distros.keys()): - if version < known_version: - distro = list(known_glibc_distros.values())[idx - 1] - break - else: - distro = list(known_glibc_distros.values())[-1] - elif m := LDD_MUSL_REGEX.search(version_lines[0]): - distro = "alpine3.8" + container_config: dict[str, Any] = { + "Image": image["canonical"], + "Tty": True, + "Privileged": False, + "AttachStdin": False, + "AttachStdout": True, + "AttachStderr": True, + "HostConfig": { + "Init": True, + }, + "Entrypoint": [""], + "Cmd": ["ldd", "--version"], + } + + container = await docker.containers.create(container_config) + await container.start() + await container.wait() # wait until container finishes to prevent race condition + container_log = await container.log(stdout=True, stderr=True, follow=False) + await container.stop() + await container.delete() + log.debug("response: {}", container_log) + version_lines = container_log[0].splitlines() + if m := LDD_GLIBC_REGEX.search(version_lines[0]): + version = float(m.group(1)) + if version in known_glibc_distros: + distro = known_glibc_distros[version] else: - raise RuntimeError("Could not determine the C library variant.") - await self.valkey_stat_client.set_image_distro(image_id, distro) - return distro + for idx, known_version in enumerate(known_glibc_distros.keys()): + if version < known_version: + distro = list(known_glibc_distros.values())[idx - 1] + break + else: + distro = list(known_glibc_distros.values())[-1] + elif m := LDD_MUSL_REGEX.search(version_lines[0]): + distro = "alpine3.8" + else: + raise RuntimeError("Could not determine the C library variant.") + await self.valkey_stat_client.set_image_distro(image_id, distro) + return distro @override async def scan_images(self) -> ScanImagesResult: - async with closing_async(Docker()) as docker: - all_images = await docker.images.list() - scanned_images: dict[ImageCanonical, InstalledImageInfo] = {} - removed_images: dict[ImageCanonical, InstalledImageInfo] = {} - for image in all_images: - if image["RepoTags"] is None: + docker = self.docker + all_images = await docker.images.list() + scanned_images: dict[ImageCanonical, InstalledImageInfo] = {} + removed_images: dict[ImageCanonical, InstalledImageInfo] = {} + for image in all_images: + if image["RepoTags"] is None: + continue + for repo_tag in image["RepoTags"]: + if repo_tag.endswith(""): continue - for repo_tag in image["RepoTags"]: - if repo_tag.endswith(""): - continue - try: - ImageRef.parse_image_str(repo_tag, "*") - except (InvalidImageName, InvalidImageTag) as e: - if repo_tag not in self.checked_invalid_images: - log.warning( - "Image name {} does not conform to Backend.AI's image naming rule. This image will be ignored. Details: {}", - repo_tag, - e, - ) - self.checked_invalid_images.add(repo_tag) - continue - - img_detail = await docker.images.inspect(repo_tag) - labels = (img_detail.get("Config") or {}).get("Labels") - if labels is None: - continue - - kernelspec = int(labels.get(LabelName.KERNEL_SPEC, "1")) - if MIN_KERNELSPEC <= kernelspec <= MAX_KERNELSPEC: - scanned_images[ImageCanonical(repo_tag)] = ( - InstalledImageInfo.from_inspect_result( - canonical=ImageCanonical(repo_tag), - inspect_result=img_detail, - ) + try: + ImageRef.parse_image_str(repo_tag, "*") + except (InvalidImageName, InvalidImageTag) as e: + if repo_tag not in self.checked_invalid_images: + log.warning( + "Image name {} does not conform to Backend.AI's image naming rule. This image will be ignored. Details: {}", + repo_tag, + e, ) - for added_image in scanned_images.keys() - self.images.keys(): - log.debug("found kernel image: {0}", added_image) + self.checked_invalid_images.add(repo_tag) + continue - for removed_image in self.images.keys() - scanned_images.keys(): - log.debug("removed kernel image: {0}", removed_image) - removed_images[removed_image] = self.images[removed_image] + img_detail = await docker.images.inspect(repo_tag) + labels = (img_detail.get("Config") or {}).get("Labels") + if labels is None: + continue - return ScanImagesResult( - scanned_images=scanned_images, - removed_images=removed_images, - ) + kernelspec = int(labels.get(LabelName.KERNEL_SPEC, "1")) + if MIN_KERNELSPEC <= kernelspec <= MAX_KERNELSPEC: + scanned_images[ImageCanonical(repo_tag)] = ( + InstalledImageInfo.from_inspect_result( + canonical=ImageCanonical(repo_tag), + inspect_result=img_detail, + ) + ) + for added_image in scanned_images.keys() - self.images.keys(): + log.debug("found kernel image: {0}", added_image) + + for removed_image in self.images.keys() - scanned_images.keys(): + log.debug("removed kernel image: {0}", removed_image) + removed_images[removed_image] = self.images[removed_image] + + return ScanImagesResult( + scanned_images=scanned_images, + removed_images=removed_images, + ) async def handle_agent_socket(self) -> None: """ @@ -1887,16 +1888,15 @@ async def push_image( "auth": encoded_creds, } - async with closing_async(Docker()) as docker: - kwargs: dict[str, Any] = {"auth": auth_config} - if timeout_seconds != Sentinel.TOKEN: - kwargs["timeout"] = timeout_seconds - result = await docker.images.push(image_ref.canonical, **kwargs) + kwargs: dict[str, Any] = {"auth": auth_config} + if timeout_seconds != Sentinel.TOKEN: + kwargs["timeout"] = timeout_seconds + result = await self.docker.images.push(image_ref.canonical, **kwargs) - if not result: - raise RuntimeError("Failed to push image: unexpected return value from aiodocker") - if error := result[-1].get("error"): - raise RuntimeError(f"Failed to push image: {error}") + if not result: + raise RuntimeError("Failed to push image: unexpected return value from aiodocker") + if error := result[-1].get("error"): + raise RuntimeError(f"Failed to push image: {error}") @override async def pull_image( @@ -1915,19 +1915,20 @@ async def pull_image( "auth": encoded_creds, } log.info("pulling image {} from registry", image_ref.canonical) - async with closing_async(Docker()) as docker: - result = await docker.images.pull( - image_ref.canonical, auth=auth_config, timeout=timeout_seconds - ) + result = await self.docker.images.pull( + image_ref.canonical, auth=auth_config, timeout=timeout_seconds + ) - if not result: - raise RuntimeError("Failed to pull image: unexpected return value from aiodocker") - if error := result[-1].get("error"): - raise RuntimeError(f"Failed to pull image: {error}") + if not result: + raise RuntimeError("Failed to pull image: unexpected return value from aiodocker") + if error := result[-1].get("error"): + raise RuntimeError(f"Failed to pull image: {error}") - async def _purge_image(self, docker: Docker, request: DockerPurgeImageReq) -> PurgeImageResp: + async def _purge_image(self, request: DockerPurgeImageReq) -> PurgeImageResp: try: - await docker.images.delete(request.image, force=request.force, noprune=request.noprune) + await self.docker.images.delete( + request.image, force=request.force, noprune=request.noprune + ) return PurgeImageResp.success(image=request.image) except Exception as e: log.error(f'Failed to purge image "{request.image}": {e}') @@ -1935,11 +1936,10 @@ async def _purge_image(self, docker: Docker, request: DockerPurgeImageReq) -> Pu @override async def purge_images(self, request: PurgeImagesReq) -> PurgeImagesResp: - async with closing_async(Docker()) as docker, TaskGroup() as tg: + async with TaskGroup() as tg: tasks = [ tg.create_task( self._purge_image( - docker, DockerPurgeImageReq( image=image, force=request.force, noprune=request.noprune ), @@ -1960,11 +1960,10 @@ async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior ) -> bool: try: - async with closing_async(Docker()) as docker: - image_info = await docker.images.inspect(image_ref.canonical) - if auto_pull == AutoPullBehavior.DIGEST: - if image_info["Id"] != image_id: - return True + image_info = await self.docker.images.inspect(image_ref.canonical) + if auto_pull == AutoPullBehavior.DIGEST: + if image_info["Id"] != image_id: + return True log.info("found the local up-to-date image for {}", image_ref.canonical) except DockerError as e: if e.status == HTTPStatus.NOT_FOUND: @@ -1999,6 +1998,7 @@ async def init_kernel_context( self.agent_sockpath, self.resource_lock, self.network_plugin_ctx, + self.docker, restarting=restarting, cluster_ssh_port_mapping=cluster_ssh_port_mapping, gwbridge_subnet=self.gwbridge_subnet, @@ -2047,11 +2047,10 @@ async def destroy_kernel( if container_id is None: return try: - async with closing_async(Docker()) as docker: - container = docker.containers.container(container_id) - # The default timeout of the docker stop API is 10 seconds - # to kill if container does not self-terminate. - await container.stop() + container = self.docker.containers.container(container_id) + # The default timeout of the docker stop API is 10 seconds + # to kill if container does not self-terminate. + await container.stop() except DockerError as e: if e.status == HTTPStatus.CONFLICT and "is not running" in e.message: # already dead @@ -2075,122 +2074,121 @@ async def clean_kernel( restarting: bool, ) -> None: loop = current_loop() - async with closing_async(Docker()) as docker: - if container_id is not None: - container = docker.containers.container(container_id) - - async def log_iter() -> AsyncGenerator[bytes, None]: - it = container.log( - stdout=True, - stderr=True, - follow=True, - ) - async with aiotools.aclosing(it): # type: ignore[type-var] - async for line in it: - yield line.encode("utf-8") + docker = self.docker + if container_id is not None: + container = docker.containers.container(container_id) + + async def log_iter() -> AsyncGenerator[bytes, None]: + it = container.log( + stdout=True, + stderr=True, + follow=True, + ) + async with aiotools.aclosing(it): # type: ignore[type-var] + async for line in it: + yield line.encode("utf-8") - try: - with timeout(60): - await self.collect_logs(kernel_id, container_id, log_iter()) - except DockerError as e: - if e.status == HTTPStatus.NOT_FOUND: - log.warning( - "container is already cleaned or missing (k:{}, cid:{})", - kernel_id, - container_id, - ) - else: - raise - except TimeoutError: - log.warning( - "timeout for collecting container logs (k:{}, cid:{})", - kernel_id, - container_id, - ) - except Exception as e: + try: + with timeout(60): + await self.collect_logs(kernel_id, container_id, log_iter()) + except DockerError as e: + if e.status == HTTPStatus.NOT_FOUND: log.warning( - "error while collecting container logs (k:{}, cid:{})", + "container is already cleaned or missing (k:{}, cid:{})", kernel_id, container_id, - exc_info=e, ) + else: + raise + except TimeoutError: + log.warning( + "timeout for collecting container logs (k:{}, cid:{})", + kernel_id, + container_id, + ) + except Exception as e: + log.warning( + "error while collecting container logs (k:{}, cid:{})", + kernel_id, + container_id, + exc_info=e, + ) - kernel_obj = self.kernel_registry.get(kernel_id) - if kernel_obj is not None: - for domain_socket_proxy in kernel_obj.get("domain_socket_proxies", []): - if domain_socket_proxy.proxy_server.is_serving(): - domain_socket_proxy.proxy_server.close() - await domain_socket_proxy.proxy_server.wait_closed() - try: - domain_socket_proxy.host_proxy_path.unlink() - except OSError: - pass + kernel_obj = self.kernel_registry.get(kernel_id) + if kernel_obj is not None: + for domain_socket_proxy in kernel_obj.get("domain_socket_proxies", []): + if domain_socket_proxy.proxy_server.is_serving(): + domain_socket_proxy.proxy_server.close() + await domain_socket_proxy.proxy_server.wait_closed() + try: + domain_socket_proxy.host_proxy_path.unlink() + except OSError: + pass - if not self.local_config.debug.skip_container_deletion and container_id is not None: - container = docker.containers.container(container_id) - try: - with timeout(90): - await container.delete(force=True, v=True) - except DockerError as e: - if ( - e.status == HTTPStatus.CONFLICT and "already in progress" in e.message - ) or e.status == HTTPStatus.NOT_FOUND: - return - log.exception( - "unexpected docker error while deleting container (k:{}, c:{})", - kernel_id, - container_id, - ) - except TimeoutError: - log.warning("container deletion timeout (k:{}, c:{})", kernel_id, container_id) - - if not restarting: - await _clean_scratch( - loop, - self.local_config.container.scratch_type, - self.local_config.container.scratch_root, + if not self.local_config.debug.skip_container_deletion and container_id is not None: + container = docker.containers.container(container_id) + try: + with timeout(90): + await container.delete(force=True, v=True) + except DockerError as e: + if ( + e.status == HTTPStatus.CONFLICT and "already in progress" in e.message + ) or e.status == HTTPStatus.NOT_FOUND: + return + log.exception( + "unexpected docker error while deleting container (k:{}, c:{})", kernel_id, + container_id, ) - if kernel_obj: - kernel = cast(DockerKernel, kernel_obj) - if kernel.network_driver != "bridge": - try: - plugin = self.network_plugin_ctx.plugins[kernel.network_driver] - except KeyError as e: - raise RuntimeError( - f"Network plugin {kernel.network_driver} not loaded!" - ) from e - await plugin.leave_network(kernel) + except TimeoutError: + log.warning("container deletion timeout (k:{}, c:{})", kernel_id, container_id) + + if not restarting: + await _clean_scratch( + loop, + self.local_config.container.scratch_type, + self.local_config.container.scratch_root, + kernel_id, + ) + if kernel_obj: + kernel = cast(DockerKernel, kernel_obj) + if kernel.network_driver != "bridge": + try: + plugin = self.network_plugin_ctx.plugins[kernel.network_driver] + except KeyError as e: + raise RuntimeError( + f"Network plugin {kernel.network_driver} not loaded!" + ) from e + await plugin.leave_network(kernel) @override async def create_local_network(self, network_name: str) -> None: - async with closing_async(Docker()) as docker: - try: - await docker.networks.get(network_name) - except DockerError as e: - if e.status == HTTPStatus.NOT_FOUND: - await docker.networks.create({ - "Name": network_name, - "Driver": "bridge", - "Labels": { - "ai.backend.cluster-network": "1", - }, - }) - else: - raise + docker = self.docker + try: + await docker.networks.get(network_name) + except DockerError as e: + if e.status == HTTPStatus.NOT_FOUND: + await docker.networks.create({ + "Name": network_name, + "Driver": "bridge", + "Labels": { + "ai.backend.cluster-network": "1", + }, + }) + else: + raise @override async def destroy_local_network(self, network_name: str) -> None: - async with closing_async(Docker()) as docker: - try: - network = await docker.networks.get(network_name) - await network.delete() - except DockerError as e: - if e.status == HTTPStatus.NOT_FOUND: - # skip silently if already removed/missing - pass - else: - raise + try: + network = await self.docker.networks.get(network_name) + await network.delete() + except DockerError as e: + if e.status == HTTPStatus.NOT_FOUND: + # skip silently if already removed/missing + pass + else: + raise @preserve_termination_log # type: ignore[misc] async def monitor_docker_events(self) -> None: diff --git a/tests/component/agent/docker/test_agent.py b/tests/component/agent/docker/test_agent.py index 445317cbb03..6098531b125 100644 --- a/tests/component/agent/docker/test_agent.py +++ b/tests/component/agent/docker/test_agent.py @@ -80,12 +80,8 @@ async def test_init(agent: DockerAgent, mocker: Any) -> None: async def test_auto_pull_digest_when_digest_matching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.DIGEST - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_matching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert not pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -93,12 +89,8 @@ async def test_auto_pull_digest_when_digest_matching(agent: DockerAgent, mocker: async def test_auto_pull_digest_when_digest_mismatching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.DIGEST - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -106,17 +98,13 @@ async def test_auto_pull_digest_when_digest_mismatching(agent: DockerAgent, mock async def test_auto_pull_digest_when_missing(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.DIGEST - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock( side_effect=DockerError( status=HTTPStatus.NOT_FOUND, data={"message": "Simulated missing image"}, ), ) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert pull inspect_mock.assert_called_with(imgref.canonical) @@ -124,12 +112,8 @@ async def test_auto_pull_digest_when_missing(agent: DockerAgent, mocker: Any) -> async def test_auto_pull_tag_when_digest_matching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.TAG - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_matching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert not pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -137,12 +121,8 @@ async def test_auto_pull_tag_when_digest_matching(agent: DockerAgent, mocker: An async def test_auto_pull_tag_when_digest_mismatching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.TAG - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert not pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -150,17 +130,13 @@ async def test_auto_pull_tag_when_digest_mismatching(agent: DockerAgent, mocker: async def test_auto_pull_tag_when_missing(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.TAG - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock( side_effect=DockerError( status=HTTPStatus.NOT_FOUND, data={"message": "Simulated missing image"}, ), ) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert pull inspect_mock.assert_called_with(imgref.canonical) @@ -168,12 +144,8 @@ async def test_auto_pull_tag_when_missing(agent: DockerAgent, mocker: Any) -> No async def test_auto_pull_none_when_digest_matching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.NONE - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_matching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert not pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -181,12 +153,8 @@ async def test_auto_pull_none_when_digest_matching(agent: DockerAgent, mocker: A async def test_auto_pull_none_when_digest_mismatching(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.NONE - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) pull = await agent.check_image(imgref, query_digest, behavior) assert not pull inspect_mock.assert_awaited_with(imgref.canonical) @@ -194,17 +162,13 @@ async def test_auto_pull_none_when_digest_mismatching(agent: DockerAgent, mocker async def test_auto_pull_none_when_missing(agent: DockerAgent, mocker: Any) -> None: behavior = AutoPullBehavior.NONE - docker_mock = MagicMock() - docker_mock.close = AsyncMock() - docker_mock.images = MagicMock() inspect_mock = AsyncMock( side_effect=DockerError( status=HTTPStatus.NOT_FOUND, data={"message": "Simulated missing image"}, ), ) - docker_mock.images.inspect = inspect_mock - mocker.patch("ai.backend.agent.docker.agent.Docker", return_value=docker_mock) + mocker.patch.object(agent.docker.images, "inspect", new=inspect_mock) with pytest.raises(ImageNotAvailable) as e: await agent.check_image(imgref, query_digest, behavior) assert e.value.args[0] is imgref @@ -219,3 +183,68 @@ async def test_save_last_registry_exception(agent: DockerAgent, mocker: Any) -> ) await agent.save_last_registry() assert not registry_state_path.exists() + + +@pytest.fixture +async def unmanaged_agent( + local_config: Any, test_id: str, mocker: Any, socket_relay_image: Any +) -> Any: + """ + Like the ``agent`` fixture, but leaves shutdown entirely to the test so it + can assert against the lifecycle of the shared aiodocker client. + """ + dummy_etcd = DummyEtcd() + mocked_etcd_get_prefix = AsyncMock(return_value={}) + mocker.patch.object(dummy_etcd, "get_prefix", new=mocked_etcd_get_prefix) + test_case_id = secrets.token_hex(8) + kernel_registry = KernelRegistry() + agent = await DockerAgent.new( + dummy_etcd, + local_config, + stats_monitor=None, + error_monitor=None, + skip_initial_scan=True, + agent_public_key=None, + kernel_registry=kernel_registry, + computers={}, + slots={}, + agent_class=AgentClass.PRIMARY, + ) + agent.local_instance_id = test_case_id + yield agent + # Best-effort cleanup: close the shared client if the test did not already + # trigger a full shutdown. ``Docker.close`` -> ``ClientSession.close`` is + # safe to call on an already-closed session. + if not agent.docker.session.closed: + await agent.docker.close() + + +async def test_shared_docker_client_open_after_ainit(unmanaged_agent: DockerAgent) -> None: + assert unmanaged_agent.docker.session.closed is False + + +async def test_shared_docker_client_closed_after_shutdown( + unmanaged_agent: DockerAgent, +) -> None: + assert unmanaged_agent.docker.session.closed is False + await unmanaged_agent.shutdown(signal.SIGTERM) + assert unmanaged_agent.docker.session.closed is True + + +async def test_shared_docker_client_closed_when_super_shutdown_raises( + unmanaged_agent: DockerAgent, mocker: Any +) -> None: + # Simulate the base Agent.shutdown raising mid-shutdown — the shared + # aiodocker client must still be closed so its aiohttp.ClientSession does + # not leak. + class _SimulatedShutdownError(Exception): + pass + + mocker.patch( + "ai.backend.agent.agent.AbstractAgent.shutdown", + new=AsyncMock(side_effect=_SimulatedShutdownError("simulated")), + ) + assert unmanaged_agent.docker.session.closed is False + with pytest.raises(_SimulatedShutdownError): + await unmanaged_agent.shutdown(signal.SIGTERM) + assert unmanaged_agent.docker.session.closed is True