Skip to content

feat(BA-442): Distribute model service replica containers whenever possible #3693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3693.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `enforce_spreading_endpoint_replica` scheduling option to the `ConcentratedAgentSelector`, which prioritizes availability over available resource slots when selecting an agent for inference sessions.
5 changes: 5 additions & 0 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,17 @@ class ScalingGroupOpts(JSONSerializableMixin):
agent_selection_strategy: AgentSelectionStrategy = AgentSelectionStrategy.DISPERSED
agent_selector_config: Mapping[str, Any] = attr.field(factory=dict)

# Only used in the ConcentratedAgentSelector
enforce_spreading_endpoint_replica: bool = False

def to_json(self) -> dict[str, Any]:
return {
"allowed_session_types": [item.value for item in self.allowed_session_types],
"pending_timeout": self.pending_timeout.total_seconds(),
"config": self.config,
"agent_selection_strategy": self.agent_selection_strategy,
"agent_selector_config": self.agent_selector_config,
"enforce_spreading_endpoint_replica": self.enforce_spreading_endpoint_replica,
}

@classmethod
Expand All @@ -133,6 +137,7 @@ def as_trafaret(cls) -> t.Trafaret:
AgentSelectionStrategy
),
t.Key("agent_selector_config", default={}): agent_selector_config_iv,
t.Key("enforce_spreading_endpoint_replica", default=False): t.ToBool,
}).allow_extra("*")


Expand Down
26 changes: 23 additions & 3 deletions src/ai/backend/manager/scheduler/agent_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ArchName,
ResourceSlot,
)
from ai.backend.logging import BraceStyleAdapter

from ..models import AgentRow, KernelRow, SessionRow
from .types import (
Expand All @@ -26,7 +27,7 @@
sort_requested_slots_by_priority,
)

log = logging.Logger(__spec__.name)
log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore


def get_num_extras(agent: AgentRow, requested_slots: ResourceSlot) -> int:
Expand Down Expand Up @@ -161,6 +162,15 @@ async def select_agent(


class ConcentratedAgentSelector(BaseAgentSelector[NullAgentSelectorState]):
@property
@override
def config_iv(self) -> t.Dict:
return t.Dict({
# Only used when "enforce_spreading_endpoint_replica" flag is True.
t.Key("kernel_counts_at_same_endpoint", optional=True, default=None): t.Null
| t.Mapping(t.String, t.Int()),
}).allow_extra("*")

@override
@classmethod
def get_state_cls(cls) -> type[NullAgentSelectorState]:
Expand All @@ -179,16 +189,26 @@ async def select_agent(
resource_priorities = sort_requested_slots_by_priority(
requested_slots, self.agent_selection_resource_priority
)

# When not using enforce_spreading_endpoint_replica, treat all agent kernel counts as 0.
kernel_counts_at_same_endpoint = (
self.config.get("kernel_counts_at_same_endpoint", {})
if self.sgroup_opts.enforce_spreading_endpoint_replica
else {}
)

chosen_agent = min(
agents,
key=lambda agent: [
key=lambda agent: (
kernel_counts_at_same_endpoint.get(agent.id, 0),
get_num_extras(agent, requested_slots),
*[
(agent.available_slots - agent.occupied_slots).get(key, sys.maxsize)
for key in resource_priorities
],
],
),
)

return chosen_agent.id


Expand Down
145 changes: 116 additions & 29 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import json
import logging
import uuid
from collections import defaultdict
from collections.abc import (
Awaitable,
Expand All @@ -12,12 +13,14 @@
Sequence,
)
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta, timezone
from decimal import Decimal
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Optional,
Union,
cast,
Expand Down Expand Up @@ -69,9 +72,11 @@
RedisConnectionInfo,
ResourceSlot,
SessionId,
SessionTypes,
aobject,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.kernel import USER_RESOURCE_OCCUPYING_KERNEL_STATUSES
from ai.backend.manager.models.session import _build_session_fetch_query
from ai.backend.manager.types import DistributedLockFactory
from ai.backend.plugin.entrypoint import scan_entrypoints
Expand Down Expand Up @@ -194,6 +199,59 @@ def create_agent_selector(
raise ImportError("Cannot load the agent-selector plugin", name)


async def get_kernel_count_per_agent_at_endpoint(
db_sess: SASession,
endpoint_id: uuid.UUID,
filter_by_statuses: Iterable[KernelStatus],
) -> dict[AgentId, int]:
"""
Query the agents to which the kernels of each session belong,
and calculate the number of kernels for each agent at a specific endpoint.
"""

routing_rows: list[RoutingRow] = (
await db_sess.scalars(
sa.select(RoutingRow)
.options(selectinload(RoutingRow.session_row).options(selectinload(SessionRow.kernels)))
.where(
RoutingRow.endpoint == endpoint_id,
)
)
).all()

kernel_count_per_agent: dict[AgentId, int] = {}

for routing_row in routing_rows:
session_row: SessionRow = routing_row.session_row
kernels: list[KernelRow] = session_row.kernels

for kernel in kernels:
if kernel.status in filter_by_statuses:
if agent_id := kernel.agent:
kernel_count_per_agent[agent_id] = kernel_count_per_agent.get(agent_id, 0) + 1

log.debug(
'kernel counts at endpoint {0}: "{1}"',
endpoint_id,
repr(kernel_count_per_agent),
)

return kernel_count_per_agent


@dataclass
class LoadSchedulerArgs:
scheduler_name: str
sgroup_opts: ScalingGroupOpts


@dataclass
class LoadAgentSelectorArgs:
sgroup_opts: ScalingGroupOpts
pending_session_id: uuid.UUID
pending_session_type: SessionTypes


class SchedulerDispatcher(aobject):
config: LocalConfig
shared_config: SharedConfig
Expand Down Expand Up @@ -408,25 +466,45 @@ def _pipeline(r: Redis) -> RedisPipeline:
raise asyncio.CancelledError()
raise

async def _load_scheduler(
self,
db_sess: SASession,
sgroup_name: str,
) -> tuple[AbstractScheduler, AbstractAgentSelector]:
query = sa.select(ScalingGroupRow.scheduler, ScalingGroupRow.scheduler_opts).where(
ScalingGroupRow.name == sgroup_name
)
result = await db_sess.execute(query)
row = result.first()
scheduler_name = row.scheduler
sgroup_opts: ScalingGroupOpts = row.scheduler_opts
def _load_scheduler(self, args: LoadSchedulerArgs) -> AbstractScheduler:
global_scheduler_opts = {}
if self.shared_config["plugins"]["scheduler"]:
global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get(
args.scheduler_name, {}
)
scheduler_config = {**global_scheduler_opts, **args.sgroup_opts.config}

return load_scheduler(args.scheduler_name, args.sgroup_opts, scheduler_config)

async def _load_agent_selector(self, args: LoadAgentSelectorArgs) -> AbstractAgentSelector:
sgroup_opts = args.sgroup_opts

# TODO: Remove "dynamic_config after refactoring.
dynamic_config: dict[str, Any] = {}

match sgroup_opts.agent_selection_strategy:
# The names correspond to the entrypoint names (backendai_agentselector_v10).
case AgentSelectionStrategy.LEGACY:
agselector_name = "legacy"
case AgentSelectionStrategy.ROUNDROBIN:
agselector_name = "roundrobin"
case AgentSelectionStrategy.CONCENTRATED:
async with self.db.begin_readonly_session() as db_sess:
if (
sgroup_opts.enforce_spreading_endpoint_replica
and args.pending_session_type == SessionTypes.INFERENCE
):
endpoint_id = await db_sess.scalar(
sa.select(RoutingRow.endpoint).where(
RoutingRow.session == args.pending_session_id
)
)

dynamic_config[
"kernel_counts_at_same_endpoint"
] = await get_kernel_count_per_agent_at_endpoint(
db_sess, endpoint_id, USER_RESOURCE_OCCUPYING_KERNEL_STATUSES
)

agselector_name = "concentrated"
case AgentSelectionStrategy.DISPERSED:
agselector_name = "dispersed"
Expand All @@ -435,45 +513,46 @@ async def _load_scheduler(
f"Unknown agent selection strategy: {unknown!r}. Possible values: {[*AgentSelectionStrategy.__members__.keys()]}"
)

global_scheduler_opts = {}
global_agselector_opts = {}
if self.shared_config["plugins"]["scheduler"]:
global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get(
scheduler_name, {}
)
scheduler_config = {**global_scheduler_opts, **sgroup_opts.config}
if self.shared_config["plugins"]["agent-selector"]:
global_agselector_opts = self.shared_config["plugins"]["agent-selector"].get(
agselector_name, {}
)
agselector_config = {**global_agselector_opts, **sgroup_opts.agent_selector_config}
agselector_config = {
**global_agselector_opts,
**sgroup_opts.agent_selector_config,
**dynamic_config,
}

agent_selection_resource_priority = self.local_config["manager"][
"agent-selection-resource-priority"
]

scheduler = load_scheduler(
scheduler_name,
sgroup_opts,
scheduler_config,
)
agent_selector = load_agent_selector(
return load_agent_selector(
agselector_name,
sgroup_opts,
agselector_config,
agent_selection_resource_priority,
self.shared_config,
)
return scheduler, agent_selector

async def _schedule_in_sgroup(
self,
sched_ctx: SchedulingContext,
sgroup_name: str,
) -> None:
# Part 0: Load the scheduler and the agent selector.

async with self.db.begin_readonly_session() as db_sess:
scheduler, agent_selector = await self._load_scheduler(db_sess, sgroup_name)
result = await db_sess.execute(
sa.select(ScalingGroupRow.scheduler, ScalingGroupRow.scheduler_opts).where(
ScalingGroupRow.name == sgroup_name
)
)
row = result.first()
if row is None:
raise ValueError(f'Scaling group "{sgroup_name}" not found!')
scheduler_name, sgroup_opts = row.scheduler, row.scheduler_opts
scheduler = self._load_scheduler(LoadSchedulerArgs(scheduler_name, sgroup_opts))
existing_sessions, pending_sessions, cancelled_sessions = await _list_managed_sessions(
db_sess, sgroup_name, scheduler.sgroup_opts.pending_timeout
)
Expand Down Expand Up @@ -512,6 +591,14 @@ async def _schedule_in_sgroup(
raise RuntimeError("should not reach here")
pending_sess = pending_sessions.pop(picked_idx)
log_fmt = "schedule(s:{}, prio:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): "
agent_selector = await self._load_agent_selector(
LoadAgentSelectorArgs(
sgroup_opts,
pending_sess.id,
pending_sess.session_type,
)
)

log_args = (
pending_sess.id,
pending_sess.priority,
Expand Down
4 changes: 3 additions & 1 deletion src/ai/backend/manager/scheduler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ class AbstractAgentSelector(Generic[T_ResourceGroupState], ABC):
"""

sgroup_opts: ScalingGroupOpts # sgroup-specific config
config: Mapping[str, Any] # agent-selector-specific config
config: Mapping[
str, Any
] # agent-selector-specific config, Do not use this. this will be removed after refactoring.
agent_selection_resource_priority: list[str]
state_store: AbstractResourceGroupStateStore[T_ResourceGroupState]

Expand Down
3 changes: 2 additions & 1 deletion tests/manager/scheduler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def create_mock_session(
status_data: dict[str, Any] | None = None,
kernel_opts: Sequence[KernelOpt] | None = None,
priority: int = SESSION_PRIORITY_DEFAULT,
session_type: SessionTypes = SessionTypes.BATCH,
) -> SessionRow:
"""Create a simple single-kernel pending session."""
if kernel_opts is None:
Expand All @@ -175,7 +176,7 @@ def create_mock_session(
id=session_id,
creation_id=secrets.token_hex(8),
name=f"session-{secrets.token_hex(4)}",
session_type=SessionTypes.BATCH,
session_type=session_type,
status=status,
status_data=status_data,
cluster_mode="single-node",
Expand Down
Loading
Loading