Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-442): Distribute model service replica containers whenever possible #3693

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
RedisConnectionInfo,
ResourceSlot,
SessionId,
SessionTypes,
aobject,
)
from ai.backend.logging import BraceStyleAdapter
Expand Down Expand Up @@ -403,40 +404,60 @@ def _pipeline(r: Redis) -> RedisPipeline:
raise asyncio.CancelledError()
raise

async def _load_scheduler(
self,
db_sess: SASession,
sgroup_name: str,
) -> tuple[AbstractScheduler, AbstractAgentSelector]:
async def _get_scaling_group_data(
self, db_sess: SASession, sgroup_name: str
) -> tuple[str, ScalingGroupOpts]:
query = sa.select(ScalingGroupRow.scheduler, ScalingGroupRow.scheduler_opts).where(
ScalingGroupRow.name == sgroup_name
)
result = await db_sess.execute(query)
row = result.first()
scheduler_name = row.scheduler
sgroup_opts: ScalingGroupOpts = row.scheduler_opts
if row is None:
raise ValueError(f"Scaling group '{sgroup_name}' not found.")
return row.scheduler, row.scheduler_opts

async def _load_scheduler(self, db_sess: SASession, sgroup_name: str) -> AbstractScheduler:
scheduler_name, sgroup_opts = await self._get_scaling_group_data(db_sess, sgroup_name)

global_scheduler_opts = {}
if self.shared_config["plugins"]["scheduler"]:
global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get(
scheduler_name, {}
)
scheduler_config = {**global_scheduler_opts, **sgroup_opts.config}

return load_scheduler(scheduler_name, sgroup_opts, scheduler_config)

async def _load_agent_selector(
self,
db_sess: SASession,
sgroup_name: str,
pending_session: SessionRow, # TODO: id and session_type?
) -> AbstractAgentSelector:
session_type = pending_session.session_type

_scheduler_name, sgroup_opts = await self._get_scaling_group_data(db_sess, sgroup_name)

match sgroup_opts.agent_selection_strategy:
# The names correspond to the entrypoint names (backendai_agentselector_v10).
case AgentSelectionStrategy.LEGACY:
agselector_name = "legacy"
case AgentSelectionStrategy.ROUNDROBIN:
agselector_name = "roundrobin"
case AgentSelectionStrategy.CONCENTRATED:
agselector_name = "concentrated"
# TODO: If there are no services with the same model, it should operate as "concentrated".
if session_type == SessionTypes.INFERENCE:
# TODO: Roundrobin?
agselector_name = "dispersed"
else:
agselector_name = "concentrated"
case AgentSelectionStrategy.DISPERSED:
agselector_name = "dispersed"
case _ as unknown:
raise ValueError(
f"Unknown agent selection strategy: {unknown!r}. Possible values: {[*AgentSelectionStrategy.__members__.keys()]}"
)

global_scheduler_opts = {}
global_agselector_opts = {}
if self.shared_config["plugins"]["scheduler"]:
global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get(
scheduler_name, {}
)
scheduler_config = {**global_scheduler_opts, **sgroup_opts.config}
if self.shared_config["plugins"]["agent-selector"]:
global_agselector_opts = self.shared_config["plugins"]["agent-selector"].get(
agselector_name, {}
Expand All @@ -446,19 +467,13 @@ async def _load_scheduler(
"agent-selection-resource-priority"
]

scheduler = load_scheduler(
scheduler_name,
sgroup_opts,
scheduler_config,
)
agent_selector = load_agent_selector(
return load_agent_selector(
agselector_name,
sgroup_opts,
agselector_config,
agent_selection_resource_priority,
self.shared_config,
)
return scheduler, agent_selector

async def _schedule_in_sgroup(
self,
Expand All @@ -468,7 +483,8 @@ async def _schedule_in_sgroup(
# Part 0: Load the scheduler and the agent selector.

async with self.db.begin_readonly_session() as db_sess:
scheduler, agent_selector = await self._load_scheduler(db_sess, sgroup_name)
# 스케줄러 로드 -> pending session 리스트업 -> agent selector 로드 (pending session 타입에 따라 다른 agent selector 로드)
scheduler = await self._load_scheduler(db_sess, sgroup_name)
existing_sessions, pending_sessions, cancelled_sessions = await _list_managed_sessions(
db_sess, sgroup_name, scheduler.sgroup_opts.pending_timeout
)
Expand Down Expand Up @@ -507,6 +523,10 @@ async def _schedule_in_sgroup(
raise RuntimeError("should not reach here")
pending_sess = pending_sessions.pop(picked_idx)
log_fmt = "schedule(s:{}, prio:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): "

async with self.db.begin_readonly_session() as db_sess:
agent_selector = await self._load_agent_selector(db_sess, sgroup_name, pending_sess)

log_args = (
pending_sess.id,
pending_sess.priority,
Expand Down
Loading