diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index 6a2c6414eab..4757d0e8040 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -81,6 +81,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu redis_stat=root_ctx.h.redis_stat, redis_image=root_ctx.h.redis_image, redis_live=root_ctx.h.redis_live, + concurrency_tracker=root_ctx.g.concurrency_tracker, manager_status=manager_status, known_slot_types=known_slot_types, background_task_manager=root_ctx.g.background_task_manager, diff --git a/src/ai/backend/manager/api/context.py b/src/ai/backend/manager/api/context.py index 244cccd16a9..86c5c24ab1a 100644 --- a/src/ai/backend/manager/api/context.py +++ b/src/ai/backend/manager/api/context.py @@ -13,6 +13,7 @@ from ..config import LocalConfig, SharedConfig from ..idle import IdleCheckerHost + from ..models.resource_policy import ConcurrencyTracker from ..models.storage import StorageSessionManager from ..models.utils import ExtendedAsyncSAEngine from ..plugin.webapp import WebappPluginContext @@ -51,6 +52,7 @@ class GlobalObjectContext: idle_checker_host: IdleCheckerHost storage_manager: StorageSessionManager background_task_manager: BackgroundTaskManager + concurrency_tracker: ConcurrencyTracker webapp_plugin_ctx: WebappPluginContext hook_plugin_ctx: HookPluginContext error_monitor: ErrorPluginContext diff --git a/src/ai/backend/manager/api/resource.py b/src/ai/backend/manager/api/resource.py index 0c0c97a5cd1..5efea579cd6 100644 --- a/src/ai/backend/manager/api/resource.py +++ b/src/ai/backend/manager/api/resource.py @@ -307,7 +307,7 @@ async def recalculate_usage(request: web.Request) -> web.Response: """ log.info("RECALCULATE_USAGE ()") root_ctx: RootContext = request.app["_root.context"] - await root_ctx.registry.recalc_resource_usage() + await root_ctx.registry.recalc_resource_usage_by_fullscan() return web.json_response({}, status=200) diff --git a/src/ai/backend/manager/models/agent.py b/src/ai/backend/manager/models/agent.py index 905d6cc0cd9..ee5805cb131 100644 --- a/src/ai/backend/manager/models/agent.py +++ b/src/ai/backend/manager/models/agent.py @@ -121,6 +121,7 @@ async def recalc_agent_resource_occupancy(db_session: SASession, agent_id: Agent & (KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) ) .options(load_only(KernelRow.occupied_slots)) + .with_for_update() ) kernel_rows = cast(list[KernelRow], (await db_session.scalars(_stmt)).all()) occupied_slots = ResourceSlot() @@ -145,6 +146,7 @@ async def recalc_agent_resource_occupancy_using_orm( KernelRow, KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES) ), ) + .with_for_update() ) occupied_slots = ResourceSlot() agent_row = cast(AgentRow, await db_session.scalar(agent_query)) diff --git a/src/ai/backend/manager/models/exceptions.py b/src/ai/backend/manager/models/exceptions.py new file mode 100644 index 00000000000..009a0ea47c1 --- /dev/null +++ b/src/ai/backend/manager/models/exceptions.py @@ -0,0 +1,9 @@ +from __future__ import annotations + + +class ResourceError(Exception): + pass + + +class ResourceLimitExceeded(ResourceError): + pass diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 0c728741496..8cba39ca7de 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -133,6 +133,7 @@ from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission from .resource_policy import ( + ConcurrencyTracker, CreateKeyPairResourcePolicy, CreateProjectResourcePolicy, CreateUserResourcePolicy, @@ -210,6 +211,7 @@ class GraphQueryContext: redis_stat: RedisConnectionInfo redis_live: RedisConnectionInfo redis_image: RedisConnectionInfo + concurrency_tracker: ConcurrencyTracker manager_status: ManagerStatus known_slot_types: Mapping[SlotName, SlotTypes] background_task_manager: BackgroundTaskManager diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index e54f4827fd2..423072d4e98 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -39,7 +39,6 @@ BinarySize, ClusterMode, KernelId, - RedisConnectionInfo, ResourceSlot, SessionId, SessionResult, @@ -105,7 +104,6 @@ "RESOURCE_USAGE_KERNEL_STATUSES", "DEAD_KERNEL_STATUSES", "LIVE_STATUS", - "recalc_concurrency_used", ) log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.models.kernel")) @@ -1560,51 +1558,3 @@ class Meta: interfaces = (PaginatedList,) items = graphene.List(LegacyComputeSession, required=True) - - -async def recalc_concurrency_used( - db_sess: SASession, - redis_stat: RedisConnectionInfo, - access_key: AccessKey, -) -> None: - concurrency_used: int - from .session import PRIVATE_SESSION_TYPES - - async with db_sess.begin_nested(): - result = await db_sess.execute( - sa.select(sa.func.count()) - .select_from(KernelRow) - .where( - (KernelRow.access_key == access_key) - & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) - & (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES)) - ), - ) - concurrency_used = result.scalar() - result = await db_sess.execute( - sa.select(sa.func.count()) - .select_from(KernelRow) - .where( - (KernelRow.access_key == access_key) - & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) - & (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES)) - ), - ) - sftp_concurrency_used = result.scalar() - assert isinstance(concurrency_used, int) - assert isinstance(sftp_concurrency_used, int) - - await redis_helper.execute( - redis_stat, - lambda r: r.set( - f"keypair.concurrency_used.{access_key}", - concurrency_used, - ), - ) - await redis_helper.execute( - redis_stat, - lambda r: r.set( - f"keypair.sftp_concurrency_used.{access_key}", - sftp_concurrency_used, - ), - ) diff --git a/src/ai/backend/manager/models/keypair.py b/src/ai/backend/manager/models/keypair.py index af7044e631b..f9083d961b6 100644 --- a/src/ai/backend/manager/models/keypair.py +++ b/src/ai/backend/manager/models/keypair.py @@ -260,14 +260,7 @@ async def resolve_compute_sessions( async def resolve_concurrency_used(self, info: graphene.ResolveInfo) -> int: ctx: GraphQueryContext = info.context - kp_key = "keypair.concurrency_used" - concurrency_used = await redis_helper.execute( - ctx.redis_stat, - lambda r: r.get(f"{kp_key}.{self.access_key}"), - ) - if concurrency_used is not None: - return int(concurrency_used) - return 0 + return await ctx.concurrency_tracker.count_compute_sessions(AccessKey(self.access_key)) async def resolve_last_used(self, info: graphene.ResolveInfo) -> datetime | None: ctx: GraphQueryContext = info.context @@ -654,10 +647,7 @@ async def mutate( delete_query = sa.delete(keypairs).where(keypairs.c.access_key == access_key) result = await simple_db_mutate(cls, ctx, delete_query) if result.ok: - await redis_helper.execute( - ctx.redis_stat, - lambda r: r.delete(f"keypair.concurrency_used.{access_key}"), - ) + await ctx.concurrency_tracker.clear(access_key) return result diff --git a/src/ai/backend/manager/models/resource_policy.py b/src/ai/backend/manager/models/resource_policy.py index 95bd4d40751..f13adb8a40a 100644 --- a/src/ai/backend/manager/models/resource_policy.py +++ b/src/ai/backend/manager/models/resource_policy.py @@ -2,16 +2,34 @@ import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, Sequence +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Final, + TypeVar, +) import graphene import sqlalchemy as sa from graphene.types.datetime import DateTime as GQLDateTime +from redis.asyncio import Redis +from redis.asyncio.client import Pipeline from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import relationship, selectinload -from ai.backend.common.types import DefaultForUnspecified, ResourceSlot +from ai.backend.common import redis_helper +from ai.backend.common.types import ( + AccessKey, + DefaultForUnspecified, + RedisConnectionInfo, + ResourceSlot, + SessionId, +) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.exceptions import ResourceLimitExceeded from ai.backend.manager.models.utils import execute_with_retry from .base import ( @@ -406,7 +424,7 @@ async def mutate( name: str, props: ModifyKeyPairResourcePolicyInput, ) -> ModifyKeyPairResourcePolicy: - data: Dict[str, Any] = {} + data: dict[str, Any] = {} set_if_set( props, data, @@ -642,7 +660,7 @@ async def mutate( name: str, props: ModifyUserResourcePolicyInput, ) -> ModifyUserResourcePolicy: - data: Dict[str, Any] = {} + data: dict[str, Any] = {} set_if_set(props, data, "max_vfolder_count") set_if_set(props, data, "max_quota_scope_size") set_if_set(props, data, "max_session_count_per_model_session") @@ -838,7 +856,7 @@ async def mutate( name: str, props: ModifyProjectResourcePolicyInput, ) -> ModifyProjectResourcePolicy: - data: Dict[str, Any] = {} + data: dict[str, Any] = {} set_if_set(props, data, "max_vfolder_count") set_if_set(props, data, "max_quota_scope_size") update_query = ( @@ -869,3 +887,243 @@ async def mutate( ProjectResourcePolicyRow.name == name ) return await simple_db_mutate(cls, info.context, delete_query) + + +COMPUTE_CONCURRENCY_USED_KEY_PREFIX: Final = "keypair.compute_concurrency_used.set." +SYSTEM_CONCURRENCY_USED_KEY_PREFIX: Final = "keypair.system_concurrency_used.set." # incl. sftp + +# TODO: accept multiple session IDs at once +_check_keypair_concurrency_script = """ +local key = KEYS[1] +local limit = tonumber(ARGV[1]) +local session_id = ARGV[2] +local result = {} +redis.call('ZADD', key, 1, session_id) +local count = tonumber(redis.call('ZCARD', key)) +if limit > 0 and count >= limit then + result[1] = 0 + result[2] = count + return result +end +result[1] = 1 +result[2] = count +return result +""" + +T_IntOrNone = TypeVar("T_IntOrNone", bound=int | None) + + +class ConcurrencyTracker: + def __init__(self, redis_stat: RedisConnectionInfo) -> None: + self.redis_stat = redis_stat + + async def clear(self, access_key: AccessKey) -> None: + await redis_helper.execute( + self.redis_stat, + lambda r: r.delete( + f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + ), + ) + + async def add_compute_sessions( + self, + access_key: AccessKey, + session_ids: list[SessionId], + *, + limit: T_IntOrNone = None, + ) -> T_IntOrNone: + if limit is not None: + ok, count = await redis_helper.execute_script( + self.redis_stat, + "check_keypair_concurrency_used", + _check_keypair_concurrency_script, + [f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{access_key}"], + [limit, str(session_ids[0])], + ) + if ok == 0: + raise ResourceLimitExceeded( + f"The maximum limit of concurrent compute sessions ({limit}) has been excedded." + ) + return count + else: + await redis_helper.execute( + self.redis_stat, + lambda r: r.zadd( + f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + {str(session_id): 1 for session_id in session_ids}, + ), + ) + return None + + async def add_system_sessions( + self, + access_key: AccessKey, + session_ids: list[SessionId], + *, + limit: T_IntOrNone = None, + ) -> T_IntOrNone: + if limit is not None: + ok, count = await redis_helper.execute_script( + self.redis_stat, + "check_keypair_concurrency_used", + _check_keypair_concurrency_script, + [f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{access_key}"], + [limit, str(session_ids[0])], + ) + if ok == 0: + raise ResourceLimitExceeded( + f"The maximum limit of concurrent system sessions ({limit}) has been excedded." + ) + return count + else: + await redis_helper.execute( + self.redis_stat, + lambda r: r.zadd( + f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + {str(session_id): 1 for session_id in session_ids}, + ), + ) + return None + + async def count_compute_sessions(self, access_key: AccessKey) -> int: + return await redis_helper.execute( + self.redis_stat, + lambda r: r.zcard(f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{access_key}"), + ) + + async def count_system_sessions(self, access_key: AccessKey) -> int: + return await redis_helper.execute( + self.redis_stat, + lambda r: r.zcard(f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{access_key}"), + ) + + async def remove_compute_sessions( + self, access_key: AccessKey, session_ids: list[SessionId] + ) -> None: + await redis_helper.execute( + self.redis_stat, + lambda r: r.zrem( + f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + *[str(session_id) for session_id in session_ids], + ), + ) + + async def remove_system_sessions( + self, access_key: AccessKey, session_ids: list[SessionId] + ) -> None: + await redis_helper.execute( + self.redis_stat, + lambda r: r.zrem( + f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{access_key}", + *[str(session_id) for session_id in session_ids], + ), + ) + + async def recalc_concurrency_used( + self, + db_sess: SASession, + access_key: AccessKey, + ) -> None: + from .kernel import USER_RESOURCE_OCCUPYING_KERNEL_STATUSES + from .session import ( + PRIVATE_SESSION_TYPES, + SessionRow, + ) + + session_query = sa.select(SessionRow).where( + (SessionRow.access_key == access_key) + & (SessionRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + & (SessionRow.session_type.not_in(PRIVATE_SESSION_TYPES)) + ) + active_compute_sessions = (await db_sess.execute(session_query)).all() + + session_query = sa.select(SessionRow).where( + (SessionRow.access_key == access_key) + & (SessionRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + & (SessionRow.session_type.in_(PRIVATE_SESSION_TYPES)) + ) + active_system_sessions = (await db_sess.execute(session_query)).all() + + await self.add_compute_sessions(access_key, [s.id for s in active_compute_sessions]) + await self.add_system_sessions(access_key, [s.id for s in active_system_sessions]) + + async def recalc_concurrency_used_fullscan( + self, + db_sess: SASession, + ) -> None: + from .kernel import USER_RESOURCE_OCCUPYING_KERNEL_STATUSES + from .session import ( + PRIVATE_SESSION_TYPES, + SessionRow, + ) + + session_query = ( + sa.select( + SessionRow.access_key, + sa.func.array_agg(SessionRow.id).label("session_ids"), + ) + .where( + (SessionRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + & (SessionRow.session_type.not_in(PRIVATE_SESSION_TYPES)) + ) + .groupby(SessionRow.access_key) + ) + active_compute_sessions = (await db_sess.execute(session_query)).all() + + session_query = ( + sa.select( + SessionRow.access_key, + sa.func.array_agg(SessionRow.id).label("session_ids"), + ) + .where( + (SessionRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + & (SessionRow.session_type.in_(PRIVATE_SESSION_TYPES)) + ) + .groupby(SessionRow.access_key) + ) + active_system_sessions = (await db_sess.execute(session_query)).all() + + async def _redis_pipe_for_active_compute_sessions(r: Redis) -> Pipeline: + pipe = r.pipeline() + for item in active_compute_sessions: + pipe.delete(item.access_key) + pipe.zadd( + f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{item.access_key}", + {str(session_id): 1 for session_id in item.session_ids}, + ) + return pipe + + async def _redis_pipe_for_active_system_sessions(r: Redis) -> Pipeline: + pipe = r.pipeline() + for item in active_system_sessions: + pipe.delete(item.access_key) + pipe.zadd( + f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{item.access_key}", + {str(session_id): 1 for session_id in item.session_ids}, + ) + return pipe + + await redis_helper.execute(self.redis_stat, _redis_pipe_for_active_compute_sessions) + await redis_helper.execute(self.redis_stat, _redis_pipe_for_active_system_sessions) + + +@dataclass +class ConcurrencyUsed: + access_key: AccessKey + compute_session_ids: set[SessionId] = field(default_factory=set) + system_session_ids: set[SessionId] = field(default_factory=set) + + @property + def compute_concurrency_used_key(self) -> str: + return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" + + @property + def system_concurrency_used_key(self) -> str: + return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" + + def to_cnt_map(self) -> Mapping[str, int]: + return { + self.compute_concurrency_used_key: len(self.compute_session_ids), + self.system_concurrency_used_key: len(self.system_session_ids), + } diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index b8cf91be580..f4edbfb8882 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -6,7 +6,6 @@ import textwrap from collections.abc import Iterable, Mapping, Sequence from contextlib import asynccontextmanager as actxmgr -from dataclasses import dataclass, field from datetime import datetime from decimal import Decimal from typing import ( @@ -605,31 +604,6 @@ async def _match_sessions_by_name( return result.scalars().all() -COMPUTE_CONCURRENCY_USED_KEY_PREFIX = "keypair.concurrency_used." -SYSTEM_CONCURRENCY_USED_KEY_PREFIX = "keypair.sftp_concurrency_used." - - -@dataclass -class ConcurrencyUsed: - access_key: AccessKey - compute_session_ids: set[SessionId] = field(default_factory=set) - system_session_ids: set[SessionId] = field(default_factory=set) - - @property - def compute_concurrency_used_key(self) -> str: - return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - @property - def system_concurrency_used_key(self) -> str: - return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - def to_cnt_map(self) -> Mapping[str, int]: - return { - self.compute_concurrency_used_key: len(self.compute_session_ids), - self.system_concurrency_used_key: len(self.system_session_ids), - } - - class SessionOp(enum.StrEnum): CREATE = "create_session" DESTROY = "destroy_session" diff --git a/src/ai/backend/manager/models/user.py b/src/ai/backend/manager/models/user.py index 2c7b207f52e..4c1ba3045d7 100644 --- a/src/ai/backend/manager/models/user.py +++ b/src/ai/backend/manager/models/user.py @@ -22,8 +22,7 @@ from sqlalchemy.sql.expression import bindparam from sqlalchemy.types import VARCHAR, TypeDecorator -from ai.backend.common import redis_helper -from ai.backend.common.types import RedisConnectionInfo, VFolderID +from ai.backend.common.types import VFolderID from ai.backend.logging import BraceStyleAdapter from ..api.exceptions import VFolderOperationFailed @@ -48,6 +47,7 @@ if TYPE_CHECKING: from .gql import GraphQueryContext + from .resource_policy import ConcurrencyTracker from .storage import StorageSessionManager log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -1001,7 +1001,7 @@ async def _pre_func(conn: SAConnection) -> None: await cls.delete_kernels(conn, user_uuid) await cls.delete_sessions(conn, user_uuid) await cls.delete_vfolders(graph_ctx.db, user_uuid, graph_ctx.storage_manager) - await cls.delete_keypairs(conn, graph_ctx.redis_stat, user_uuid) + await cls.delete_keypairs(conn, graph_ctx.concurrency_tracker, user_uuid) delete_query = sa.delete(users).where(users.c.email == email) return await simple_db_mutate(cls, graph_ctx, delete_query, pre_func=_pre_func) @@ -1292,7 +1292,7 @@ async def delete_sessions( async def delete_keypairs( cls, conn: SAConnection, - redis_conn: RedisConnectionInfo, + concurrency_tracker: ConcurrencyTracker, user_uuid: UUID, ) -> int: """ @@ -1303,21 +1303,14 @@ async def delete_keypairs( :param user_uuid: user's UUID to delete keypairs :return: number of deleted rows """ - from . import keypairs + from .keypair import keypairs ak_rows = await conn.execute( sa.select([keypairs.c.access_key]).where(keypairs.c.user == user_uuid), ) if (row := ak_rows.first()) and (access_key := row.access_key): # Log concurrency used only when there is at least one keypair. - await redis_helper.execute( - redis_conn, - lambda r: r.delete(f"keypair.concurrency_used.{access_key}"), - ) - await redis_helper.execute( - redis_conn, - lambda r: r.delete(f"keypair.sftp_concurrency_used.{access_key}"), - ) + await concurrency_tracker.clear(access_key) result = await conn.execute( sa.delete(keypairs).where(keypairs.c.user == user_uuid), ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 68ee27ec41b..616af56eda8 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -155,7 +155,6 @@ ALLOWED_IMAGE_ROLES_FOR_SESSION_TYPE, PRIVATE_SESSION_TYPES, USER_RESOURCE_OCCUPYING_KERNEL_STATUSES, - USER_RESOURCE_OCCUPYING_SESSION_STATUSES, AgentRow, AgentStatus, EndpointLifecycle, @@ -182,18 +181,14 @@ query_allowed_sgroups, query_bootstrap_script, recalc_agent_resource_occupancy, - recalc_concurrency_used, scaling_groups, verify_vfolder_name, ) from .models.container_registry import ContainerRegistryRow from .models.image import bulk_get_image_configs from .models.session import ( - COMPUTE_CONCURRENCY_USED_KEY_PREFIX, SESSION_KERNEL_STATUS_MAPPING, SESSION_PRIORITY_DEFUALT, - SYSTEM_CONCURRENCY_USED_KEY_PREFIX, - ConcurrencyUsed, SessionLifecycleManager, ) from .models.utils import ( @@ -202,6 +197,7 @@ is_db_retry_error, reenter_txn, reenter_txn_session, + retry_txn, sql_json_merge, ) from .models.vfolder import VFolderOperationStatus, update_vfolder_status @@ -260,6 +256,7 @@ def __init__( self.event_dispatcher = global_context.event_dispatcher self.event_producer = global_context.event_producer self.storage_manager = global_context.storage_manager + self.concurrency_tracker = global_context.concurrency_tracker self.hook_plugin_ctx = global_context.hook_plugin_ctx self._kernel_actual_allocated_resources = {} self.debug = debug @@ -2062,148 +2059,113 @@ async def _update_agent_resource() -> None: await execute_with_retry(_update_agent_resource) - async def recalc_resource_usage(self, do_fullscan: bool = False) -> None: - async def _recalc() -> Mapping[AccessKey, ConcurrencyUsed]: - occupied_slots_per_agent: MutableMapping[str, ResourceSlot] = defaultdict( - lambda: ResourceSlot({"cpu": 0, "mem": 0}) + async def _recalc_agent_resources_fullscan( + self, + db_sess: AsyncSession, + ) -> Mapping[AgentId, ResourceSlot]: + # Initialize the per-agent resources with the empty resource slot + agent_query = sa.select(AgentRow.id).with_for_update() + agent_ids = [*(await db_sess.scalars(agent_query))] + occupied_slots_per_agent: dict[AgentId, ResourceSlot] = { + agent_id: ResourceSlot({"cpu": 0, "mem": 0}) for agent_id in agent_ids + } + # Do a full-scan of all resource-occupying sessions. + session_query = ( + sa.select(SessionRow) + .where(SessionRow.status.in_(AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES)) + .options( + load_only( + SessionRow.id, + SessionRow.access_key, + SessionRow.status, + SessionRow.session_type, + ), + selectinload(SessionRow.kernels).options( + load_only(KernelRow.agent, KernelRow.occupied_slots) + ), ) - access_key_to_concurrency_used: dict[AccessKey, ConcurrencyUsed] = {} - - async with self.db.begin_session() as db_sess: - # Query running containers and calculate concurrency_used per AK and - # occupied_slots per agent. - session_query = ( - sa.select(SessionRow) - .where( - ( - SessionRow.status.in_({ - *AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES, - *USER_RESOURCE_OCCUPYING_SESSION_STATUSES, - }) - ) - ) - .options( - load_only( - SessionRow.id, - SessionRow.access_key, - SessionRow.status, - SessionRow.session_type, - ), - selectinload(SessionRow.kernels).options( - load_only(KernelRow.agent, KernelRow.occupied_slots) - ), - ) - ) - async for session_row in await db_sess.stream_scalars(session_query): - session_row = cast(SessionRow, session_row) - for kernel in session_row.kernels: - session_status = cast(SessionStatus, session_row.status) - if session_status in AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES: - occupied_slots_per_agent[kernel.agent] += ResourceSlot( - kernel.occupied_slots - ) - if session_row.status in USER_RESOURCE_OCCUPYING_SESSION_STATUSES: - access_key = cast(AccessKey, session_row.access_key) - if access_key not in access_key_to_concurrency_used: - access_key_to_concurrency_used[access_key] = ConcurrencyUsed( - access_key - ) - if session_row.session_type in PRIVATE_SESSION_TYPES: - access_key_to_concurrency_used[access_key].system_session_ids.add( - session_row.id - ) - else: - access_key_to_concurrency_used[access_key].compute_session_ids.add( - session_row.id - ) - - if len(occupied_slots_per_agent) > 0: - # Update occupied_slots for agents with running containers. - await db_sess.execute( - ( - sa.update(AgentRow) - .where(AgentRow.id == sa.bindparam("agent_id")) - .values(occupied_slots=sa.bindparam("occupied_slots")) - ), - [ - {"agent_id": aid, "occupied_slots": slots} - for aid, slots in occupied_slots_per_agent.items() - ], - ) - await db_sess.execute( - ( - sa.update(AgentRow) - .values(occupied_slots=ResourceSlot({})) - .where(AgentRow.status == AgentStatus.ALIVE) - .where(sa.not_(AgentRow.id.in_(occupied_slots_per_agent.keys()))) - ) - ) - else: - query = ( - sa.update(AgentRow) - .values(occupied_slots=ResourceSlot({})) - .where(AgentRow.status == AgentStatus.ALIVE) - ) - await db_sess.execute(query) - return access_key_to_concurrency_used - - access_key_to_concurrency_used = await execute_with_retry(_recalc) - - # Update keypair resource usage for keypairs with running containers. - async def _update(r: Redis): - updates: dict[str, int] = {} - for concurrency in access_key_to_concurrency_used.values(): - updates |= concurrency.to_cnt_map() - if updates: - await r.mset(cast(MSetType, updates)) - - async def _update_by_fullscan(r: Redis): - updates = {} - keys = await r.keys(f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}*") - for stat_key in keys: - if isinstance(stat_key, bytes): - _stat_key = stat_key.decode("utf-8") - else: - _stat_key = cast(str, stat_key) - ak = _stat_key.replace(COMPUTE_CONCURRENCY_USED_KEY_PREFIX, "") - concurrent_sessions = access_key_to_concurrency_used.get(AccessKey(ak)) - usage = ( - len(concurrent_sessions.compute_session_ids) - if concurrent_sessions is not None - else 0 - ) - updates[_stat_key] = usage - keys = await r.keys(f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}*") - for stat_key in keys: - if isinstance(stat_key, bytes): - _stat_key = stat_key.decode("utf-8") - else: - _stat_key = cast(str, stat_key) - ak = _stat_key.replace(SYSTEM_CONCURRENCY_USED_KEY_PREFIX, "") - concurrent_sessions = access_key_to_concurrency_used.get(AccessKey(ak)) - usage = ( - len(concurrent_sessions.system_concurrency_used_key) - if concurrent_sessions is not None - else 0 + ) + # Re-add per-kernel resource occupancy + session: SessionRow + kernel: KernelRow + async for session in await db_sess.stream_scalars(session_query): + for kernel in session.kernels: + occupied_slots_per_agent[kernel.agent] += ResourceSlot(kernel.occupied_slots) + return occupied_slots_per_agent + + async def _recalc_agent_resources_from_session( + self, + db_sess: AsyncSession, + session_id: SessionId | None, + ) -> Mapping[AgentId, ResourceSlot]: + # This method returns the updates for the agents impacted by the given session only. + occupied_slots_per_agent: dict[AgentId, ResourceSlot] = defaultdict( + lambda: ResourceSlot({"cpu": 0, "mem": 0}) + ) + # First, let's get the agents impacted by the given session. + agent_query = ( + sa.select(AgentRow.id) + .select_from(sa.join(AgentRow, KernelRow, KernelRow.agent == AgentRow.id)) + .where(KernelRow.session_id == session_id) + .distinct() + .with_for_update() + ) + # Get all resource-occupying kernels on the impacted agents. + # The kernels not belonging to the given session are also included to recalculate the + # agent resource. + kernel_query = ( + sa.select(KernelRow) + .where( + KernelRow.agent.in_(agent_query) + & KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES) + ) + .options( + load_only( + KernelRow.id, + KernelRow.agent, + KernelRow.occupied_slots, ) - updates[_stat_key] = usage - if updates: - await r.mset(cast(MSetType, updates)) - - # Do full scan if the entire system does not have ANY sessions/sftp-sessions - # to set all concurrency_used to 0 - _do_fullscan = do_fullscan or not access_key_to_concurrency_used - if _do_fullscan: - await redis_helper.execute( - self.redis_stat, - _update_by_fullscan, ) - else: - await redis_helper.execute( - self.redis_stat, - _update, + ) + # Recalculate occupied resource slots from the agents and their kernels + kernel: KernelRow + async for kernel in await db_sess.stream_scalars(kernel_query): + occupied_slots_per_agent[kernel.agent] += ResourceSlot(kernel.occupied_slots) + return occupied_slots_per_agent + + async def _apply_agent_occupied_slots( + self, + db_sess: AsyncSession, + occupied_slots_per_agent: Mapping[AgentId, ResourceSlot], + ) -> None: + if occupied_slots_per_agent: + await db_sess.execute( + ( + sa.update(AgentRow) + .where(AgentRow.id == sa.bindparam("agent_id")) + .values(occupied_slots=sa.bindparam("occupied_slots")) + ), + [ + {"agent_id": aid, "occupied_slots": slots} + for aid, slots in occupied_slots_per_agent.items() + ], ) + async def recalc_resource_usage_by_session(self, session_id: SessionId) -> None: + async for attempt in retry_txn(): + with attempt: + async with self.db.begin_session() as db_sess: + updates = await self._recalc_agent_resources_from_session(db_sess, session_id) + await self._apply_agent_occupied_slots(db_sess, updates) + + async def recalc_resource_usage_by_fullscan(self) -> None: + async for attempt in retry_txn(): + with attempt: + async with self.db.begin_session() as db_sess: + updates = await self._recalc_agent_resources_fullscan(db_sess) + await self._apply_agent_occupied_slots(db_sess, updates) + await self.concurrency_tracker.recalc_concurrency_used_fullscan(db_sess) + async def destroy_session_lowlevel( self, session_id: SessionId, @@ -2316,7 +2278,7 @@ async def _destroy(db_session: AsyncSession) -> SessionRow: async with self.db.connect() as db_conn: await execute_with_txn_retry(_destroy, self.db.begin_session, db_conn) - await self.recalc_resource_usage() + await self.recalc_resource_usage_by_session(session_id) async with handle_session_exception( self.db, @@ -2358,16 +2320,9 @@ async def _destroy(db_session: AsyncSession) -> SessionRow: async def _decrease_concurrency_used(access_key: AccessKey, is_private: bool) -> None: if is_private: - kp_key = "keypair.sftp_concurrency_used" + await self.concurrency_tracker.remove_system_sessions(access_key, [session_id]) else: - kp_key = "keypair.concurrency_used" - await redis_helper.execute( - self.redis_stat, - lambda r: r.incrby( - f"{kp_key}.{access_key}", - -1, - ), - ) + await self.concurrency_tracker.remove_compute_sessions(access_key, [session_id]) match target_session.status: case SessionStatus.PENDING: @@ -2622,7 +2577,7 @@ async def _destroy_kernels_in_agent( (session_id, session.name, session.access_key), ) if forced: - await self.recalc_resource_usage() + await self.recalc_resource_usage_by_session(session_id) return main_stat async def clean_session( @@ -3486,15 +3441,16 @@ async def _get_and_transit( if result is None: return - access_key = cast(AccessKey, result.access_key) + # access_key = cast(AccessKey, result.access_key) agent = cast(AgentId, result.agent) async def _recalc(db_session: AsyncSession) -> None: - log.debug( - "recalculate concurrency used in kernel termination (ak: {})", - access_key, - ) - await recalc_concurrency_used(db_session, self.redis_stat, access_key) + # For kernel termination, we don't have to update per-session concurrency_used! + # log.debug( + # "update concurrency_used in kernel termination (ak: {})", + # access_key, + # ) + # await self.concurrency_tracker.recalc_concurrency_used(db_session, access_key) log.debug( "recalculate agent resource occupancy in kernel termination (agent: {})", agent, diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 2489369ef50..be72453a70e 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -93,7 +93,6 @@ SessionStatus, list_schedulable_agents_by_sgroup, recalc_agent_resource_occupancy, - recalc_concurrency_used, ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine from ..models.utils import ( @@ -1671,7 +1670,7 @@ async def _mark_session_cancelled() -> None: session.id, destroyed_kernels, ) - await self.registry.recalc_resource_usage() + await self.registry.recalc_resource_usage_by_session(session.id) except Exception as destroy_err: log.error(log_fmt + "cleanup-start-failure: error", *log_args, exc_info=destroy_err) finally: @@ -1891,4 +1890,7 @@ async def _rollback_predicate_mutations( # may accumulate up multiple subtractions, resulting in # negative concurrency_occupied values. log.debug("recalculate concurrency used in rollback predicates (ak: {})", session.access_key) - await recalc_concurrency_used(db_sess, sched_ctx.registry.redis_stat, session.access_key) + + # We no longer need to do this since the concurrency tracker now uses + # Redis sets to keep track of session IDs. + # await recalc_concurrency_used(db_sess, sched_ctx.registry.redis_stat, session.access_key) diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index 2564c6bf02b..4288f2ed72e 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import load_only, noload -from ai.backend.common import redis_helper from ai.backend.common.types import ResourceSlot, SessionResult, SessionTypes from ai.backend.logging import BraceStyleAdapter @@ -20,29 +19,15 @@ SessionRow, UserRow, ) -from ..models.session import SessionStatus +from ..models.exceptions import ResourceLimitExceeded +from ..models.session import ( + SessionStatus, +) from ..models.utils import execute_with_retry from .types import PredicateResult, SchedulingContext log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) -_check_keypair_concurrency_script = """ -local key = KEYS[1] -local limit = tonumber(ARGV[1]) -local result = {} -redis.call('SETNX', key, 0) -local count = tonumber(redis.call('GET', key)) -if limit > 0 and count >= limit then - result[1] = 0 - result[2] = count - return result -end -redis.call('INCR', key) -result[1] = 1 -result[2] = count + 1 -return result -""" - async def check_reserved_batch_session( db_sess: SASession, @@ -83,18 +68,13 @@ async def _get_max_concurrent_sessions() -> int: return result.scalar() max_concurrent_sessions = await execute_with_retry(_get_max_concurrent_sessions) - if sess_ctx.is_private: - redis_key = f"keypair.sftp_concurrency_used.{sess_ctx.access_key}" - else: - redis_key = f"keypair.concurrency_used.{sess_ctx.access_key}" - ok, concurrency_used = await redis_helper.execute_script( - sched_ctx.registry.redis_stat, - "check_keypair_concurrency_used", - _check_keypair_concurrency_script, - [redis_key], - [max_concurrent_sessions], - ) - if ok == 0: + try: + concurrency_used = await sched_ctx.registry.concurrency_tracker.add_compute_sessions( + sess_ctx.access_key, + sess_ctx.id, + limit=max_concurrent_sessions, + ) + except ResourceLimitExceeded: return PredicateResult( False, f"You cannot run more than {max_concurrent_sessions} concurrent sessions", diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 6e216262017..d116903e62e 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -454,6 +454,14 @@ async def idle_checker_ctx(root_ctx: RootContext) -> AsyncIterator[None]: await root_ctx.g.idle_checker_host.shutdown() +@actxmgr +async def resource_tracker_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + from .models.resource_policy import ConcurrencyTracker + + root_ctx.g.concurrency_tracker = ConcurrencyTracker(root_ctx.h.redis_stat) + yield + + @actxmgr async def storage_manager_ctx(root_ctx: RootContext) -> AsyncIterator[None]: from .models.storage import StorageSessionManager @@ -825,6 +833,7 @@ def build_root_app( event_dispatcher_ctx, idle_checker_ctx, storage_manager_ctx, + resource_tracker_ctx, hook_plugin_ctx, monitoring_ctx, agent_registry_ctx, diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py index 379d11d624f..8a8141359ec 100644 --- a/tests/manager/conftest.py +++ b/tests/manager/conftest.py @@ -811,6 +811,7 @@ async def registry_ctx(mocker): mock_redis_live.hset = AsyncMock() mock_redis_image = MagicMock() mock_redis_stream = MagicMock() + mock_concurrency_tracker = MagicMock() mock_event_dispatcher = MagicMock() mock_event_producer = MagicMock() mock_event_producer.produce_event = AsyncMock() @@ -827,6 +828,7 @@ async def registry_ctx(mocker): halfstack_ctx.redis_image = mock_redis_image halfstack_ctx.redis_stream = mock_redis_stream global_ctx = GlobalObjectContext() + global_ctx.concurrency_tracker = mock_concurrency_tracker global_ctx.event_dispatcher = mock_event_dispatcher global_ctx.event_producer = mock_event_producer global_ctx.storage_manager = None # type: ignore diff --git a/tests/manager/models/test_container_registries.py b/tests/manager/models/test_container_registries.py index 69abc7efcdc..54030bc8124 100644 --- a/tests/manager/models/test_container_registries.py +++ b/tests/manager/models/test_container_registries.py @@ -38,6 +38,7 @@ def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQuery redis_stat=None, # type: ignore redis_image=None, # type: ignore redis_live=None, # type: ignore + concurrency_tracker=None, # type: ignore manager_status=None, # type: ignore known_slot_types=None, # type: ignore background_task_manager=None, # type: ignore diff --git a/tests/manager/models/test_container_registry_nodes.py b/tests/manager/models/test_container_registry_nodes.py index 18ea6247d3d..295f98a4b30 100644 --- a/tests/manager/models/test_container_registry_nodes.py +++ b/tests/manager/models/test_container_registry_nodes.py @@ -58,6 +58,7 @@ def mock_shared_config_api_getitem(key): redis_stat=None, # type: ignore redis_image=None, # type: ignore redis_live=None, # type: ignore + concurrency_tracker=None, # type: ignore manager_status=None, # type: ignore known_slot_types=None, # type: ignore background_task_manager=None, # type: ignore