|
26 | 26 | List,
|
27 | 27 | Literal,
|
28 | 28 | MutableMapping,
|
| 29 | + NotRequired, |
29 | 30 | Optional,
|
30 | 31 | Sequence,
|
31 | 32 | Set,
|
32 | 33 | Tuple,
|
| 34 | + TypedDict, |
33 | 35 | Union,
|
34 | 36 | cast,
|
35 | 37 | override,
|
36 | 38 | )
|
37 |
| -from uuid import UUID |
| 39 | +from uuid import UUID, uuid4 |
38 | 40 |
|
39 | 41 | import aiohttp
|
40 | 42 | import aiotools
|
|
47 | 49 | from async_timeout import timeout
|
48 | 50 |
|
49 | 51 | from ai.backend.common import redis_helper
|
| 52 | +from ai.backend.common.bgtask import ProgressReporter |
50 | 53 | from ai.backend.common.cgroup import get_cgroup_mount_point
|
51 | 54 | from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
|
52 | 55 | from ai.backend.common.events import EventProducer, KernelLifecycleEventReason
|
@@ -1491,6 +1494,169 @@ async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) ->
|
1491 | 1494 | async with closing_async(Docker()) as docker:
|
1492 | 1495 | await docker.images.pull(image_ref.canonical, auth=auth_config)
|
1493 | 1496 |
|
| 1497 | + async def pull_image_in_background( |
| 1498 | + self, reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry |
| 1499 | + ) -> None: |
| 1500 | + auth_config = None |
| 1501 | + reg_user = registry_conf.get("username") |
| 1502 | + reg_passwd = registry_conf.get("password") |
| 1503 | + if reg_user and reg_passwd: |
| 1504 | + encoded_creds = base64.b64encode(f"{reg_user}:{reg_passwd}".encode("utf-8")).decode( |
| 1505 | + "ascii" |
| 1506 | + ) |
| 1507 | + auth_config = { |
| 1508 | + "auth": encoded_creds, |
| 1509 | + } |
| 1510 | + log.info("pulling image {} from registry", image_ref.canonical) |
| 1511 | + |
| 1512 | + class PullingProgressDetail(TypedDict): |
| 1513 | + current: int |
| 1514 | + total: int |
| 1515 | + |
| 1516 | + class EmptyPullingProgressDetail(TypedDict): |
| 1517 | + pass |
| 1518 | + |
| 1519 | + class PullResponse(TypedDict): |
| 1520 | + status: str |
| 1521 | + progressDetail: NotRequired[PullingProgressDetail | EmptyPullingProgressDetail] |
| 1522 | + progress: NotRequired[str] # ' 25.48MB/29.16MB', ' 1.399kB/1.399kB' |
| 1523 | + id: NotRequired[str] |
| 1524 | + |
| 1525 | + class ErrorDetail(TypedDict): |
| 1526 | + message: str |
| 1527 | + |
| 1528 | + class PullErrorResponse(TypedDict): |
| 1529 | + errorDetail: ErrorDetail |
| 1530 | + error: str |
| 1531 | + |
| 1532 | + bgtask_mgr = self.background_task_manager |
| 1533 | + layer_to_reporter_task_id_map: dict[str, UUID] = {} |
| 1534 | + image_pull_config = await self.etcd.get_prefix("agent/image-pull") |
| 1535 | + |
| 1536 | + if image_pull_config: |
| 1537 | + try: |
| 1538 | + use_subreporter = bool(cast(str | None, image_pull_config.get("use-subreporter"))) |
| 1539 | + except (ValueError, TypeError): |
| 1540 | + use_subreporter = False |
| 1541 | + try: |
| 1542 | + raw_val = cast(str | None, image_pull_config.get("subreporter-cool-down-sec")) |
| 1543 | + if raw_val is None: |
| 1544 | + subreporter_cool_down_sec = 1.0 |
| 1545 | + else: |
| 1546 | + subreporter_cool_down_sec = float(raw_val) |
| 1547 | + except (ValueError, TypeError): |
| 1548 | + subreporter_cool_down_sec = 1.0 |
| 1549 | + else: |
| 1550 | + use_subreporter = False |
| 1551 | + subreporter_cool_down_sec = None |
| 1552 | + |
| 1553 | + async def update_to_subreporter( |
| 1554 | + resp: PullResponse, |
| 1555 | + message: str | None = None, |
| 1556 | + current: int | None = None, |
| 1557 | + total: int | None = None, |
| 1558 | + force: bool = False, |
| 1559 | + ) -> None: |
| 1560 | + id_ = resp.get("id") |
| 1561 | + if id_ is None: |
| 1562 | + return None |
| 1563 | + if id_ not in layer_to_reporter_task_id_map: |
| 1564 | + task_id = uuid4() |
| 1565 | + subreporter = ProgressReporter( |
| 1566 | + bgtask_mgr.event_producer, task_id, cool_down_seconds=subreporter_cool_down_sec |
| 1567 | + ) |
| 1568 | + else: |
| 1569 | + task_id = layer_to_reporter_task_id_map[id_] |
| 1570 | + subreporter = reporter.subreporters[task_id] |
| 1571 | + reporter.register_subreporter(subreporter) |
| 1572 | + if current is not None: |
| 1573 | + subreporter.current_progress = current |
| 1574 | + if total is not None: |
| 1575 | + subreporter.total_progress = total |
| 1576 | + await subreporter.update(message=message, force=force) |
| 1577 | + |
| 1578 | + def register_layer_id(resp: PullResponse) -> None: |
| 1579 | + if (id_ := resp.get("id")) is not None and id_ not in layer_to_reporter_task_id_map: |
| 1580 | + dummy_id = uuid4() |
| 1581 | + layer_to_reporter_task_id_map[id_] = dummy_id |
| 1582 | + |
| 1583 | + async def handle_response(resp: PullResponse) -> None: |
| 1584 | + match resp["status"]: |
| 1585 | + case ( |
| 1586 | + "Pulling fs layer" |
| 1587 | + | "Waiting" |
| 1588 | + | "Verifying Checksum" |
| 1589 | + | "Download complete" |
| 1590 | + | "Extracting" as status |
| 1591 | + ): |
| 1592 | + if use_subreporter: |
| 1593 | + await update_to_subreporter(resp, message=status) |
| 1594 | + else: |
| 1595 | + register_layer_id(resp) |
| 1596 | + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) |
| 1597 | + await reporter.update() |
| 1598 | + case "Downloading" as status: |
| 1599 | + if use_subreporter: |
| 1600 | + detail = resp["progressDetail"] |
| 1601 | + current = cast(int | None, detail.get("current")) |
| 1602 | + total = cast(int | None, detail.get("total")) |
| 1603 | + await update_to_subreporter( |
| 1604 | + resp, message=status, current=current, total=total |
| 1605 | + ) |
| 1606 | + else: |
| 1607 | + register_layer_id(resp) |
| 1608 | + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) |
| 1609 | + await reporter.update() |
| 1610 | + case "Pull complete" as status: |
| 1611 | + if use_subreporter: |
| 1612 | + await update_to_subreporter(resp, message=status, force=True) |
| 1613 | + else: |
| 1614 | + register_layer_id(resp) |
| 1615 | + reporter.total_progress = len(layer_to_reporter_task_id_map.keys()) |
| 1616 | + await reporter.update(1, force=True) |
| 1617 | + case "Already exists": |
| 1618 | + reporter.total_progress += 1 |
| 1619 | + await reporter.update(1, force=True) |
| 1620 | + case status if status.startswith("Pulling from"): |
| 1621 | + # Pulling has started. |
| 1622 | + # Value of 'id' field in response dict does not represent layer id. |
| 1623 | + await reporter.update(message=status) |
| 1624 | + case status if status.startswith("Digest:") or status.startswith("Status:"): |
| 1625 | + # Only 'status' field exists in response dict. |
| 1626 | + await reporter.update(message=status) |
| 1627 | + case _: |
| 1628 | + await reporter.update(message=resp["status"]) |
| 1629 | + |
| 1630 | + async def handle_err_response(resp: PullErrorResponse) -> None: |
| 1631 | + await reporter.update(message=resp.get("error"), force=True) |
| 1632 | + |
| 1633 | + async with closing_async(Docker()) as docker: |
| 1634 | + async for resp in docker.images.pull( |
| 1635 | + image_ref.canonical, auth=auth_config, stream=True |
| 1636 | + ): |
| 1637 | + match resp: |
| 1638 | + case dict() if resp.get("status"): |
| 1639 | + _resp = PullResponse(status=resp["status"]) |
| 1640 | + if detail := resp.get("progressDetail"): |
| 1641 | + _resp["progressDetail"] = detail |
| 1642 | + if progress := resp.get("progress"): |
| 1643 | + _resp["progress"] = progress |
| 1644 | + if id := resp.get("id"): |
| 1645 | + _resp["id"] = id |
| 1646 | + await handle_response(_resp) |
| 1647 | + case dict() if resp.get("error"): |
| 1648 | + await handle_err_response( |
| 1649 | + PullErrorResponse( |
| 1650 | + error=resp["error"], |
| 1651 | + errorDetail=resp["errorDetail"], |
| 1652 | + ) |
| 1653 | + ) |
| 1654 | + case _: |
| 1655 | + log.warning( |
| 1656 | + f"Unable to deserialize pulling response. skip. (r:{str(resp)})" |
| 1657 | + ) |
| 1658 | + continue |
| 1659 | + |
1494 | 1660 | async def check_image(
|
1495 | 1661 | self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
|
1496 | 1662 | ) -> bool:
|
|
0 commit comments