diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 5ae221a46d6..9ebb6ad7f63 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -66,7 +66,7 @@ from trafaret import DataError from ai.backend.common import msgpack, redis_helper -from ai.backend.common.bgtask import BackgroundTaskManager +from ai.backend.common.bgtask import BackgroundTaskManager, ProgressReporter from ai.backend.common.config import model_definition_iv from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef @@ -1609,6 +1609,15 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> Pull the given image from the given registry. """ + @abstractmethod + async def pull_image_in_background( + self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry + ) -> None: + """ + Pull the given image from the given registry. + Read the streaming response and report through the given ProgressReporter. + """ + @abstractmethod async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 0898eef0180..b91bc1bf064 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -26,15 +26,17 @@ List, Literal, MutableMapping, + NotRequired, Optional, Sequence, Set, Tuple, + TypedDict, Union, cast, override, ) -from uuid import UUID +from uuid import UUID, uuid4 import aiohttp import aiotools @@ -47,6 +49,7 @@ from async_timeout import timeout from ai.backend.common import redis_helper +from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.cgroup import get_cgroup_mount_point from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef from ai.backend.common.events import EventProducer, KernelLifecycleEventReason @@ -1491,6 +1494,169 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> async with closing_async(Docker()) as docker: await docker.images.pull(image_ref.canonical, auth=auth_config) + async def pull_image_in_background( + self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry + ) -> None: + auth_config = None + reg_user = registry_conf.get("username") + reg_passwd = registry_conf.get("password") + if reg_user and reg_passwd: + encoded_creds = base64.b64encode(f"{reg_user}:{reg_passwd}".encode("utf-8")).decode( + "ascii" + ) + auth_config = { + "auth": encoded_creds, + } + log.info("pulling image {} from registry", image_ref.canonical) + + class PullingProgressDetail(TypedDict): + current: int + total: int + + class EmptyPullingProgressDetail(TypedDict): + pass + + class PullResponse(TypedDict): + status: str + progressDetail: NotRequired[PullingProgressDetail | EmptyPullingProgressDetail] + progress: NotRequired[str] # ' 25.48MB/29.16MB', ' 1.399kB/1.399kB' + id: NotRequired[str] + + class ErrorDetail(TypedDict): + message: str + + class PullErrorResponse(TypedDict): + errorDetail: ErrorDetail + error: str + + bgtask_mgr = self.background_task_manager + layer_to_reporter_task_id_map: dict[str, UUID] = {} + image_pull_config = await self.etcd.get_prefix("agent/image-pull") + + if image_pull_config: + try: + use_subreporter = bool(cast(str | None, image_pull_config.get("use-subreporter"))) + except (ValueError, TypeError): + use_subreporter = False + try: + raw_val = cast(str | None, image_pull_config.get("subreporter-cool-down-sec")) + if raw_val is None: + subreporter_cool_down_sec = 1.0 + else: + subreporter_cool_down_sec = float(raw_val) + except (ValueError, TypeError): + subreporter_cool_down_sec = 1.0 + else: + use_subreporter = False + subreporter_cool_down_sec = None + + async def update_to_subreporter( + resp: PullResponse, + message: str | None = None, + current: int | None = None, + total: int | None = None, + force: bool = False, + ) -> None: + id_ = resp.get("id") + if id_ is None: + return None + if id_ not in layer_to_reporter_task_id_map: + task_id = uuid4() + subreporter = ProgressReporter( + bgtask_mgr.event_producer, task_id, cool_down_seconds=subreporter_cool_down_sec + ) + else: + task_id = layer_to_reporter_task_id_map[id_] + subreporter = reporter.subreporters[task_id] + reporter.register_subreporter(subreporter) + if current is not None: + subreporter.current_progress = current + if total is not None: + subreporter.total_progress = total + await subreporter.update(message=message, force=force) + + def register_layer_id(resp: PullResponse) -> None: + if (id_ := resp.get("id")) is not None and id_ not in layer_to_reporter_task_id_map: + dummy_id = uuid4() + layer_to_reporter_task_id_map[id_] = dummy_id + + async def handle_response(resp: PullResponse) -> None: + match resp["status"]: + case ( + "Pulling fs layer" + | "Waiting" + | "Verifying Checksum" + | "Download complete" + | "Extracting" as status + ): + if use_subreporter: + await update_to_subreporter(resp, message=status) + else: + register_layer_id(resp) + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) + await reporter.update() + case "Downloading" as status: + if use_subreporter: + detail = resp["progressDetail"] + current = cast(int | None, detail.get("current")) + total = cast(int | None, detail.get("total")) + await update_to_subreporter( + resp, message=status, current=current, total=total + ) + else: + register_layer_id(resp) + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) + await reporter.update() + case "Pull complete" as status: + if use_subreporter: + await update_to_subreporter(resp, message=status, force=True) + else: + register_layer_id(resp) + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) + await reporter.update(1, force=True) + case "Already exists": + reporter.total_progress += 1 + await reporter.update(1, force=True) + case status if status.startswith("Pulling from"): + # Pulling has started. + # Value of 'id' field in response dict does not represent layer id. + await reporter.update(message=status) + case status if status.startswith("Digest:") or status.startswith("Status:"): + # Only 'status' field exists in response dict. + await reporter.update(message=status) + case _: + await reporter.update(message=resp["status"]) + + async def handle_err_response(resp: PullErrorResponse) -> None: + await reporter.update(message=resp.get("error"), force=True) + + async with closing_async(Docker()) as docker: + async for resp in docker.images.pull( + image_ref.canonical, auth=auth_config, stream=True + ): + match resp: + case dict() if resp.get("status"): + _resp = PullResponse(status=resp["status"]) + if detail := resp.get("progressDetail"): + _resp["progressDetail"] = detail + if progress := resp.get("progress"): + _resp["progress"] = progress + if id := resp.get("id"): + _resp["id"] = id + await handle_response(_resp) + case dict() if resp.get("error"): + await handle_err_response( + PullErrorResponse( + error=resp["error"], + errorDetail=resp["errorDetail"], + ) + ) + case _: + log.warning( + f"Unable to deserialize pulling response. skip. (r:{str(resp)})" + ) + continue + async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior ) -> bool: diff --git a/src/ai/backend/agent/dummy/agent.py b/src/ai/backend/agent/dummy/agent.py index 09af9e401f0..1dcf45fe295 100644 --- a/src/ai/backend/agent/dummy/agent.py +++ b/src/ai/backend/agent/dummy/agent.py @@ -13,6 +13,7 @@ override, ) +from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.config import read_from_file from ai.backend.common.docker import ImageRef from ai.backend.common.events import EventProducer @@ -282,6 +283,11 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> delay = self.dummy_agent_cfg["delay"]["pull-image"] await asyncio.sleep(delay) + async def pull_image_in_background( + self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry + ) -> None: + return None + async def push_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: delay = self.dummy_agent_cfg["delay"]["push-image"] await asyncio.sleep(delay) diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index c2e1fbc99d6..ee78d09ab90 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -33,6 +33,7 @@ from kubernetes_asyncio import config as kube_config from ai.backend.common.asyncio import current_loop +from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ImageRef from ai.backend.common.etcd import AsyncEtcd from ai.backend.common.events import EventProducer @@ -1015,6 +1016,12 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> # TODO: Add support for appropriate image pulling mechanism on K8s pass + async def pull_image_in_background( + self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry + ) -> None: + # TODO: Add support for appropriate image pulling mechanism on K8s + pass + async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior ) -> bool: diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index ce48ee3692b..bb65b813bbd 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -507,7 +507,7 @@ async def _pull(reporter: ProgressReporter) -> None: image=str(img_ref), ) ) - await self.agent.pull_image(img_ref, img_conf["registry"]) + await self.agent.pull_image_in_background(reporter, img_ref, img_conf["registry"]) await self.agent.produce_event( ImagePullFinishedEvent( image=str(img_ref),