Skip to content

Commit

Permalink
feat: progress report for image pulling
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 5, 2024
1 parent 870f3de commit 135ab2e
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from trafaret import DataError

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.bgtask import BackgroundTaskManager
from ai.backend.common.bgtask import BackgroundTaskManager, ProgressReporter
from ai.backend.common.config import model_definition_iv
from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
Expand Down Expand Up @@ -1609,6 +1609,15 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
Pull the given image from the given registry.
"""

@abstractmethod
async def pull_image_in_background(
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
) -> None:
"""
Pull the given image from the given registry.
Read the streaming response and report through the given ProgressReporter.
"""

@abstractmethod
async def check_image(
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
Expand Down
168 changes: 167 additions & 1 deletion src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
List,
Literal,
MutableMapping,
NotRequired,
Optional,
Sequence,
Set,
Tuple,
TypedDict,
Union,
cast,
override,
)
from uuid import UUID
from uuid import UUID, uuid4

import aiohttp
import aiotools
Expand All @@ -47,6 +49,7 @@
from async_timeout import timeout

from ai.backend.common import redis_helper
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.cgroup import get_cgroup_mount_point
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
from ai.backend.common.events import EventProducer, KernelLifecycleEventReason
Expand Down Expand Up @@ -1491,6 +1494,169 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
async with closing_async(Docker()) as docker:
await docker.images.pull(image_ref.canonical, auth=auth_config)

async def pull_image_in_background(
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
) -> None:
auth_config = None
reg_user = registry_conf.get("username")
reg_passwd = registry_conf.get("password")
if reg_user and reg_passwd:
encoded_creds = base64.b64encode(f"{reg_user}:{reg_passwd}".encode("utf-8")).decode(
"ascii"
)
auth_config = {
"auth": encoded_creds,
}
log.info("pulling image {} from registry", image_ref.canonical)

class PullingProgressDetail(TypedDict):
current: int
total: int

class EmptyPullingProgressDetail(TypedDict):
pass

class PullResponse(TypedDict):
status: str
progressDetail: NotRequired[PullingProgressDetail | EmptyPullingProgressDetail]
progress: NotRequired[str] # ' 25.48MB/29.16MB', ' 1.399kB/1.399kB'
id: NotRequired[str]

class ErrorDetail(TypedDict):
message: str

class PullErrorResponse(TypedDict):
errorDetail: ErrorDetail
error: str

bgtask_mgr = self.background_task_manager
layer_to_reporter_task_id_map: dict[str, UUID] = {}
image_pull_config = await self.etcd.get_prefix("agent/image-pull")

if image_pull_config:
try:
use_subreporter = bool(cast(str | None, image_pull_config.get("use-subreporter")))
except (ValueError, TypeError):
use_subreporter = False
try:
raw_val = cast(str | None, image_pull_config.get("subreporter-cool-down-sec"))
if raw_val is None:
subreporter_cool_down_sec = 1.0
else:
subreporter_cool_down_sec = float(raw_val)
except (ValueError, TypeError):
subreporter_cool_down_sec = 1.0
else:
use_subreporter = False
subreporter_cool_down_sec = None

async def update_to_subreporter(
resp: PullResponse,
message: str | None = None,
current: int | None = None,
total: int | None = None,
force: bool = False,
) -> None:
id_ = resp.get("id")
if id_ is None:
return None
if id_ not in layer_to_reporter_task_id_map:
task_id = uuid4()
subreporter = ProgressReporter(
bgtask_mgr.event_producer, task_id, cool_down_seconds=subreporter_cool_down_sec
)
else:
task_id = layer_to_reporter_task_id_map[id_]
subreporter = reporter.subreporters[task_id]
reporter.register_subreporter(subreporter)
if current is not None:
subreporter.current_progress = current
if total is not None:
subreporter.total_progress = total
await subreporter.update(message=message, force=force)

def register_layer_id(resp: PullResponse) -> None:
if (id_ := resp.get("id")) is not None and id_ not in layer_to_reporter_task_id_map:
dummy_id = uuid4()
layer_to_reporter_task_id_map[id_] = dummy_id

async def handle_response(resp: PullResponse) -> None:
match resp["status"]:
case (
"Pulling fs layer"
| "Waiting"
| "Verifying Checksum"
| "Download complete"
| "Extracting" as status
):
if use_subreporter:
await update_to_subreporter(resp, message=status)
else:
register_layer_id(resp)
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
await reporter.update()
case "Downloading" as status:
if use_subreporter:
detail = resp["progressDetail"]
current = cast(int | None, detail.get("current"))
total = cast(int | None, detail.get("total"))
await update_to_subreporter(
resp, message=status, current=current, total=total
)
else:
register_layer_id(resp)
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
await reporter.update()
case "Pull complete" as status:
if use_subreporter:
await update_to_subreporter(resp, message=status, force=True)
else:
register_layer_id(resp)
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
await reporter.update(1, force=True)
case "Already exists":
reporter.total_progress += 1
await reporter.update(1, force=True)
case status if status.startswith("Pulling from"):
# Pulling has started.
# Value of 'id' field in response dict does not represent layer id.
await reporter.update(message=status)
case status if status.startswith("Digest:") or status.startswith("Status:"):
# Only 'status' field exists in response dict.
await reporter.update(message=status)
case _:
await reporter.update(message=resp["status"])

async def handle_err_response(resp: PullErrorResponse) -> None:
await reporter.update(message=resp.get("error"), force=True)

async with closing_async(Docker()) as docker:
async for resp in docker.images.pull(
image_ref.canonical, auth=auth_config, stream=True
):
match resp:
case dict() if resp.get("status"):
_resp = PullResponse(status=resp["status"])
if detail := resp.get("progressDetail"):
_resp["progressDetail"] = detail
if progress := resp.get("progress"):
_resp["progress"] = progress
if id := resp.get("id"):
_resp["id"] = id
await handle_response(_resp)
case dict() if resp.get("error"):
await handle_err_response(
PullErrorResponse(
error=resp["error"],
errorDetail=resp["errorDetail"],
)
)
case _:
log.warning(
f"Unable to deserialize pulling response. skip. (r:{str(resp)})"
)
continue

async def check_image(
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
override,
)

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.config import read_from_file
from ai.backend.common.docker import ImageRef
from ai.backend.common.events import EventProducer
Expand Down Expand Up @@ -282,6 +283,11 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
delay = self.dummy_agent_cfg["delay"]["pull-image"]
await asyncio.sleep(delay)

async def pull_image_in_background(
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
) -> None:
return None

async def push_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None:
delay = self.dummy_agent_cfg["delay"]["push-image"]
await asyncio.sleep(delay)
Expand Down
7 changes: 7 additions & 0 deletions src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from kubernetes_asyncio import config as kube_config

from ai.backend.common.asyncio import current_loop
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.docker import ImageRef
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.events import EventProducer
Expand Down Expand Up @@ -1015,6 +1016,12 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
# TODO: Add support for appropriate image pulling mechanism on K8s
pass

async def pull_image_in_background(
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
) -> None:
# TODO: Add support for appropriate image pulling mechanism on K8s
pass

async def check_image(
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ async def _pull(reporter: ProgressReporter) -> None:
image=str(img_ref),
)
)
await self.agent.pull_image(img_ref, img_conf["registry"])
await self.agent.pull_image_in_background(reporter, img_ref, img_conf["registry"])
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
Expand Down

0 comments on commit 135ab2e

Please sign in to comment.