Skip to content

Commit

Permalink
refactor: impl vfolder RBAC APIs (#2137)
Browse files Browse the repository at this point in the history
implement basic functions and APIs for RBAC design

**Checklist:** (if applicable)

- [x] Milestone metadata specifying the target backport version
- [x] Update of end-to-end CLI integration tests in `ai.backend.test`
- [x] Documentation
  - Contents in the `docs` directory
  - docstrings in public interfaces and type annotations

<!-- readthedocs-preview sorna start -->
----
📚 Documentation preview 📚: https://sorna--2137.org.readthedocs.build/en/2137/

<!-- readthedocs-preview sorna end -->

<!-- readthedocs-preview sorna-ko start -->
----
📚 Documentation preview 📚: https://sorna-ko--2137.org.readthedocs.build/ko/2137/

<!-- readthedocs-preview sorna-ko end -->
  • Loading branch information
fregataa committed Jul 1, 2024
1 parent 1065ed0 commit b583f39
Show file tree
Hide file tree
Showing 9 changed files with 859 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,9 @@ type VirtualFolder implements Item {
group: UUID
group_name: String
creator: String

"""Added in 24.09.0."""
domain_name: String

Check notice on line 532 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'domain_name' was added to object type 'VirtualFolder'

Field 'domain_name' was added to object type 'VirtualFolder'
unmanaged_path: String
usage_mode: String
permission: String
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/manager/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import image as _image
from . import kernel as _kernel
from . import keypair as _keypair
from . import rbac as _rbac
from . import resource_policy as _rpolicy
from . import resource_preset as _rpreset
from . import resource_usage as _rusage
Expand All @@ -33,6 +34,7 @@
*_user.__all__,
*_vfolder.__all__,
*_dotfile.__all__,
*_rbac.__all__,
*_rusage.__all__,
*_rpolicy.__all__,
*_rpreset.__all__,
Expand All @@ -54,6 +56,7 @@
from .image import * # noqa
from .kernel import * # noqa
from .keypair import * # noqa
from .rbac import * # noqa
from .resource_policy import * # noqa
from .resource_preset import * # noqa
from .resource_usage import * # noqa
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/manager/models/acl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Mapping, Sequence
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, List, Sequence

import graphene

Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB
_rx_slug = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$")


class UserRoleInProject(enum.StrEnum):
ADMIN = enum.auto() # TODO: impl project admin
USER = enum.auto() # User is associated as user


association_groups_users = sa.Table(
"association_groups_users",
mapper_registry.metadata,
Expand Down Expand Up @@ -197,7 +203,7 @@ class GroupRow(Base):
users = relationship("AssocGroupUserRow", back_populates="group")
resource_policy_row = relationship("ProjectResourcePolicyRow", back_populates="projects")
kernels = relationship("KernelRow", back_populates="group_row")
vfolder_row = relationship("VFolderRow", back_populates="group_row")
vfolder_rows = relationship("VFolderRow", back_populates="group_row")


def _build_group_query(cond: sa.sql.BinaryExpression, domain_name: str) -> sa.sql.Select:
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/models/rbac/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
315 changes: 315 additions & 0 deletions src/ai/backend/manager/models/rbac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
from __future__ import annotations

import enum
import uuid
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar

import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only

from ..group import AssocGroupUserRow, GroupRow, UserRoleInProject
from ..user import UserRole

if TYPE_CHECKING:
from ..utils import ExtendedAsyncSAEngine


__all__: Sequence[str] = (
"BasePermission",
"ClientContext",
"DomainScope",
"ProjectScope",
"UserScope",
"StorageHost",
"ImageRegistry",
"ScalingGroup",
"AbstractPermissionContext",
"AbstractPermissionContextBuilder",
)


class BasePermission(enum.StrEnum):
pass


PermissionType = TypeVar("PermissionType", bound=BasePermission)


class Bypass(enum.Enum):
TOKEN = enum.auto()


bypass = Bypass.TOKEN

ProjectContext = Mapping[uuid.UUID, UserRoleInProject]


@dataclass
class ClientContext:
db: ExtendedAsyncSAEngine

domain_name: str
user_id: uuid.UUID
user_role: UserRole

_project_ctx: ProjectContext | None = field(init=False, default=None)
_domain_project_ctx: Mapping[str, ProjectContext] | None = field(init=False, default=None)

async def get_or_init_project_ctx_in_domain(
self, db_session: AsyncSession, domain_name: str
) -> ProjectContext | None:
_project_ctx = await self._get_or_init_project_ctx(db_session)
if _project_ctx is bypass:
# client is superadmin or monitor
if self._domain_project_ctx is None:
self._domain_project_ctx = {}
if domain_name not in self._domain_project_ctx:
stmt = (
sa.select(GroupRow)
.where(GroupRow.domain_name == domain_name)
.options(load_only(GroupRow.id))
)
self._domain_project_ctx = {
**self._domain_project_ctx,
domain_name: {
row.id: UserRoleInProject.ADMIN for row in await db_session.scalars(stmt)
},
}
else:
# client is domain admin or user
self._domain_project_ctx = {self.domain_name: _project_ctx}
return self._domain_project_ctx.get(domain_name)

async def get_user_role_in_project(
self, db_session: AsyncSession, project_id: uuid.UUID
) -> UserRoleInProject | None:
_project_ctx = await self._get_or_init_project_ctx(db_session)
if _project_ctx is bypass:
return UserRoleInProject.ADMIN
else:
return _project_ctx.get(project_id)

async def _get_or_init_project_ctx(self, db_session: AsyncSession) -> ProjectContext | Bypass:
match self.user_role:
case UserRole.SUPERADMIN | UserRole.MONITOR:
# Superadmins and monitors can access to ALL projects in the system.
# Let's not fetch all project data from DB.
return bypass
case UserRole.ADMIN:
if self._project_ctx is None:
stmt = (
sa.select(GroupRow)
.where(GroupRow.domain_name == self.domain_name)
.options(load_only(GroupRow.id))
)
self._project_ctx = {
row.id: UserRoleInProject.ADMIN for row in await db_session.scalars(stmt)
}
return self._project_ctx
case UserRole.USER:
if self._project_ctx is None:
stmt = (
sa.select(AssocGroupUserRow)
.select_from(sa.join(AssocGroupUserRow, GroupRow))
.where(
(AssocGroupUserRow.user_id == self.user_id)
& (GroupRow.domain_name == self.domain_name)
)
)
self._project_ctx = {
row.group_id: UserRoleInProject.USER
for row in await db_session.scalars(stmt)
}
return self._project_ctx


class BaseScope(metaclass=ABCMeta):
@abstractmethod
def __str__(self) -> str:
pass


@dataclass(frozen=True)
class DomainScope(BaseScope):
domain_name: str

def __str__(self) -> str:
return f"Domain(name: {self.domain_name})"


@dataclass(frozen=True)
class ProjectScope(BaseScope):
project_id: uuid.UUID

def __str__(self) -> str:
return f"Project(id: {self.project_id})"


@dataclass(frozen=True)
class UserScope(BaseScope):
user_id: uuid.UUID

def __str__(self) -> str:
return f"User(id: {self.user_id})"


# Extra scope is to address some scopes that contain specific object types
# such as registries for images, scaling groups for agents, storage hosts for vfolders etc.
class ExtraScope:
pass


@dataclass(frozen=True)
class StorageHost(ExtraScope):
name: str


@dataclass(frozen=True)
class ImageRegistry(ExtraScope):
name: str


@dataclass(frozen=True)
class ScalingGroup(ExtraScope):
name: str


ObjectType = TypeVar("ObjectType")
ObjectIDType = TypeVar("ObjectIDType")


@dataclass
class AbstractPermissionContext(
Generic[PermissionType, ObjectType, ObjectIDType], metaclass=ABCMeta
):
"""
Define permissions under given User, Project or Domain scopes.
Each field of this class represents a mapping of ["accessible scope id", "permissions under the scope"].
For example, `project` field has a mapping of ["accessible project id", "permissions under the project"].
{
"PROJECT_A_ID": {"READ", "WRITE", "DELETE"}
"PROJECT_B_ID": {"READ"}
}
`additional` and `overriding` fields have a mapping of ["object id", "permissions applied to the object"].
`additional` field is used to add permissions to specific objects. It can be used for admins.
`overriding` field is used to address exceptional cases such as permission overriding or cover other scopes(scaling groups or storage hosts etc).
"""

user_id_to_permission_map: Mapping[uuid.UUID, frozenset[PermissionType]] = field(
default_factory=dict
)
project_id_to_permission_map: Mapping[uuid.UUID, frozenset[PermissionType]] = field(
default_factory=dict
)
domain_name_to_permission_map: Mapping[str, frozenset[PermissionType]] = field(
default_factory=dict
)

object_id_to_additional_permission_map: Mapping[ObjectIDType, frozenset[PermissionType]] = (
field(default_factory=dict)
)
object_id_to_overriding_permission_map: Mapping[ObjectIDType, frozenset[PermissionType]] = (
field(default_factory=dict)
)

def filter_by_permission(self, permission_to_include: PermissionType) -> None:
self.user_id_to_permission_map = {
uid: permissions
for uid, permissions in self.user_id_to_permission_map.items()
if permission_to_include in permissions
}
self.project_id_to_permission_map = {
pid: permissions
for pid, permissions in self.project_id_to_permission_map.items()
if permission_to_include in permissions
}
self.domain_name_to_permission_map = {
dname: permissions
for dname, permissions in self.domain_name_to_permission_map.items()
if permission_to_include in permissions
}
self.object_id_to_additional_permission_map = {
obj_id: permissions
for obj_id, permissions in self.object_id_to_additional_permission_map.items()
if permission_to_include in permissions
}
self.object_id_to_overriding_permission_map = {
obj_id: permissions
for obj_id, permissions in self.object_id_to_overriding_permission_map.items()
if permission_to_include in permissions
}

@abstractmethod
async def build_query(self) -> sa.sql.Select | None:
pass

@abstractmethod
async def calculate_final_permission(self, acl_obj: ObjectType) -> frozenset[PermissionType]:
"""
Calculate the final permissions applied to the given object based on the fields in this class.
"""
pass


PermissionContextType = TypeVar("PermissionContextType", bound=AbstractPermissionContext)


class AbstractPermissionContextBuilder(
Generic[PermissionType, PermissionContextType], metaclass=ABCMeta
):
@classmethod
async def build(
cls,
db_session: AsyncSession,
ctx: ClientContext,
target_scope: BaseScope,
*,
permission: PermissionType | None = None,
) -> PermissionContextType:
match target_scope:
case UserScope(user_id=user_id):
result = await cls._build_in_user_scope(db_session, ctx, user_id)
case ProjectScope(project_id=project_id):
result = await cls._build_in_project_scope(db_session, ctx, project_id)
case DomainScope(domain_name=domain_name):
result = await cls._build_in_domain_scope(db_session, ctx, domain_name)
case _:
raise RuntimeError(f"invalid scope `{target_scope}`")
if permission is not None:
result.filter_by_permission(permission)
return result

@classmethod
@abstractmethod
async def _build_in_user_scope(
cls,
db_session: AsyncSession,
ctx: ClientContext,
user_id: uuid.UUID,
) -> PermissionContextType:
pass

@classmethod
@abstractmethod
async def _build_in_project_scope(
cls,
db_session: AsyncSession,
ctx: ClientContext,
project_id: uuid.UUID,
) -> PermissionContextType:
pass

@classmethod
@abstractmethod
async def _build_in_domain_scope(
cls,
db_session: AsyncSession,
ctx: ClientContext,
domain_name: str,
) -> PermissionContextType:
pass
6 changes: 6 additions & 0 deletions src/ai/backend/manager/models/rbac/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class RBACException(Exception):
pass


class NotEnoughPermission(RBACException):
pass
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class UserRow(Base):

main_keypair = relationship("KeyPairRow", foreign_keys=users.c.main_access_key)

vfolder_row = relationship("VFolderRow", back_populates="user_row")
vfolder_rows = relationship("VFolderRow", back_populates="user_row")


class UserGroup(graphene.ObjectType):
Expand Down
Loading

0 comments on commit b583f39

Please sign in to comment.