diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 63cfee21573..91ae7adbb00 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -68,6 +68,7 @@ RedisConnectionInfo, ResourceSlot, SessionId, + SessionTypes, aobject, ) from ai.backend.logging import BraceStyleAdapter @@ -403,26 +404,52 @@ def _pipeline(r: Redis) -> RedisPipeline: raise asyncio.CancelledError() raise - async def _load_scheduler( - self, - db_sess: SASession, - sgroup_name: str, - ) -> tuple[AbstractScheduler, AbstractAgentSelector]: + async def _get_scaling_group_data( + self, db_sess: SASession, sgroup_name: str + ) -> tuple[str, ScalingGroupOpts]: query = sa.select(ScalingGroupRow.scheduler, ScalingGroupRow.scheduler_opts).where( ScalingGroupRow.name == sgroup_name ) result = await db_sess.execute(query) row = result.first() - scheduler_name = row.scheduler - sgroup_opts: ScalingGroupOpts = row.scheduler_opts + if row is None: + raise ValueError(f"Scaling group '{sgroup_name}' not found.") + return row.scheduler, row.scheduler_opts + + async def _load_scheduler(self, db_sess: SASession, sgroup_name: str) -> AbstractScheduler: + scheduler_name, sgroup_opts = await self._get_scaling_group_data(db_sess, sgroup_name) + + global_scheduler_opts = {} + if self.shared_config["plugins"]["scheduler"]: + global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get( + scheduler_name, {} + ) + scheduler_config = {**global_scheduler_opts, **sgroup_opts.config} + + return load_scheduler(scheduler_name, sgroup_opts, scheduler_config) + + async def _load_agent_selector( + self, + db_sess: SASession, + sgroup_name: str, + pending_session: SessionRow, # TODO: id and session_type? + ) -> AbstractAgentSelector: + session_type = pending_session.session_type + + _scheduler_name, sgroup_opts = await self._get_scaling_group_data(db_sess, sgroup_name) + match sgroup_opts.agent_selection_strategy: - # The names correspond to the entrypoint names (backendai_agentselector_v10). case AgentSelectionStrategy.LEGACY: agselector_name = "legacy" case AgentSelectionStrategy.ROUNDROBIN: agselector_name = "roundrobin" case AgentSelectionStrategy.CONCENTRATED: - agselector_name = "concentrated" + # TODO: If there are no services with the same model, it should operate as "concentrated". + if session_type == SessionTypes.INFERENCE: + # TODO: Roundrobin? + agselector_name = "dispersed" + else: + agselector_name = "concentrated" case AgentSelectionStrategy.DISPERSED: agselector_name = "dispersed" case _ as unknown: @@ -430,13 +457,7 @@ async def _load_scheduler( f"Unknown agent selection strategy: {unknown!r}. Possible values: {[*AgentSelectionStrategy.__members__.keys()]}" ) - global_scheduler_opts = {} global_agselector_opts = {} - if self.shared_config["plugins"]["scheduler"]: - global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get( - scheduler_name, {} - ) - scheduler_config = {**global_scheduler_opts, **sgroup_opts.config} if self.shared_config["plugins"]["agent-selector"]: global_agselector_opts = self.shared_config["plugins"]["agent-selector"].get( agselector_name, {} @@ -446,19 +467,13 @@ async def _load_scheduler( "agent-selection-resource-priority" ] - scheduler = load_scheduler( - scheduler_name, - sgroup_opts, - scheduler_config, - ) - agent_selector = load_agent_selector( + return load_agent_selector( agselector_name, sgroup_opts, agselector_config, agent_selection_resource_priority, self.shared_config, ) - return scheduler, agent_selector async def _schedule_in_sgroup( self, @@ -468,7 +483,8 @@ async def _schedule_in_sgroup( # Part 0: Load the scheduler and the agent selector. async with self.db.begin_readonly_session() as db_sess: - scheduler, agent_selector = await self._load_scheduler(db_sess, sgroup_name) + # 스케줄러 로드 -> pending session 리스트업 -> agent selector 로드 (pending session 타입에 따라 다른 agent selector 로드) + scheduler = await self._load_scheduler(db_sess, sgroup_name) existing_sessions, pending_sessions, cancelled_sessions = await _list_managed_sessions( db_sess, sgroup_name, scheduler.sgroup_opts.pending_timeout ) @@ -507,6 +523,10 @@ async def _schedule_in_sgroup( raise RuntimeError("should not reach here") pending_sess = pending_sessions.pop(picked_idx) log_fmt = "schedule(s:{}, prio:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): " + + async with self.db.begin_readonly_session() as db_sess: + agent_selector = await self._load_agent_selector(db_sess, sgroup_name, pending_sess) + log_args = ( pending_sess.id, pending_sess.priority,