From f505b106a0be5ccd8d3df4a8919a01b6c6cc9622 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Mon, 19 Aug 2024 19:44:28 +0900 Subject: [PATCH] chore: Add sub-progress reporter --- src/ai/backend/common/bgtask.py | 98 +++++++++++++++++++++++++++------ src/ai/backend/common/events.py | 5 +- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/src/ai/backend/common/bgtask.py b/src/ai/backend/common/bgtask.py index 243b3c2d2e8..e2d42396323 100644 --- a/src/ai/backend/common/bgtask.py +++ b/src/ai/backend/common/bgtask.py @@ -7,7 +7,10 @@ import uuid import weakref from collections import defaultdict +from collections.abc import Mapping +from datetime import datetime from typing import ( + Annotated, Any, AsyncIterator, Awaitable, @@ -19,10 +22,18 @@ Type, TypeAlias, Union, + cast, ) from aiohttp import web from aiohttp_sse import sse_response +from dateutil.tz import tzutc +from pydantic import ( + BaseModel, + Field, + PlainSerializer, + field_serializer, +) from redis.asyncio import Redis from redis.asyncio.client import Pipeline @@ -49,11 +60,35 @@ MAX_BGTASK_ARCHIVE_PERIOD: Final = 86400 # 24 hours +NumSerializedToStr = Annotated[ + int | float, PlainSerializer(lambda x: str(x), return_type=str, when_used="json") +] + + +class ProgressModel(BaseModel): + current: NumSerializedToStr = Field() + total: NumSerializedToStr = Field() + msg: str = Field(default="") + last_update: NumSerializedToStr = Field() + last_update_datetime: datetime = Field() + subreporter_task_ids: list[uuid.UUID] = Field() + + @field_serializer("subreporter_task_ids", when_used="json") + def stringify_task_ids(self, subreporter_task_ids: list[uuid.UUID], _info: Any) -> str: + return ",".join([str(val) for val in subreporter_task_ids]) + + @field_serializer("last_update_datetime", when_used="json") + def stringify_dt(self, last_update_datetime: datetime, _info: Any) -> str: + return last_update_datetime.isoformat() + + class ProgressReporter: event_producer: Final[EventProducer] task_id: Final[uuid.UUID] total_progress: Union[int, float] current_progress: Union[int, float] + subreporters: dict[uuid.UUID, ProgressReporter] + cool_down_seconds: float | None def __init__( self, @@ -61,21 +96,28 @@ def __init__( task_id: uuid.UUID, current_progress: int = 0, total_progress: int = 0, + subreporters: dict[uuid.UUID, ProgressReporter] | None = None, + cool_down_seconds: float | None = None, ) -> None: self.event_producer = event_dispatcher self.task_id = task_id self.current_progress = current_progress self.total_progress = total_progress + self.subreporters = subreporters if subreporters is not None else {} + self.cool_down_seconds = cool_down_seconds - async def update( - self, - increment: Union[int, float] = 0, - message: str | None = None, - ) -> None: - self.current_progress += increment - # keep the state as local variables because they might be changed - # due to interleaving at await statements below. - current, total = self.current_progress, self.total_progress + self._report_time = time.time() + + def register_subreporter(self, reporter: ProgressReporter) -> None: + if reporter.task_id not in self.subreporters: + self.subreporters[reporter.task_id] = reporter + + async def _update(self, data: ProgressModel, force: bool = False) -> None: + now = time.time() + if not force and ( + self.cool_down_seconds is not None and now - self._report_time < self.cool_down_seconds + ): + return redis_producer = self.event_producer.redis_client async def _pipe_builder(r: Redis) -> Pipeline: @@ -83,12 +125,7 @@ async def _pipe_builder(r: Redis) -> Pipeline: tracker_key = f"bgtask.{self.task_id}" await pipe.hset( tracker_key, - mapping={ - "current": str(current), - "total": str(total), - "msg": message or "", - "last_update": str(time.time()), - }, + mapping=cast(Mapping[str | bytes, str], data.model_dump(mode="json")), ) await pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD) return pipe @@ -97,11 +134,33 @@ async def _pipe_builder(r: Redis) -> Pipeline: await self.event_producer.produce_event( BgtaskUpdatedEvent( self.task_id, - message=message, - current_progress=current, - total_progress=total, + message=data.msg, + current_progress=data.current, + total_progress=data.total, + subreporter_task_ids=data.subreporter_task_ids, ), ) + self._report_time = now + + async def update( + self, + increment: Union[int, float] = 0, + message: str | None = None, + force: bool = False, + ) -> None: + now = time.time() + current_dt = datetime.now(tzutc()) + self.current_progress += increment + + data = ProgressModel( + current=self.current_progress, + total=self.total_progress, + msg=message or "", + last_update=now, + last_update_datetime=current_dt, + subreporter_task_ids=list(self.subreporters.keys()), + ) + await self._update(data, force=force) BackgroundTask = Callable[[ProgressReporter], Awaitable[str | None]] @@ -158,6 +217,9 @@ async def push_bgtask_events( case BgtaskUpdatedEvent(): body["current_progress"] = event.current_progress body["total_progress"] = event.total_progress + body["subreporter_task_id"] = [ + str(id) for id in event.subreporter_task_ids + ] await resp.send(json.dumps(body), event=event.name, retry=5) case BgtaskDoneEvent(): if extra_data: diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index d561dd7bf3b..e37d4c50f95 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -610,6 +610,7 @@ class BgtaskUpdatedEvent(AbstractEvent): task_id: uuid.UUID = attrs.field() current_progress: float = attrs.field() total_progress: float = attrs.field() + subreporter_task_ids: list[uuid.UUID] = attrs.field() message: Optional[str] = attrs.field(default=None) def serialize(self) -> tuple: @@ -617,6 +618,7 @@ def serialize(self) -> tuple: str(self.task_id), self.current_progress, self.total_progress, + tuple(str(v) for v in self.subreporter_task_ids), self.message, ) @@ -626,7 +628,8 @@ def deserialize(cls, value: tuple): uuid.UUID(value[0]), value[1], value[2], - value[3], + list(uuid.UUID(v) for v in value[3]), + value[4], )