Skip to content

Commit

Permalink
fix: Rewrite concurrency tracking mechanism to use Redis sets
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Nov 27, 2024
1 parent 46c011a commit 24ef1cc
Show file tree
Hide file tree
Showing 18 changed files with 432 additions and 300 deletions.
1 change: 1 addition & 0 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
redis_stat=root_ctx.h.redis_stat,
redis_image=root_ctx.h.redis_image,
redis_live=root_ctx.h.redis_live,
concurrency_tracker=root_ctx.g.concurrency_tracker,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.g.background_task_manager,
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ..config import LocalConfig, SharedConfig
from ..idle import IdleCheckerHost
from ..models.resource_policy import ConcurrencyTracker
from ..models.storage import StorageSessionManager
from ..models.utils import ExtendedAsyncSAEngine
from ..plugin.webapp import WebappPluginContext
Expand Down Expand Up @@ -51,6 +52,7 @@ class GlobalObjectContext:
idle_checker_host: IdleCheckerHost
storage_manager: StorageSessionManager
background_task_manager: BackgroundTaskManager
concurrency_tracker: ConcurrencyTracker
webapp_plugin_ctx: WebappPluginContext
hook_plugin_ctx: HookPluginContext
error_monitor: ErrorPluginContext
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async def recalculate_usage(request: web.Request) -> web.Response:
"""
log.info("RECALCULATE_USAGE ()")
root_ctx: RootContext = request.app["_root.context"]
await root_ctx.registry.recalc_resource_usage()
await root_ctx.registry.recalc_resource_usage_by_fullscan()
return web.json_response({}, status=200)


Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ async def recalc_agent_resource_occupancy(db_session: SASession, agent_id: Agent
& (KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))
)
.options(load_only(KernelRow.occupied_slots))
.with_for_update()
)
kernel_rows = cast(list[KernelRow], (await db_session.scalars(_stmt)).all())
occupied_slots = ResourceSlot()
Expand All @@ -145,6 +146,7 @@ async def recalc_agent_resource_occupancy_using_orm(
KernelRow, KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)
),
)
.with_for_update()
)
occupied_slots = ResourceSlot()
agent_row = cast(AgentRow, await db_session.scalar(agent_query))
Expand Down
9 changes: 9 additions & 0 deletions src/ai/backend/manager/models/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations


class ResourceError(Exception):
pass


class ResourceLimitExceeded(ResourceError):
pass
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
from .resource_policy import (
ConcurrencyTracker,
CreateKeyPairResourcePolicy,
CreateProjectResourcePolicy,
CreateUserResourcePolicy,
Expand Down Expand Up @@ -210,6 +211,7 @@ class GraphQueryContext:
redis_stat: RedisConnectionInfo
redis_live: RedisConnectionInfo
redis_image: RedisConnectionInfo
concurrency_tracker: ConcurrencyTracker
manager_status: ManagerStatus
known_slot_types: Mapping[SlotName, SlotTypes]
background_task_manager: BackgroundTaskManager
Expand Down
50 changes: 0 additions & 50 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
BinarySize,
ClusterMode,
KernelId,
RedisConnectionInfo,
ResourceSlot,
SessionId,
SessionResult,
Expand Down Expand Up @@ -105,7 +104,6 @@
"RESOURCE_USAGE_KERNEL_STATUSES",
"DEAD_KERNEL_STATUSES",
"LIVE_STATUS",
"recalc_concurrency_used",
)

log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.models.kernel"))
Expand Down Expand Up @@ -1560,51 +1558,3 @@ class Meta:
interfaces = (PaginatedList,)

items = graphene.List(LegacyComputeSession, required=True)


async def recalc_concurrency_used(
db_sess: SASession,
redis_stat: RedisConnectionInfo,
access_key: AccessKey,
) -> None:
concurrency_used: int
from .session import PRIVATE_SESSION_TYPES

async with db_sess.begin_nested():
result = await db_sess.execute(
sa.select(sa.func.count())
.select_from(KernelRow)
.where(
(KernelRow.access_key == access_key)
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
& (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES))
),
)
concurrency_used = result.scalar()
result = await db_sess.execute(
sa.select(sa.func.count())
.select_from(KernelRow)
.where(
(KernelRow.access_key == access_key)
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
& (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES))
),
)
sftp_concurrency_used = result.scalar()
assert isinstance(concurrency_used, int)
assert isinstance(sftp_concurrency_used, int)

await redis_helper.execute(
redis_stat,
lambda r: r.set(
f"keypair.concurrency_used.{access_key}",
concurrency_used,
),
)
await redis_helper.execute(
redis_stat,
lambda r: r.set(
f"keypair.sftp_concurrency_used.{access_key}",
sftp_concurrency_used,
),
)
14 changes: 2 additions & 12 deletions src/ai/backend/manager/models/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,7 @@ async def resolve_compute_sessions(

async def resolve_concurrency_used(self, info: graphene.ResolveInfo) -> int:
ctx: GraphQueryContext = info.context
kp_key = "keypair.concurrency_used"
concurrency_used = await redis_helper.execute(
ctx.redis_stat,
lambda r: r.get(f"{kp_key}.{self.access_key}"),
)
if concurrency_used is not None:
return int(concurrency_used)
return 0
return await ctx.concurrency_tracker.count_compute_sessions(AccessKey(self.access_key))

async def resolve_last_used(self, info: graphene.ResolveInfo) -> datetime | None:
ctx: GraphQueryContext = info.context
Expand Down Expand Up @@ -654,10 +647,7 @@ async def mutate(
delete_query = sa.delete(keypairs).where(keypairs.c.access_key == access_key)
result = await simple_db_mutate(cls, ctx, delete_query)
if result.ok:
await redis_helper.execute(
ctx.redis_stat,
lambda r: r.delete(f"keypair.concurrency_used.{access_key}"),
)
await ctx.concurrency_tracker.clear(access_key)
return result


Expand Down
Loading

0 comments on commit 24ef1cc

Please sign in to comment.