Skip to content

Commit b583f39

Browse files
committed
refactor: impl vfolder RBAC APIs (#2137)
implement basic functions and APIs for RBAC design **Checklist:** (if applicable) - [x] Milestone metadata specifying the target backport version - [x] Update of end-to-end CLI integration tests in `ai.backend.test` - [x] Documentation - Contents in the `docs` directory - docstrings in public interfaces and type annotations <!-- readthedocs-preview sorna start --> ---- 📚 Documentation preview 📚: https://sorna--2137.org.readthedocs.build/en/2137/ <!-- readthedocs-preview sorna end --> <!-- readthedocs-preview sorna-ko start --> ---- 📚 Documentation preview 📚: https://sorna-ko--2137.org.readthedocs.build/ko/2137/ <!-- readthedocs-preview sorna-ko end -->
1 parent 1065ed0 commit b583f39

File tree

9 files changed

+859
-9
lines changed

9 files changed

+859
-9
lines changed

src/ai/backend/manager/api/schema.graphql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,9 @@ type VirtualFolder implements Item {
527527
group: UUID
528528
group_name: String
529529
creator: String
530+
531+
"""Added in 24.09.0."""
532+
domain_name: String
530533
unmanaged_path: String
531534
usage_mode: String
532535
permission: String

src/ai/backend/manager/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import image as _image
99
from . import kernel as _kernel
1010
from . import keypair as _keypair
11+
from . import rbac as _rbac
1112
from . import resource_policy as _rpolicy
1213
from . import resource_preset as _rpreset
1314
from . import resource_usage as _rusage
@@ -33,6 +34,7 @@
3334
*_user.__all__,
3435
*_vfolder.__all__,
3536
*_dotfile.__all__,
37+
*_rbac.__all__,
3638
*_rusage.__all__,
3739
*_rpolicy.__all__,
3840
*_rpreset.__all__,
@@ -54,6 +56,7 @@
5456
from .image import * # noqa
5557
from .kernel import * # noqa
5658
from .keypair import * # noqa
59+
from .rbac import * # noqa
5760
from .resource_policy import * # noqa
5861
from .resource_preset import * # noqa
5962
from .resource_usage import * # noqa

src/ai/backend/manager/models/acl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, List, Mapping, Sequence
3+
from collections.abc import Mapping
4+
from typing import TYPE_CHECKING, Any, List, Sequence
45

56
import graphene
67

src/ai/backend/manager/models/group.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@
9494
MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB
9595
_rx_slug = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$")
9696

97+
98+
class UserRoleInProject(enum.StrEnum):
99+
ADMIN = enum.auto() # TODO: impl project admin
100+
USER = enum.auto() # User is associated as user
101+
102+
97103
association_groups_users = sa.Table(
98104
"association_groups_users",
99105
mapper_registry.metadata,
@@ -197,7 +203,7 @@ class GroupRow(Base):
197203
users = relationship("AssocGroupUserRow", back_populates="group")
198204
resource_policy_row = relationship("ProjectResourcePolicyRow", back_populates="projects")
199205
kernels = relationship("KernelRow", back_populates="group_row")
200-
vfolder_row = relationship("VFolderRow", back_populates="group_row")
206+
vfolder_rows = relationship("VFolderRow", back_populates="group_row")
201207

202208

203209
def _build_group_query(cond: sa.sql.BinaryExpression, domain_name: str) -> sa.sql.Select:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python_sources()
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
import uuid
5+
from abc import ABCMeta, abstractmethod
6+
from collections.abc import Mapping
7+
from dataclasses import dataclass, field
8+
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar
9+
10+
import sqlalchemy as sa
11+
from sqlalchemy.ext.asyncio import AsyncSession
12+
from sqlalchemy.orm import load_only
13+
14+
from ..group import AssocGroupUserRow, GroupRow, UserRoleInProject
15+
from ..user import UserRole
16+
17+
if TYPE_CHECKING:
18+
from ..utils import ExtendedAsyncSAEngine
19+
20+
21+
__all__: Sequence[str] = (
22+
"BasePermission",
23+
"ClientContext",
24+
"DomainScope",
25+
"ProjectScope",
26+
"UserScope",
27+
"StorageHost",
28+
"ImageRegistry",
29+
"ScalingGroup",
30+
"AbstractPermissionContext",
31+
"AbstractPermissionContextBuilder",
32+
)
33+
34+
35+
class BasePermission(enum.StrEnum):
36+
pass
37+
38+
39+
PermissionType = TypeVar("PermissionType", bound=BasePermission)
40+
41+
42+
class Bypass(enum.Enum):
43+
TOKEN = enum.auto()
44+
45+
46+
bypass = Bypass.TOKEN
47+
48+
ProjectContext = Mapping[uuid.UUID, UserRoleInProject]
49+
50+
51+
@dataclass
52+
class ClientContext:
53+
db: ExtendedAsyncSAEngine
54+
55+
domain_name: str
56+
user_id: uuid.UUID
57+
user_role: UserRole
58+
59+
_project_ctx: ProjectContext | None = field(init=False, default=None)
60+
_domain_project_ctx: Mapping[str, ProjectContext] | None = field(init=False, default=None)
61+
62+
async def get_or_init_project_ctx_in_domain(
63+
self, db_session: AsyncSession, domain_name: str
64+
) -> ProjectContext | None:
65+
_project_ctx = await self._get_or_init_project_ctx(db_session)
66+
if _project_ctx is bypass:
67+
# client is superadmin or monitor
68+
if self._domain_project_ctx is None:
69+
self._domain_project_ctx = {}
70+
if domain_name not in self._domain_project_ctx:
71+
stmt = (
72+
sa.select(GroupRow)
73+
.where(GroupRow.domain_name == domain_name)
74+
.options(load_only(GroupRow.id))
75+
)
76+
self._domain_project_ctx = {
77+
**self._domain_project_ctx,
78+
domain_name: {
79+
row.id: UserRoleInProject.ADMIN for row in await db_session.scalars(stmt)
80+
},
81+
}
82+
else:
83+
# client is domain admin or user
84+
self._domain_project_ctx = {self.domain_name: _project_ctx}
85+
return self._domain_project_ctx.get(domain_name)
86+
87+
async def get_user_role_in_project(
88+
self, db_session: AsyncSession, project_id: uuid.UUID
89+
) -> UserRoleInProject | None:
90+
_project_ctx = await self._get_or_init_project_ctx(db_session)
91+
if _project_ctx is bypass:
92+
return UserRoleInProject.ADMIN
93+
else:
94+
return _project_ctx.get(project_id)
95+
96+
async def _get_or_init_project_ctx(self, db_session: AsyncSession) -> ProjectContext | Bypass:
97+
match self.user_role:
98+
case UserRole.SUPERADMIN | UserRole.MONITOR:
99+
# Superadmins and monitors can access to ALL projects in the system.
100+
# Let's not fetch all project data from DB.
101+
return bypass
102+
case UserRole.ADMIN:
103+
if self._project_ctx is None:
104+
stmt = (
105+
sa.select(GroupRow)
106+
.where(GroupRow.domain_name == self.domain_name)
107+
.options(load_only(GroupRow.id))
108+
)
109+
self._project_ctx = {
110+
row.id: UserRoleInProject.ADMIN for row in await db_session.scalars(stmt)
111+
}
112+
return self._project_ctx
113+
case UserRole.USER:
114+
if self._project_ctx is None:
115+
stmt = (
116+
sa.select(AssocGroupUserRow)
117+
.select_from(sa.join(AssocGroupUserRow, GroupRow))
118+
.where(
119+
(AssocGroupUserRow.user_id == self.user_id)
120+
& (GroupRow.domain_name == self.domain_name)
121+
)
122+
)
123+
self._project_ctx = {
124+
row.group_id: UserRoleInProject.USER
125+
for row in await db_session.scalars(stmt)
126+
}
127+
return self._project_ctx
128+
129+
130+
class BaseScope(metaclass=ABCMeta):
131+
@abstractmethod
132+
def __str__(self) -> str:
133+
pass
134+
135+
136+
@dataclass(frozen=True)
137+
class DomainScope(BaseScope):
138+
domain_name: str
139+
140+
def __str__(self) -> str:
141+
return f"Domain(name: {self.domain_name})"
142+
143+
144+
@dataclass(frozen=True)
145+
class ProjectScope(BaseScope):
146+
project_id: uuid.UUID
147+
148+
def __str__(self) -> str:
149+
return f"Project(id: {self.project_id})"
150+
151+
152+
@dataclass(frozen=True)
153+
class UserScope(BaseScope):
154+
user_id: uuid.UUID
155+
156+
def __str__(self) -> str:
157+
return f"User(id: {self.user_id})"
158+
159+
160+
# Extra scope is to address some scopes that contain specific object types
161+
# such as registries for images, scaling groups for agents, storage hosts for vfolders etc.
162+
class ExtraScope:
163+
pass
164+
165+
166+
@dataclass(frozen=True)
167+
class StorageHost(ExtraScope):
168+
name: str
169+
170+
171+
@dataclass(frozen=True)
172+
class ImageRegistry(ExtraScope):
173+
name: str
174+
175+
176+
@dataclass(frozen=True)
177+
class ScalingGroup(ExtraScope):
178+
name: str
179+
180+
181+
ObjectType = TypeVar("ObjectType")
182+
ObjectIDType = TypeVar("ObjectIDType")
183+
184+
185+
@dataclass
186+
class AbstractPermissionContext(
187+
Generic[PermissionType, ObjectType, ObjectIDType], metaclass=ABCMeta
188+
):
189+
"""
190+
Define permissions under given User, Project or Domain scopes.
191+
Each field of this class represents a mapping of ["accessible scope id", "permissions under the scope"].
192+
For example, `project` field has a mapping of ["accessible project id", "permissions under the project"].
193+
{
194+
"PROJECT_A_ID": {"READ", "WRITE", "DELETE"}
195+
"PROJECT_B_ID": {"READ"}
196+
}
197+
198+
`additional` and `overriding` fields have a mapping of ["object id", "permissions applied to the object"].
199+
`additional` field is used to add permissions to specific objects. It can be used for admins.
200+
`overriding` field is used to address exceptional cases such as permission overriding or cover other scopes(scaling groups or storage hosts etc).
201+
"""
202+
203+
user_id_to_permission_map: Mapping[uuid.UUID, frozenset[PermissionType]] = field(
204+
default_factory=dict
205+
)
206+
project_id_to_permission_map: Mapping[uuid.UUID, frozenset[PermissionType]] = field(
207+
default_factory=dict
208+
)
209+
domain_name_to_permission_map: Mapping[str, frozenset[PermissionType]] = field(
210+
default_factory=dict
211+
)
212+
213+
object_id_to_additional_permission_map: Mapping[ObjectIDType, frozenset[PermissionType]] = (
214+
field(default_factory=dict)
215+
)
216+
object_id_to_overriding_permission_map: Mapping[ObjectIDType, frozenset[PermissionType]] = (
217+
field(default_factory=dict)
218+
)
219+
220+
def filter_by_permission(self, permission_to_include: PermissionType) -> None:
221+
self.user_id_to_permission_map = {
222+
uid: permissions
223+
for uid, permissions in self.user_id_to_permission_map.items()
224+
if permission_to_include in permissions
225+
}
226+
self.project_id_to_permission_map = {
227+
pid: permissions
228+
for pid, permissions in self.project_id_to_permission_map.items()
229+
if permission_to_include in permissions
230+
}
231+
self.domain_name_to_permission_map = {
232+
dname: permissions
233+
for dname, permissions in self.domain_name_to_permission_map.items()
234+
if permission_to_include in permissions
235+
}
236+
self.object_id_to_additional_permission_map = {
237+
obj_id: permissions
238+
for obj_id, permissions in self.object_id_to_additional_permission_map.items()
239+
if permission_to_include in permissions
240+
}
241+
self.object_id_to_overriding_permission_map = {
242+
obj_id: permissions
243+
for obj_id, permissions in self.object_id_to_overriding_permission_map.items()
244+
if permission_to_include in permissions
245+
}
246+
247+
@abstractmethod
248+
async def build_query(self) -> sa.sql.Select | None:
249+
pass
250+
251+
@abstractmethod
252+
async def calculate_final_permission(self, acl_obj: ObjectType) -> frozenset[PermissionType]:
253+
"""
254+
Calculate the final permissions applied to the given object based on the fields in this class.
255+
"""
256+
pass
257+
258+
259+
PermissionContextType = TypeVar("PermissionContextType", bound=AbstractPermissionContext)
260+
261+
262+
class AbstractPermissionContextBuilder(
263+
Generic[PermissionType, PermissionContextType], metaclass=ABCMeta
264+
):
265+
@classmethod
266+
async def build(
267+
cls,
268+
db_session: AsyncSession,
269+
ctx: ClientContext,
270+
target_scope: BaseScope,
271+
*,
272+
permission: PermissionType | None = None,
273+
) -> PermissionContextType:
274+
match target_scope:
275+
case UserScope(user_id=user_id):
276+
result = await cls._build_in_user_scope(db_session, ctx, user_id)
277+
case ProjectScope(project_id=project_id):
278+
result = await cls._build_in_project_scope(db_session, ctx, project_id)
279+
case DomainScope(domain_name=domain_name):
280+
result = await cls._build_in_domain_scope(db_session, ctx, domain_name)
281+
case _:
282+
raise RuntimeError(f"invalid scope `{target_scope}`")
283+
if permission is not None:
284+
result.filter_by_permission(permission)
285+
return result
286+
287+
@classmethod
288+
@abstractmethod
289+
async def _build_in_user_scope(
290+
cls,
291+
db_session: AsyncSession,
292+
ctx: ClientContext,
293+
user_id: uuid.UUID,
294+
) -> PermissionContextType:
295+
pass
296+
297+
@classmethod
298+
@abstractmethod
299+
async def _build_in_project_scope(
300+
cls,
301+
db_session: AsyncSession,
302+
ctx: ClientContext,
303+
project_id: uuid.UUID,
304+
) -> PermissionContextType:
305+
pass
306+
307+
@classmethod
308+
@abstractmethod
309+
async def _build_in_domain_scope(
310+
cls,
311+
db_session: AsyncSession,
312+
ctx: ClientContext,
313+
domain_name: str,
314+
) -> PermissionContextType:
315+
pass
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class RBACException(Exception):
2+
pass
3+
4+
5+
class NotEnoughPermission(RBACException):
6+
pass

src/ai/backend/manager/models/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class UserRow(Base):
186186

187187
main_keypair = relationship("KeyPairRow", foreign_keys=users.c.main_access_key)
188188

189-
vfolder_row = relationship("VFolderRow", back_populates="user_row")
189+
vfolder_rows = relationship("VFolderRow", back_populates="user_row")
190190

191191

192192
class UserGroup(graphene.ObjectType):

0 commit comments

Comments
 (0)