Skip to content

Commit 135ab2e

Browse files
committed
feat: progress report for image pulling
1 parent 870f3de commit 135ab2e

File tree

5 files changed

+191
-3
lines changed

5 files changed

+191
-3
lines changed

src/ai/backend/agent/agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from trafaret import DataError
6767

6868
from ai.backend.common import msgpack, redis_helper
69-
from ai.backend.common.bgtask import BackgroundTaskManager
69+
from ai.backend.common.bgtask import BackgroundTaskManager, ProgressReporter
7070
from ai.backend.common.config import model_definition_iv
7171
from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB
7272
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
@@ -1609,6 +1609,15 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
16091609
Pull the given image from the given registry.
16101610
"""
16111611

1612+
@abstractmethod
1613+
async def pull_image_in_background(
1614+
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
1615+
) -> None:
1616+
"""
1617+
Pull the given image from the given registry.
1618+
Read the streaming response and report through the given ProgressReporter.
1619+
"""
1620+
16121621
@abstractmethod
16131622
async def check_image(
16141623
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior

src/ai/backend/agent/docker/agent.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@
2626
List,
2727
Literal,
2828
MutableMapping,
29+
NotRequired,
2930
Optional,
3031
Sequence,
3132
Set,
3233
Tuple,
34+
TypedDict,
3335
Union,
3436
cast,
3537
override,
3638
)
37-
from uuid import UUID
39+
from uuid import UUID, uuid4
3840

3941
import aiohttp
4042
import aiotools
@@ -47,6 +49,7 @@
4749
from async_timeout import timeout
4850

4951
from ai.backend.common import redis_helper
52+
from ai.backend.common.bgtask import ProgressReporter
5053
from ai.backend.common.cgroup import get_cgroup_mount_point
5154
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
5255
from ai.backend.common.events import EventProducer, KernelLifecycleEventReason
@@ -1491,6 +1494,169 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
14911494
async with closing_async(Docker()) as docker:
14921495
await docker.images.pull(image_ref.canonical, auth=auth_config)
14931496

1497+
async def pull_image_in_background(
1498+
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
1499+
) -> None:
1500+
auth_config = None
1501+
reg_user = registry_conf.get("username")
1502+
reg_passwd = registry_conf.get("password")
1503+
if reg_user and reg_passwd:
1504+
encoded_creds = base64.b64encode(f"{reg_user}:{reg_passwd}".encode("utf-8")).decode(
1505+
"ascii"
1506+
)
1507+
auth_config = {
1508+
"auth": encoded_creds,
1509+
}
1510+
log.info("pulling image {} from registry", image_ref.canonical)
1511+
1512+
class PullingProgressDetail(TypedDict):
1513+
current: int
1514+
total: int
1515+
1516+
class EmptyPullingProgressDetail(TypedDict):
1517+
pass
1518+
1519+
class PullResponse(TypedDict):
1520+
status: str
1521+
progressDetail: NotRequired[PullingProgressDetail | EmptyPullingProgressDetail]
1522+
progress: NotRequired[str] # ' 25.48MB/29.16MB', ' 1.399kB/1.399kB'
1523+
id: NotRequired[str]
1524+
1525+
class ErrorDetail(TypedDict):
1526+
message: str
1527+
1528+
class PullErrorResponse(TypedDict):
1529+
errorDetail: ErrorDetail
1530+
error: str
1531+
1532+
bgtask_mgr = self.background_task_manager
1533+
layer_to_reporter_task_id_map: dict[str, UUID] = {}
1534+
image_pull_config = await self.etcd.get_prefix("agent/image-pull")
1535+
1536+
if image_pull_config:
1537+
try:
1538+
use_subreporter = bool(cast(str | None, image_pull_config.get("use-subreporter")))
1539+
except (ValueError, TypeError):
1540+
use_subreporter = False
1541+
try:
1542+
raw_val = cast(str | None, image_pull_config.get("subreporter-cool-down-sec"))
1543+
if raw_val is None:
1544+
subreporter_cool_down_sec = 1.0
1545+
else:
1546+
subreporter_cool_down_sec = float(raw_val)
1547+
except (ValueError, TypeError):
1548+
subreporter_cool_down_sec = 1.0
1549+
else:
1550+
use_subreporter = False
1551+
subreporter_cool_down_sec = None
1552+
1553+
async def update_to_subreporter(
1554+
resp: PullResponse,
1555+
message: str | None = None,
1556+
current: int | None = None,
1557+
total: int | None = None,
1558+
force: bool = False,
1559+
) -> None:
1560+
id_ = resp.get("id")
1561+
if id_ is None:
1562+
return None
1563+
if id_ not in layer_to_reporter_task_id_map:
1564+
task_id = uuid4()
1565+
subreporter = ProgressReporter(
1566+
bgtask_mgr.event_producer, task_id, cool_down_seconds=subreporter_cool_down_sec
1567+
)
1568+
else:
1569+
task_id = layer_to_reporter_task_id_map[id_]
1570+
subreporter = reporter.subreporters[task_id]
1571+
reporter.register_subreporter(subreporter)
1572+
if current is not None:
1573+
subreporter.current_progress = current
1574+
if total is not None:
1575+
subreporter.total_progress = total
1576+
await subreporter.update(message=message, force=force)
1577+
1578+
def register_layer_id(resp: PullResponse) -> None:
1579+
if (id_ := resp.get("id")) is not None and id_ not in layer_to_reporter_task_id_map:
1580+
dummy_id = uuid4()
1581+
layer_to_reporter_task_id_map[id_] = dummy_id
1582+
1583+
async def handle_response(resp: PullResponse) -> None:
1584+
match resp["status"]:
1585+
case (
1586+
"Pulling fs layer"
1587+
| "Waiting"
1588+
| "Verifying Checksum"
1589+
| "Download complete"
1590+
| "Extracting" as status
1591+
):
1592+
if use_subreporter:
1593+
await update_to_subreporter(resp, message=status)
1594+
else:
1595+
register_layer_id(resp)
1596+
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
1597+
await reporter.update()
1598+
case "Downloading" as status:
1599+
if use_subreporter:
1600+
detail = resp["progressDetail"]
1601+
current = cast(int | None, detail.get("current"))
1602+
total = cast(int | None, detail.get("total"))
1603+
await update_to_subreporter(
1604+
resp, message=status, current=current, total=total
1605+
)
1606+
else:
1607+
register_layer_id(resp)
1608+
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
1609+
await reporter.update()
1610+
case "Pull complete" as status:
1611+
if use_subreporter:
1612+
await update_to_subreporter(resp, message=status, force=True)
1613+
else:
1614+
register_layer_id(resp)
1615+
reporter.total_progress = len(layer_to_reporter_task_id_map.keys())
1616+
await reporter.update(1, force=True)
1617+
case "Already exists":
1618+
reporter.total_progress += 1
1619+
await reporter.update(1, force=True)
1620+
case status if status.startswith("Pulling from"):
1621+
# Pulling has started.
1622+
# Value of 'id' field in response dict does not represent layer id.
1623+
await reporter.update(message=status)
1624+
case status if status.startswith("Digest:") or status.startswith("Status:"):
1625+
# Only 'status' field exists in response dict.
1626+
await reporter.update(message=status)
1627+
case _:
1628+
await reporter.update(message=resp["status"])
1629+
1630+
async def handle_err_response(resp: PullErrorResponse) -> None:
1631+
await reporter.update(message=resp.get("error"), force=True)
1632+
1633+
async with closing_async(Docker()) as docker:
1634+
async for resp in docker.images.pull(
1635+
image_ref.canonical, auth=auth_config, stream=True
1636+
):
1637+
match resp:
1638+
case dict() if resp.get("status"):
1639+
_resp = PullResponse(status=resp["status"])
1640+
if detail := resp.get("progressDetail"):
1641+
_resp["progressDetail"] = detail
1642+
if progress := resp.get("progress"):
1643+
_resp["progress"] = progress
1644+
if id := resp.get("id"):
1645+
_resp["id"] = id
1646+
await handle_response(_resp)
1647+
case dict() if resp.get("error"):
1648+
await handle_err_response(
1649+
PullErrorResponse(
1650+
error=resp["error"],
1651+
errorDetail=resp["errorDetail"],
1652+
)
1653+
)
1654+
case _:
1655+
log.warning(
1656+
f"Unable to deserialize pulling response. skip. (r:{str(resp)})"
1657+
)
1658+
continue
1659+
14941660
async def check_image(
14951661
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
14961662
) -> bool:

src/ai/backend/agent/dummy/agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
override,
1414
)
1515

16+
from ai.backend.common.bgtask import ProgressReporter
1617
from ai.backend.common.config import read_from_file
1718
from ai.backend.common.docker import ImageRef
1819
from ai.backend.common.events import EventProducer
@@ -282,6 +283,11 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
282283
delay = self.dummy_agent_cfg["delay"]["pull-image"]
283284
await asyncio.sleep(delay)
284285

286+
async def pull_image_in_background(
287+
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
288+
) -> None:
289+
return None
290+
285291
async def push_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None:
286292
delay = self.dummy_agent_cfg["delay"]["push-image"]
287293
await asyncio.sleep(delay)

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from kubernetes_asyncio import config as kube_config
3434

3535
from ai.backend.common.asyncio import current_loop
36+
from ai.backend.common.bgtask import ProgressReporter
3637
from ai.backend.common.docker import ImageRef
3738
from ai.backend.common.etcd import AsyncEtcd
3839
from ai.backend.common.events import EventProducer
@@ -1015,6 +1016,12 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
10151016
# TODO: Add support for appropriate image pulling mechanism on K8s
10161017
pass
10171018

1019+
async def pull_image_in_background(
1020+
self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry
1021+
) -> None:
1022+
# TODO: Add support for appropriate image pulling mechanism on K8s
1023+
pass
1024+
10181025
async def check_image(
10191026
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
10201027
) -> bool:

src/ai/backend/agent/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ async def _pull(reporter: ProgressReporter) -> None:
507507
image=str(img_ref),
508508
)
509509
)
510-
await self.agent.pull_image(img_ref, img_conf["registry"])
510+
await self.agent.pull_image_in_background(reporter, img_ref, img_conf["registry"])
511511
await self.agent.produce_event(
512512
ImagePullFinishedEvent(
513513
image=str(img_ref),

0 commit comments

Comments
 (0)