Skip to content

Commit f55366d

Browse files
authored
refactor(BA-5816): migrate session_template off association_groups_users (#11284)
1 parent d50fbab commit f55366d

21 files changed

Lines changed: 375 additions & 186 deletions

File tree

changes/11284.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Migrate session and cluster template code paths off the legacy `association_groups_users` table; project membership is now validated against `association_scopes_entities` and the REST handlers resolve `(domain, group_name) → project_id` upstream via a new `GroupService` entry point.

src/ai/backend/manager/api/rest/cluster_template/handler.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@
3838
ListClusterTemplatesResponse,
3939
UpdateClusterTemplateResponse,
4040
)
41+
from ai.backend.common.identifier.project import ProjectID
4142
from ai.backend.common.json import load_json
4243
from ai.backend.logging import BraceStyleAdapter
4344
from ai.backend.manager.dto.context import RequestCtx, UserContext
4445
from ai.backend.manager.errors.api import InvalidAPIParameters
46+
from ai.backend.manager.services.group.actions.resolve_project_id_by_name import (
47+
ResolveProjectIdByNameAction,
48+
)
4549
from ai.backend.manager.services.template.actions.create_cluster_template import (
4650
CreateClusterTemplateAction,
4751
)
@@ -59,6 +63,7 @@
5963
)
6064

6165
if TYPE_CHECKING:
66+
from ai.backend.manager.services.group.processors import GroupProcessors
6267
from ai.backend.manager.services.template.processors import TemplateProcessors
6368

6469
log: Final = BraceStyleAdapter(logging.getLogger(__spec__.name))
@@ -67,8 +72,24 @@
6772
class ClusterTemplateHandler:
6873
"""Cluster template API handler with constructor-injected dependencies."""
6974

70-
def __init__(self, *, template: TemplateProcessors) -> None:
75+
def __init__(
76+
self,
77+
*,
78+
template: TemplateProcessors,
79+
group: GroupProcessors,
80+
) -> None:
7181
self._template = template
82+
self._group = group
83+
84+
async def _resolve_project_id(self, domain_name: str, project_name: str) -> ProjectID:
85+
result = await self._group.resolve_project_id_by_name.wait_for_complete(
86+
ResolveProjectIdByNameAction(domain_name=domain_name, project_name=project_name)
87+
)
88+
if result.project_id is None:
89+
raise InvalidAPIParameters(
90+
f"No active group named {project_name!r} exists in domain {domain_name!r}"
91+
)
92+
return result.project_id
7293

7394
async def create(
7495
self,
@@ -93,9 +114,10 @@ async def create(
93114
except (yaml.YAMLError, yaml.MarkedYAMLError) as e:
94115
raise InvalidAPIParameters("Malformed payload") from e
95116

117+
project_id = await self._resolve_project_id(domain, params.group)
96118
action = CreateClusterTemplateAction(
97119
domain_name=domain,
98-
requesting_group=params.group,
120+
requesting_project=project_id,
99121
requester_uuid=ctx.user_uuid,
100122
requester_access_key=ctx.access_key,
101123
requester_role=req.request["user"]["role"],

src/ai/backend/manager/api/rest/session_template/handler.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@
3838
SessionTemplateListItemDTO,
3939
UpdateSessionTemplateResponse,
4040
)
41+
from ai.backend.common.identifier.project import ProjectID
4142
from ai.backend.common.json import load_json
4243
from ai.backend.logging import BraceStyleAdapter
4344
from ai.backend.manager.dto.context import RequestCtx, UserContext
4445
from ai.backend.manager.errors.api import InvalidAPIParameters
46+
from ai.backend.manager.services.group.actions.resolve_project_id_by_name import (
47+
ResolveProjectIdByNameAction,
48+
)
4549
from ai.backend.manager.services.template.actions.create_task_template import (
4650
CreateTaskTemplateAction,
4751
TaskTemplateItemInput,
@@ -60,6 +64,7 @@
6064
)
6165

6266
if TYPE_CHECKING:
67+
from ai.backend.manager.services.group.processors import GroupProcessors
6368
from ai.backend.manager.services.template.processors import TemplateProcessors
6469

6570
log: Final = BraceStyleAdapter(logging.getLogger(__spec__.name))
@@ -68,8 +73,24 @@
6873
class SessionTemplateHandler:
6974
"""Session template API handler with constructor-injected dependencies."""
7075

71-
def __init__(self, *, template: TemplateProcessors) -> None:
76+
def __init__(
77+
self,
78+
*,
79+
template: TemplateProcessors,
80+
group: GroupProcessors,
81+
) -> None:
7282
self._template = template
83+
self._group = group
84+
85+
async def _resolve_project_id(self, domain_name: str, project_name: str) -> ProjectID:
86+
result = await self._group.resolve_project_id_by_name.wait_for_complete(
87+
ResolveProjectIdByNameAction(domain_name=domain_name, project_name=project_name)
88+
)
89+
if result.project_id is None:
90+
raise InvalidAPIParameters(
91+
f"No active group named {project_name!r} exists in domain {domain_name!r}"
92+
)
93+
return result.project_id
7394

7495
async def create(
7596
self,
@@ -104,9 +125,10 @@ async def create(
104125
for st in payload
105126
]
106127

128+
project_id = await self._resolve_project_id(domain, params.group)
107129
action = CreateTaskTemplateAction(
108130
domain_name=domain,
109-
requesting_group=params.group,
131+
requesting_project=project_id,
110132
requester_uuid=ctx.user_uuid,
111133
requester_access_key=ctx.access_key,
112134
requester_role=req.request["user"]["role"],
@@ -220,10 +242,11 @@ async def update(
220242
for st in payload
221243
]
222244

245+
project_id = await self._resolve_project_id(domain, params.group)
223246
action = UpdateTaskTemplateAction(
224247
template_id=template_id,
225248
domain_name=domain,
226-
requesting_group=params.group,
249+
requesting_project=project_id,
227250
requester_uuid=ctx.user_uuid,
228251
requester_access_key=ctx.access_key,
229252
requester_role=req.request["user"]["role"],

src/ai/backend/manager/api/rest/tree.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,12 @@ def build_api_routes(
261261
)
262262

263263
# Template sub-registries
264-
cluster_template_handler = ClusterTemplateHandler(template=processors.template)
265-
session_template_handler = SessionTemplateHandler(template=processors.template)
264+
cluster_template_handler = ClusterTemplateHandler(
265+
template=processors.template, group=processors.group
266+
)
267+
session_template_handler = SessionTemplateHandler(
268+
template=processors.template, group=processors.group
269+
)
266270
cluster_template_reg = register_cluster_template_routes(cluster_template_handler, route_deps)
267271
session_template_reg = register_session_template_routes(session_template_handler, route_deps)
268272

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
from __future__ import annotations
22

33
import enum
4-
import uuid
5-
from collections.abc import Iterable, Mapping, Sequence
4+
from collections.abc import Mapping, Sequence
65
from typing import Any, cast
76

87
import sqlalchemy as sa
98
import trafaret as t
109
from sqlalchemy.dialects import postgresql as pgsql
11-
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
1210

1311
from ai.backend.common import validators as tx
1412
from ai.backend.common.types import SessionTypes
1513
from ai.backend.manager.defs import DEFAULT_ROLE
1614
from ai.backend.manager.exceptions import InvalidArgument
1715

1816
from .base import GUID, EnumType, IDColumn, metadata
19-
from .user import UserRole
2017
from .vfolder import verify_vfolder_name
2118

2219
__all__: Sequence[str] = (
2320
"TemplateType",
24-
"query_accessible_session_templates",
2521
"session_templates",
2622
)
2723

@@ -139,112 +135,3 @@ def check_cluster_template(raw_data: Mapping[str, Any]) -> Mapping[str, Any]:
139135
f"One and only one {DEFAULT_ROLE} node must be created per cluster",
140136
)
141137
return cast(Mapping[str, Any], data)
142-
143-
144-
async def query_accessible_session_templates(
145-
conn: SAConnection,
146-
user_uuid: uuid.UUID,
147-
template_type: TemplateType,
148-
*,
149-
user_role: UserRole | None = None,
150-
domain_name: str | None = None,
151-
allowed_types: Iterable[str] = ["user"],
152-
extra_conds: Any = None,
153-
) -> list[Mapping[str, Any]]:
154-
from .group import association_groups_users as agus
155-
from .group import groups
156-
from .user import users
157-
158-
entries: list[Mapping[str, Any]] = []
159-
if "user" in allowed_types:
160-
# Query user templates
161-
j = session_templates.join(users, session_templates.c.user_uuid == users.c.uuid)
162-
query = (
163-
sa.select(
164-
session_templates.c.name,
165-
session_templates.c.id,
166-
session_templates.c.created_at,
167-
session_templates.c.user_uuid,
168-
session_templates.c.group_id,
169-
users.c.email,
170-
)
171-
.select_from(j)
172-
.where(
173-
(session_templates.c.user_uuid == user_uuid)
174-
& session_templates.c.is_active
175-
& (session_templates.c.type == template_type),
176-
)
177-
)
178-
if extra_conds is not None:
179-
query = query.where(extra_conds)
180-
result = await conn.execute(query)
181-
for row in result:
182-
entries.append({
183-
"name": row.name,
184-
"id": row.id,
185-
"created_at": row.created_at,
186-
"is_owner": True,
187-
"user": str(row.user_uuid) if row.user_uuid else None,
188-
"group": str(row.group_id) if row.group_id else None,
189-
"user_email": row.email,
190-
"group_name": None,
191-
})
192-
if "group" in allowed_types:
193-
# Query group session_templates
194-
if user_role == UserRole.ADMIN:
195-
query = (
196-
sa.select(groups.c.id)
197-
.select_from(groups)
198-
.where(groups.c.domain_name == domain_name)
199-
)
200-
result = await conn.execute(query)
201-
grps = result.fetchall()
202-
group_ids = [g.id for g in grps]
203-
else:
204-
j = sa.join(agus, users, agus.c.user_id == users.c.uuid)
205-
query = sa.select(agus.c.group_id).select_from(j).where(agus.c.user_id == user_uuid)
206-
result = await conn.execute(query)
207-
grps = result.fetchall()
208-
group_ids = [g.group_id for g in grps]
209-
j = session_templates.join(groups, session_templates.c.group_id == groups.c.id)
210-
query = (
211-
sa.select(
212-
session_templates.c.name,
213-
session_templates.c.id,
214-
session_templates.c.created_at,
215-
session_templates.c.user_uuid,
216-
session_templates.c.group_id,
217-
groups.c.name,
218-
)
219-
.set_label_style(sa.LABEL_STYLE_TABLENAME_PLUS_COL)
220-
.select_from(j)
221-
.where(
222-
session_templates.c.group_id.in_(group_ids)
223-
& session_templates.c.is_active
224-
& (session_templates.c.type == template_type),
225-
)
226-
)
227-
if extra_conds is not None:
228-
query = query.where(extra_conds)
229-
if "user" in allowed_types:
230-
query = query.where(session_templates.c.user_uuid != user_uuid)
231-
result = await conn.execute(query)
232-
is_owner = user_role == UserRole.ADMIN
233-
for row in result:
234-
entries.append({
235-
"name": row.session_templates_name,
236-
"id": row.session_templates_id,
237-
"created_at": row.session_templates_created_at,
238-
"is_owner": is_owner,
239-
"user": (
240-
str(row.session_templates_user_uuid)
241-
if row.session_templates_user_uuid
242-
else None
243-
),
244-
"group": (
245-
str(row.session_templates_group_id) if row.session_templates_group_id else None
246-
),
247-
"user_email": None,
248-
"group_name": row.groups_name,
249-
})
250-
return entries

src/ai/backend/manager/repositories/group/db_source/db_source.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient
1919
from ai.backend.common.data.permission.types import RBACElementType
2020
from ai.backend.common.exception import InvalidAPIParameters
21+
from ai.backend.common.identifier.project import ProjectID
2122
from ai.backend.common.types import SlotName, VFolderID
2223
from ai.backend.common.utils import nmget
2324
from ai.backend.logging.utils import BraceStyleAdapter
@@ -921,6 +922,32 @@ async def get_project(self, project_id: UUID) -> GroupData:
921922
raise ProjectNotFound(f"Project {project_id} not found")
922923
return row.to_data()
923924

925+
async def project_id_by_name_in_domain(
926+
self, domain_name: str, project_name: str
927+
) -> ProjectID | None:
928+
"""Resolve an active project's UUID by its domain-scoped name.
929+
930+
LEGACY: Exists solely to support existing API handlers that only accept a
931+
group name as input (e.g. the REST v1 session/cluster template endpoints).
932+
New API handlers and any other new code MUST NOT use this — they should
933+
accept a project UUID directly.
934+
935+
Returns:
936+
The project UUID if found, or ``None`` if no matching active project exists.
937+
"""
938+
async with self._db.begin_readonly_session_read_committed() as db_sess:
939+
result = await db_sess.execute(
940+
sa.select(GroupRow.id).where(
941+
GroupRow.domain_name == domain_name,
942+
GroupRow.name == project_name,
943+
GroupRow.is_active.is_(True),
944+
)
945+
)
946+
project_id = result.scalar_one_or_none()
947+
if project_id is None:
948+
return None
949+
return ProjectID(project_id)
950+
924951
async def search_projects(
925952
self,
926953
querier: BatchQuerier,

src/ai/backend/manager/repositories/group/repository.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient
1111
from ai.backend.common.exception import BackendAIError
12+
from ai.backend.common.identifier.project import ProjectID
1213
from ai.backend.common.metrics.metric import DomainType, LayerType
1314
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
1415
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
@@ -174,6 +175,22 @@ async def get_project(self, project_id: UUID) -> GroupData:
174175
"""
175176
return await self._db_source.get_project(project_id)
176177

178+
@group_repository_resilience.apply()
179+
async def project_id_by_name_in_domain(
180+
self, domain_name: str, project_name: str
181+
) -> ProjectID | None:
182+
"""Resolve an active project's UUID by its domain-scoped name.
183+
184+
LEGACY: Exists solely to support existing API handlers that only accept a
185+
group name as input (e.g. the REST v1 session/cluster template endpoints).
186+
New API handlers and any other new code MUST NOT use this — they should
187+
accept a project UUID directly.
188+
189+
Returns:
190+
The project UUID if found, or ``None`` if no matching active project exists.
191+
"""
192+
return await self._db_source.project_id_by_name_in_domain(domain_name, project_name)
193+
177194
@group_repository_resilience.apply()
178195
async def search_projects(
179196
self,

0 commit comments

Comments
 (0)