Skip to content

Commit

Permalink
refactor: impl vfolder RBAC APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed May 13, 2024
1 parent 8241fd0 commit 22d8e97
Show file tree
Hide file tree
Showing 6 changed files with 821 additions and 5 deletions.
136 changes: 135 additions & 1 deletion src/ai/backend/manager/models/acl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Mapping, Sequence
import enum
import uuid
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, List, Sequence, TypeVar

import graphene
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession

from ai.backend.common.types import VFolderHostPermission

from .group import AssocGroupUserRow, GroupRow, UserRoleInProject
from .user import UserRole

if TYPE_CHECKING:
from .gql import GraphQueryContext

Expand All @@ -16,6 +26,130 @@
)


class AbstractACLPermission(enum.StrEnum):
pass


ACLPermissionType = TypeVar("ACLPermissionType", bound=AbstractACLPermission)


@dataclass
class RequesterContext:
db_conn: AsyncConnection

domain_name: str
user_id: uuid.UUID
user_role: UserRole

project_ctx: Mapping[uuid.UUID, UserRoleInProject] | None = None

async def get_or_init_project_ctx(self) -> Mapping[uuid.UUID, UserRoleInProject]:
if self.project_ctx is None:
if self.user_role in (UserRole.SUPERADMIN, UserRole.ADMIN):
role_in_project = UserRoleInProject.ADMIN
else:
role_in_project = UserRoleInProject.USER
stmt = (
sa.select(AssocGroupUserRow)
.select_from(sa.join(AssocGroupUserRow, GroupRow))
.where(
(AssocGroupUserRow.user_id == self.user_id)
& (GroupRow.domain_name == self.domain_name)
)
)
async with AsyncSession(self.db_conn) as db_session:
self.project_ctx = {
row.group_id: role_in_project for row in await db_session.scalars(stmt)
}
return self.project_ctx


class RequestedScope:
pass


@dataclass(frozen=True)
class RequestedDomainScope(RequestedScope):
domain_name: str


@dataclass(frozen=True)
class RequestedProjectScope(RequestedScope):
project_id: uuid.UUID


@dataclass(frozen=True)
class RequestedUserScope(RequestedScope):
user_id: uuid.UUID


ACLObjectType = TypeVar("ACLObjectType")


@dataclass
class AbstractScopePermissionMap(Generic[ACLPermissionType, ACLObjectType], metaclass=ABCMeta):
user: Mapping[uuid.UUID, frozenset[ACLPermissionType]]
project: Mapping[uuid.UUID, frozenset[ACLPermissionType]]
domain: Mapping[str, frozenset[ACLPermissionType]]

@abstractmethod
async def determine_permission(self, acl_obj: ACLObjectType) -> frozenset[ACLPermissionType]:
pass


ScopePermissionMapType = TypeVar("ScopePermissionMapType", bound=AbstractScopePermissionMap)


class AbstractScopePermissionMapBuilder(Generic[ScopePermissionMapType], metaclass=ABCMeta):
@classmethod
async def build(
cls,
db_session: AsyncSession,
ctx: RequesterContext,
requested_scope: RequestedScope,
) -> ScopePermissionMapType:
match requested_scope:
case RequestedUserScope(user_id=user_id):
return await cls._build_in_user_scope(db_session, ctx, user_id)
case RequestedProjectScope(project_id=project_id):
return await cls._build_in_project_scope(db_session, ctx, project_id)
case RequestedDomainScope(domain_name=domain_name):
return await cls._build_in_domain_scope(db_session, ctx, domain_name)
case _:
pass
raise RuntimeError(f"invalid request scope {requested_scope}")

@classmethod
@abstractmethod
async def _build_in_user_scope(
cls,
db_session: AsyncSession,
ctx: RequesterContext,
user_id: uuid.UUID,
) -> ScopePermissionMapType:
pass

@classmethod
@abstractmethod
async def _build_in_project_scope(
cls,
db_session: AsyncSession,
ctx: RequesterContext,
project_id: uuid.UUID,
) -> ScopePermissionMapType:
pass

@classmethod
@abstractmethod
async def _build_in_domain_scope(
cls,
db_session: AsyncSession,
ctx: RequesterContext,
domain_name: str,
) -> ScopePermissionMapType:
pass


def get_all_vfolder_host_permissions() -> List[str]:
return [perm.value for perm in VFolderHostPermission]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add_id_to_vfolder_permissions
Revision ID: 2be0ce116c35
Revises: 37410c773b8c
Create Date: 2024-05-11 00:53:28.387238
"""

import sqlalchemy as sa
from alembic import op

from ai.backend.manager.models.base import GUID

# revision identifiers, used by Alembic.
revision = "2be0ce116c35"
down_revision = "37410c773b8c"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"vfolder_permissions",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
)
op.create_primary_key("pk_vfolder_permissions", "vfolder_permissions", ["id"])


def downgrade():
op.drop_constraint("pk_vfolder_permissions", "vfolder_permissions", type_="primary")
op.drop_column("vfolder_permissions", "id")
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""add_vfolders_domain_name
Revision ID: 37410c773b8c
Revises: dddf9be580f5
Create Date: 2024-05-06 20:51:54.658829
"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.sql import text

# revision identifiers, used by Alembic.
revision = "37410c773b8c"
down_revision = "dddf9be580f5"
branch_labels = None
depends_on = None


def upgrade():
conn = op.get_bind()
op.add_column("vfolders", sa.Column("domain_name", sa.String(length=64), nullable=True))

conn.execute(
text(
"""\
UPDATE vfolders
SET domain_name = COALESCE(
(SELECT domain_name FROM users WHERE vfolders.user = users.uuid),
(SELECT domain_name FROM groups WHERE vfolders.group = groups.id)
)
WHERE domain_name IS NULL;
"""
)
)

op.alter_column("vfolders", column_name="domain_name", nullable=False)
op.create_index(op.f("ix_vfolders_domain_name"), "vfolders", ["domain_name"], unique=False)


def downgrade():
op.drop_index(op.f("ix_vfolders_domain_name"), table_name="vfolders")
op.drop_column("vfolders", "domain_name")
7 changes: 7 additions & 0 deletions src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB
_rx_slug = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$")


class UserRoleInProject(enum.StrEnum):
ADMIN = "admin" # User is associated as admin. TODO: impl project admin
USER = "user" # User is associated as user


association_groups_users = sa.Table(
"association_groups_users",
mapper_registry.metadata,
Expand Down Expand Up @@ -197,6 +203,7 @@ class GroupRow(Base):
users = relationship("AssocGroupUserRow", back_populates="group")
resource_policy_row = relationship("ProjectResourcePolicyRow", back_populates="projects")
kernels = relationship("KernelRow", back_populates="group_row")
vfolders = relationship("VFolderRow", back_populates="project_row")


def _build_group_query(cond: sa.sql.BinaryExpression, domain_name: str) -> sa.sql.Select:
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class UserRow(Base):

main_keypair = relationship("KeyPairRow", foreign_keys=users.c.main_access_key)

vfolders = relationship("VFolderRow", back_populates="user_row")


class UserGroup(graphene.ObjectType):
id = graphene.UUID()
Expand Down
Loading

0 comments on commit 22d8e97

Please sign in to comment.