Skip to content

Commit 24ef1cc

Browse files
committed
fix: Rewrite concurrency tracking mechanism to use Redis sets
1 parent 46c011a commit 24ef1cc

File tree

18 files changed

+432
-300
lines changed

18 files changed

+432
-300
lines changed

src/ai/backend/manager/api/admin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
8181
redis_stat=root_ctx.h.redis_stat,
8282
redis_image=root_ctx.h.redis_image,
8383
redis_live=root_ctx.h.redis_live,
84+
concurrency_tracker=root_ctx.g.concurrency_tracker,
8485
manager_status=manager_status,
8586
known_slot_types=known_slot_types,
8687
background_task_manager=root_ctx.g.background_task_manager,

src/ai/backend/manager/api/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from ..config import LocalConfig, SharedConfig
1515
from ..idle import IdleCheckerHost
16+
from ..models.resource_policy import ConcurrencyTracker
1617
from ..models.storage import StorageSessionManager
1718
from ..models.utils import ExtendedAsyncSAEngine
1819
from ..plugin.webapp import WebappPluginContext
@@ -51,6 +52,7 @@ class GlobalObjectContext:
5152
idle_checker_host: IdleCheckerHost
5253
storage_manager: StorageSessionManager
5354
background_task_manager: BackgroundTaskManager
55+
concurrency_tracker: ConcurrencyTracker
5456
webapp_plugin_ctx: WebappPluginContext
5557
hook_plugin_ctx: HookPluginContext
5658
error_monitor: ErrorPluginContext

src/ai/backend/manager/api/resource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ async def recalculate_usage(request: web.Request) -> web.Response:
307307
"""
308308
log.info("RECALCULATE_USAGE ()")
309309
root_ctx: RootContext = request.app["_root.context"]
310-
await root_ctx.registry.recalc_resource_usage()
310+
await root_ctx.registry.recalc_resource_usage_by_fullscan()
311311
return web.json_response({}, status=200)
312312

313313

src/ai/backend/manager/models/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ async def recalc_agent_resource_occupancy(db_session: SASession, agent_id: Agent
121121
& (KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))
122122
)
123123
.options(load_only(KernelRow.occupied_slots))
124+
.with_for_update()
124125
)
125126
kernel_rows = cast(list[KernelRow], (await db_session.scalars(_stmt)).all())
126127
occupied_slots = ResourceSlot()
@@ -145,6 +146,7 @@ async def recalc_agent_resource_occupancy_using_orm(
145146
KernelRow, KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)
146147
),
147148
)
149+
.with_for_update()
148150
)
149151
occupied_slots = ResourceSlot()
150152
agent_row = cast(AgentRow, await db_session.scalar(agent_query))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
4+
class ResourceError(Exception):
5+
pass
6+
7+
8+
class ResourceLimitExceeded(ResourceError):
9+
pass

src/ai/backend/manager/models/gql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
134134
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
135135
from .resource_policy import (
136+
ConcurrencyTracker,
136137
CreateKeyPairResourcePolicy,
137138
CreateProjectResourcePolicy,
138139
CreateUserResourcePolicy,
@@ -210,6 +211,7 @@ class GraphQueryContext:
210211
redis_stat: RedisConnectionInfo
211212
redis_live: RedisConnectionInfo
212213
redis_image: RedisConnectionInfo
214+
concurrency_tracker: ConcurrencyTracker
213215
manager_status: ManagerStatus
214216
known_slot_types: Mapping[SlotName, SlotTypes]
215217
background_task_manager: BackgroundTaskManager

src/ai/backend/manager/models/kernel.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
BinarySize,
4040
ClusterMode,
4141
KernelId,
42-
RedisConnectionInfo,
4342
ResourceSlot,
4443
SessionId,
4544
SessionResult,
@@ -105,7 +104,6 @@
105104
"RESOURCE_USAGE_KERNEL_STATUSES",
106105
"DEAD_KERNEL_STATUSES",
107106
"LIVE_STATUS",
108-
"recalc_concurrency_used",
109107
)
110108

111109
log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.models.kernel"))
@@ -1560,51 +1558,3 @@ class Meta:
15601558
interfaces = (PaginatedList,)
15611559

15621560
items = graphene.List(LegacyComputeSession, required=True)
1563-
1564-
1565-
async def recalc_concurrency_used(
1566-
db_sess: SASession,
1567-
redis_stat: RedisConnectionInfo,
1568-
access_key: AccessKey,
1569-
) -> None:
1570-
concurrency_used: int
1571-
from .session import PRIVATE_SESSION_TYPES
1572-
1573-
async with db_sess.begin_nested():
1574-
result = await db_sess.execute(
1575-
sa.select(sa.func.count())
1576-
.select_from(KernelRow)
1577-
.where(
1578-
(KernelRow.access_key == access_key)
1579-
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
1580-
& (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES))
1581-
),
1582-
)
1583-
concurrency_used = result.scalar()
1584-
result = await db_sess.execute(
1585-
sa.select(sa.func.count())
1586-
.select_from(KernelRow)
1587-
.where(
1588-
(KernelRow.access_key == access_key)
1589-
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
1590-
& (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES))
1591-
),
1592-
)
1593-
sftp_concurrency_used = result.scalar()
1594-
assert isinstance(concurrency_used, int)
1595-
assert isinstance(sftp_concurrency_used, int)
1596-
1597-
await redis_helper.execute(
1598-
redis_stat,
1599-
lambda r: r.set(
1600-
f"keypair.concurrency_used.{access_key}",
1601-
concurrency_used,
1602-
),
1603-
)
1604-
await redis_helper.execute(
1605-
redis_stat,
1606-
lambda r: r.set(
1607-
f"keypair.sftp_concurrency_used.{access_key}",
1608-
sftp_concurrency_used,
1609-
),
1610-
)

src/ai/backend/manager/models/keypair.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,7 @@ async def resolve_compute_sessions(
260260

261261
async def resolve_concurrency_used(self, info: graphene.ResolveInfo) -> int:
262262
ctx: GraphQueryContext = info.context
263-
kp_key = "keypair.concurrency_used"
264-
concurrency_used = await redis_helper.execute(
265-
ctx.redis_stat,
266-
lambda r: r.get(f"{kp_key}.{self.access_key}"),
267-
)
268-
if concurrency_used is not None:
269-
return int(concurrency_used)
270-
return 0
263+
return await ctx.concurrency_tracker.count_compute_sessions(AccessKey(self.access_key))
271264

272265
async def resolve_last_used(self, info: graphene.ResolveInfo) -> datetime | None:
273266
ctx: GraphQueryContext = info.context
@@ -654,10 +647,7 @@ async def mutate(
654647
delete_query = sa.delete(keypairs).where(keypairs.c.access_key == access_key)
655648
result = await simple_db_mutate(cls, ctx, delete_query)
656649
if result.ok:
657-
await redis_helper.execute(
658-
ctx.redis_stat,
659-
lambda r: r.delete(f"keypair.concurrency_used.{access_key}"),
660-
)
650+
await ctx.concurrency_tracker.clear(access_key)
661651
return result
662652

663653

0 commit comments

Comments
 (0)