diff --git a/changes/3669.feat.md b/changes/3669.feat.md new file mode 100644 index 00000000000..6ea21892f78 --- /dev/null +++ b/changes/3669.feat.md @@ -0,0 +1 @@ +Add Repository for VFolder in manager \ No newline at end of file diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 03b8c65b892..248d84cf02f 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "Backend.AI Manager API", "description": "Backend.AI Manager REST API specification", - "version": "25.1.1", + "version": "25.2.0", "contact": { "name": "Lablup Inc.", "url": "https://docs.backend.ai", diff --git a/src/ai/backend/common/dto/manager/field.py b/src/ai/backend/common/dto/manager/field.py index 6cc56d79c29..a14ba581e45 100644 --- a/src/ai/backend/common/dto/manager/field.py +++ b/src/ai/backend/common/dto/manager/field.py @@ -6,7 +6,7 @@ from ai.backend.common.types import VFolderUsageMode -class VFolderPermissionField(enum.StrEnum): +class VFolderMountPermissionField(enum.StrEnum): READ_ONLY = "ro" READ_WRITE = "rw" RW_DELETE = "wd" @@ -38,7 +38,7 @@ class VFolderItemField(BaseModel): host: str usage_mode: VFolderUsageMode created_at: str - permission: VFolderPermissionField + permission: VFolderMountPermissionField max_size: int creator: str ownership_type: VFolderOwnershipTypeField @@ -47,8 +47,8 @@ class VFolderItemField(BaseModel): cloneable: bool status: VFolderOperationStatusField is_owner: bool - user_email: str - group_name: str + user_email: Optional[str] + group_name: Optional[str] type: str # legacy max_files: int cur_size: int diff --git a/src/ai/backend/common/dto/manager/request.py b/src/ai/backend/common/dto/manager/request.py index a9835adbc5d..e904bd61ca7 100644 --- a/src/ai/backend/common/dto/manager/request.py +++ b/src/ai/backend/common/dto/manager/request.py @@ -4,7 +4,7 @@ from pydantic import AliasChoices, BaseModel, Field from ai.backend.common import typed_validators as tv -from ai.backend.common.dto.manager.dto import VFolderPermissionDTO +from ai.backend.common.dto.manager.field import VFolderMountPermissionField from ai.backend.common.types import VFolderUsageMode @@ -17,7 +17,7 @@ class VFolderCreateReq(BaseModel): default=None, ) usage_mode: VFolderUsageMode = Field(default=VFolderUsageMode.GENERAL) - permission: VFolderPermissionDTO = Field(default=VFolderPermissionDTO.READ_WRITE) + permission: VFolderMountPermissionField = Field(default=VFolderMountPermissionField.READ_WRITE) unmanaged_path: Optional[str] = Field( validation_alias=AliasChoices("unmanaged_path", "unmanagedPath"), default=None, diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py new file mode 100644 index 00000000000..e39d7b0f09b --- /dev/null +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -0,0 +1,447 @@ +import uuid +from typing import Any, Awaitable, Callable, Iterable, Optional, ParamSpec, TypeVar + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import selectinload + +from ai.backend.manager.api.exceptions import GroupNotFound, UserNotFound +from ai.backend.manager.data.vfolder.dto import ( + UserIdentity, + VFolderItem, + VFolderMetadataToCreate, + VFolderResourceLimit, +) +from ai.backend.manager.models import ( + HARD_DELETED_VFOLDER_STATUSES, + ProjectType, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, + VFolderStatusSet, + vfolder_status_map, +) +from ai.backend.manager.models.group import AssocGroupUserRow +from ai.backend.manager.models.utils import ( + ExtendedAsyncSAEngine, + execute_with_txn_retry, +) +from ai.backend.manager.models.vfolder import ( + GroupRow, + UserRow, + VFolderInvitationRow, + VFolderPermissionRow, + VFolderRow, +) + +_P = ParamSpec("_P") +_TQueryResult = TypeVar("_TQueryResult") + + +class VFolderRepository: + _db: ExtendedAsyncSAEngine + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db = db + + async def get_group_type(self, group_id: uuid.UUID) -> ProjectType: + async with self._db.begin_session() as sess: + query = sa.select(GroupRow).where(GroupRow.id == group_id) + group_row = await sess.scalar(query) + + if group_row is None: + raise GroupNotFound(extra_data=group_id) + + return group_row.type + + async def get_user_container_id(self, user_id: uuid.UUID) -> Optional[int]: + async with self._db.begin_session() as sess: + query = sa.select(UserRow.container_uid).where(UserRow.uuid == user_id) + result = await sess.scalar(query) + return result + + async def get_created_vfolder_count( + self, owner_id: uuid.UUID, ownership_type: VFolderOwnershipType + ) -> int: + async with self._db.begin_session() as sess: + if ownership_type == VFolderOwnershipType.USER: + ownership_type_caluse = VFolderRow.user == owner_id + else: + ownership_type_caluse = VFolderRow.group == owner_id + + query = ( + sa.select(sa.func.count()) + .select_from(VFolderRow) + .where( + (ownership_type_caluse) + & (VFolderRow.status.not_in(HARD_DELETED_VFOLDER_STATUSES)) + ) + ) + result = await sess.scalar(query) + + return result + + async def get_user_vfolder_resource_limit( + self, user_identity: UserIdentity + ) -> VFolderResourceLimit: + async with self._db.begin_session() as sess: + query = ( + sa.select(UserRow) + .where(UserRow.uuid == user_identity.user_uuid) + .options(selectinload(UserRow.resource_policy_row)) + ) + user_row = await sess.scalar(query) + + if user_row is None: + raise UserNotFound(extra_data=user_identity.user_uuid) + + max_vfolder_count = user_row.resource_policy_row.max_vfolder_count + max_quota_scope_size = user_row.resource_policy_row.max_quota_scope_size + + return VFolderResourceLimit( + max_vfolder_count=max_vfolder_count, + max_quota_scope_size=max_quota_scope_size, + ) + + async def get_group_vfolder_resource_limit( + self, user_identity: UserIdentity, group_id: uuid.UUID + ) -> VFolderResourceLimit: + async with self._db.begin_session() as sess: + query = ( + sa.select(GroupRow) + .where( + (GroupRow.domain_name == user_identity.domain_name) & (GroupRow.id == group_id) + ) + .options(selectinload(GroupRow.resource_policy_row)) + ) + group_row = await sess.scalar(query) + + if group_row is None: + raise GroupNotFound(extra_data=group_id) + + max_vfolder_count = group_row.resource_policy_row.max_vfolder_count + max_quota_scope_size = group_row.resource_policy_row.max_quota_scope_size + + return VFolderResourceLimit( + max_vfolder_count=max_vfolder_count, + max_quota_scope_size=max_quota_scope_size, + ) + + async def persist_vfolder_metadata(self, metadata: VFolderMetadataToCreate) -> VFolderItem: + async with self._db.begin_session() as sess: + insert_query = sa.insert(VFolderRow).values(metadata.to_dict()).returning(VFolderRow.id) + vfolder_id = await sess.scalar(insert_query) + + query = sa.select(VFolderRow).where(VFolderRow.id == vfolder_id) + vfolder: VFolderRow = await sess.scalar(query) + vfolder_item = VFolderItem.from_orm(orm=vfolder, is_owner=True) + return vfolder_item + + async def create_vfolder_permission( + self, + user_id: uuid.UUID, + vfolder_id: uuid.UUID, + permission: VFolderPermission = VFolderPermission.OWNER_PERM, + ) -> None: + async with self._db.begin_session() as sess: + insert_value: dict[str, Any] = { + "user": user_id, + "vfolder": vfolder_id.hex, + "permission": permission, + } + + stmt = sa.insert(VFolderPermissionRow).values(insert_value) + await sess.execute(stmt) + + async def get_accessible_folders( + self, + user_identity: UserIdentity, + allowed_vfolder_types: list[str], + group_id: Optional[uuid.UUID] = None, + ) -> list[VFolderItem]: + all_entries: list[VFolderItem] = [] + async with self._db.begin_session() as sess: + if "user" in allowed_vfolder_types: + user_type_vfolders: list[ + VFolderItem + ] = await self._query_accessible_user_type_vfolders( + db_session=sess, user_identity=user_identity, group_id=group_id + ) + all_entries.extend(user_type_vfolders) + + if "group" in allowed_vfolder_types: + group_type_vfolders: list[ + VFolderItem + ] = await self._query_accessible_group_type_vfolders( + db_session=sess, user_identity=user_identity, group_id=group_id + ) + + all_entries.extend(group_type_vfolders) + + return all_entries + + async def _query_accessible_user_type_vfolders( + self, db_session: SASession, user_identity: UserIdentity, group_id: Optional[uuid.UUID] + ) -> list[VFolderItem]: + owned: list[VFolderItem] = await self._query_owned_vfolders( + db_session=db_session, user_identity=user_identity, group_id=group_id + ) + shared: list[VFolderItem] = await self._query_shared_vfolders( + db_session=db_session, user_identity=user_identity + ) + + return [*owned, *shared] + + async def _query_accessible_group_type_vfolders( + self, db_session: SASession, user_identity: UserIdentity, group_id: Optional[uuid.UUID] + ) -> list[VFolderItem]: + if group_id is not None: + return await self._query_specific_group_vfolders( + db_session=db_session, user_identity=user_identity, group_id=group_id + ) + + return await self._query_all_accessible_group_vfolders( + db_session=db_session, user_identity=user_identity + ) + + async def _query_specific_group_vfolders( + self, + db_session: SASession, + user_identity: UserIdentity, + group_id: uuid.UUID, + ) -> list[VFolderItem]: + if user_identity.is_admin: + # check if group belongs to admin's domain + domain_check_query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where( + (GroupRow.id == group_id) & (GroupRow.domain_name == user_identity.domain_name) + ) + ) + if await db_session.scalar(domain_check_query) is None: + raise GroupNotFound( + extra_msg=f"group {group_id} does not belong to domain {user_identity.domain_name}" + ) + + if user_identity.is_normal_user: + # check if user is in the group + membership_query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from(AssocGroupUserRow) + .where( + (AssocGroupUserRow.user_id == user_identity.user_uuid) + & (AssocGroupUserRow.group_id == group_id) + ) + ) + if await db_session.scalar(membership_query) is None: + raise GroupNotFound( + extra_msg=f"user {user_identity.user_uuid} is not a member of group {group_id}" + ) + + query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where( + ( + (VFolderRow.group == group_id) + | (VFolderRow.user.isnot(None)) + & VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE]) + ) + ) + ) + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + + entries = [ + VFolderItem.from_orm( + orm=vfolder, + is_owner=user_identity.has_privilege_role, + include_relations=True, + override_with_group_member_permission=True, + ) + for vfolder in vfolders + ] + + return entries + + async def _query_all_accessible_group_vfolders( + self, + db_session: SASession, + user_identity: UserIdentity, + ) -> list[VFolderItem]: + base_query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + + if user_identity.is_superadmin: + query = base_query + elif user_identity.is_admin: + query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where(GroupRow.domain_name == user_identity.domain_name) + ) + group_ids = await db_session.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) + else: + query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from( + AssocGroupUserRow.join(UserRow, AssocGroupUserRow.user_id == UserRow.uuid) + ) + .where(AssocGroupUserRow.user_id == user_identity.user_uuid) + ) + group_ids = await db_session.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) + + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm( + orm=vfolder, is_owner=user_identity.has_privilege_role, include_relations=True + ) + for vfolder in vfolders + ] + + return entries + + async def _query_owned_vfolders( + self, + db_session: SASession, + user_identity: UserIdentity, + group_id: Optional[uuid.UUID] = None, + ) -> list[VFolderItem]: + user_join = VFolderRow.join(UserRow, VFolderRow.user == UserRow.uuid) + + query = ( + sa.select(VFolderRow) + .select_from(user_join) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + # If group id is provided, filter user owned vfolders that are in certain group + if group_id is not None: + query = query.where((VFolderRow.group == group_id) | (VFolderRow.user.isnot(None))) + + if user_identity.is_normal_user: + query = query.where(VFolderRow.user == user_identity.user_uuid) + + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=True, include_relations=True) + for vfolder in vfolders + ] + + return entries + + async def _query_shared_vfolders( + self, + db_session: SASession, + user_identity: UserIdentity, + ) -> list[VFolderItem]: + shared_join = VFolderRow.join( + VFolderPermissionRow, + VFolderRow.id == VFolderPermissionRow.vfolder, + isouter=True, + ).join( + UserRow, + VFolderRow.user == UserRow.uuid, + isouter=True, + ) + + query = ( + sa.select(VFolderRow) + .select_from(shared_join) + .where( + (VFolderPermissionRow.user == user_identity.user_uuid) + & (VFolderRow.ownership_type == VFolderOwnershipType.USER) + & (VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + ) + + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=False, include_relations=True) + for vfolder in vfolders + ] + + return entries + + async def patch_vfolder_name(self, vfolder_id: uuid.UUID, new_name: str) -> None: + async with self._db.begin_session() as sess: + stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).values(name=new_name) + await sess.execute(stmt) + + async def _delete_vfolder_permission_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + stmt = sa.delete(VFolderPermissionRow).where( + VFolderPermissionRow.vfolder.in_(vfolder_row_ids) + ) + await db_session.execute(stmt) + + async def _delete_vfolder_invitation_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + stmt = sa.delete(VFolderInvitationRow).where( + VFolderInvitationRow.vfolder.in_(vfolder_row_ids) + ) + await db_session.execute(stmt) + + async def _delete_vfolder_relation_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + await self._delete_vfolder_invitation_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) + await self._delete_vfolder_permission_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) + + async def _update_vfolder_status( + self, + db_session: SASession, + vfolder_id: uuid.UUID, + vfolder_status: VFolderOperationStatus, + ) -> None: + stmt = ( + sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).values(status=vfolder_status) + ) + await db_session.execute(stmt) + + """ + NOTICE: _retry method must be used in top level function + """ + + async def _retry( + self, + func: Callable[[SASession], Awaitable[_TQueryResult]], + ) -> None: + async with self._db.connect() as conn: + await execute_with_txn_retry( + txn_func=func, begin_trx=self._db.begin_session, connection=conn + ) + + async def delete_vfolder_by_id( + self, + vfolder_id: uuid.UUID, + ) -> None: + vfolder_ids = [vfolder_id] + + async def _delete_and_update(db_session: SASession) -> None: + await self._delete_vfolder_relation_rows( + db_session=db_session, vfolder_row_ids=vfolder_ids + ) + await self._update_vfolder_status( + db_session=db_session, + vfolder_id=vfolder_id, + vfolder_status=VFolderOperationStatus.DELETE_PENDING, + ) + + await self._retry(func=_delete_and_update) diff --git a/src/ai/backend/manager/data/vfolder/dto.py b/src/ai/backend/manager/data/vfolder/dto.py index 14cf862c38c..7425633a726 100644 --- a/src/ai/backend/manager/data/vfolder/dto.py +++ b/src/ai/backend/manager/data/vfolder/dto.py @@ -1,5 +1,5 @@ import uuid -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Mapping, Optional, Self from ai.backend.common.dto.manager.context import KeypairCtx, UserIdentityCtx @@ -8,10 +8,12 @@ ) from ai.backend.common.dto.manager.request import VFolderCreateReq from ai.backend.common.types import VFolderUsageMode -from ai.backend.manager.models import ( +from ai.backend.manager.models.user import UserRole +from ai.backend.manager.models.vfolder import ( VFolderOperationStatus, VFolderOwnershipType, VFolderPermission, + VFolderRow, ) @@ -31,6 +33,22 @@ def from_ctx(cls, ctx: UserIdentityCtx) -> Self: domain_name=ctx.domain_name, ) + @property + def is_admin(self) -> bool: + return self.user_role == UserRole.ADMIN + + @property + def is_superadmin(self) -> bool: + return self.user_role == UserRole.SUPERADMIN + + @property + def has_privilege_role(self) -> bool: + return (self.user_role == UserRole.ADMIN) or (self.user_role == UserRole.SUPERADMIN) + + @property + def is_normal_user(self) -> bool: + return (self.user_role != UserRole.ADMIN) and (self.user_role != UserRole.SUPERADMIN) + @dataclass class Keypair: @@ -68,6 +86,26 @@ def from_request(cls, request: VFolderCreateReq) -> Self: ) +@dataclass +class VFolderMetadataToCreate: + name: str + domain_name: str + quota_scope_id: str + usage_mode: VFolderUsageMode + permission: VFolderPermission + host: str + creator: str + ownership_type: VFolderOwnershipType + cloneable: bool + user: str | None = None + group: str | None = None + unmanaged_path: str | None = None + status: VFolderOperationStatus = VFolderOperationStatus.READY + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + @dataclass class VFolderItem: id: str @@ -85,12 +123,58 @@ class VFolderItem: cloneable: bool status: VFolderOperationStatus is_owner: bool - user_email: str - group_name: str + user_email: Optional[str] + group_name: Optional[str] type: str # legacy max_files: int cur_size: int + @classmethod + def from_orm( + cls, + orm: VFolderRow, + is_owner, + include_relations=False, + override_with_group_member_permission=False, + ): + user_email = orm.user_row.email if include_relations and orm.user_row else None + group_name = orm.group_row.name if include_relations and orm.group_row else None + permission = ( + orm.permission_rows.permission + if override_with_group_member_permission + else orm.permission + ) + type = ( + VFolderOwnershipType.USER + if orm.ownership_type == VFolderOwnershipType.USER + else VFolderOwnershipType.GROUP + ) + user = str(orm.user) if type == VFolderOwnershipType.USER else None + group = str(orm.group) if type == VFolderOwnershipType.GROUP else None + + return cls( + id=orm.id.hex, + name=orm.name, + quota_scope_id=orm.quota_scope_id, + host=orm.host, + usage_mode=orm.usage_mode, + created_at=orm.created_at.isoformat(), + permission=permission, + max_size=orm.max_size, + creator=orm.creator, + ownership_type=orm.ownership_type, + user=user, + group=group, + cloneable=orm.cloneable, + status=orm.status, + is_owner=is_owner, + user_email=user_email, + group_name=group_name, + type=type, + max_files=orm.max_files, + cur_size=orm.cur_size, + ) + def to_field(self) -> VFolderItemField: return VFolderItemField( id=self.id, @@ -114,3 +198,9 @@ def to_field(self) -> VFolderItemField: max_files=self.max_files, cur_size=self.cur_size, ) + + +@dataclass +class VFolderResourceLimit: + max_vfolder_count: int + max_quota_scope_size: int diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py index 1a74190b6a2..b58dc894fbf 100644 --- a/src/ai/backend/manager/models/vfolder.py +++ b/src/ai/backend/manager/models/vfolder.py @@ -41,9 +41,9 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.defs import MODEL_VFOLDER_LENGTH_LIMIT from ai.backend.common.dto.manager.field import ( + VFolderMountPermissionField, VFolderOperationStatusField, VFolderOwnershipTypeField, - VFolderPermissionField, ) from ai.backend.common.types import ( MountPermission, @@ -177,8 +177,8 @@ class VFolderPermission(enum.StrEnum): RW_DELETE = "wd" OWNER_PERM = "wd" # resolved as RW_DELETE - def to_field(self) -> VFolderPermissionField: - return VFolderPermissionField(self) + def to_field(self) -> VFolderMountPermissionField: + return VFolderMountPermissionField(self) class VFolderPermissionValidator(t.Trafaret): diff --git a/tests/manager/api/vfolders/conftest.py b/tests/manager/api/vfolders/conftest.py index 9618f350832..cece4b0358b 100644 --- a/tests/manager/api/vfolders/conftest.py +++ b/tests/manager/api/vfolders/conftest.py @@ -1,9 +1,37 @@ import uuid +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable, Optional from unittest.mock import MagicMock import pytest +import sqlalchemy as sa from pydantic import BaseModel +from ai.backend.common.types import ( + QuotaScopeID, + QuotaScopeType, + VFolderHostPermission, + VFolderUsageMode, +) +from ai.backend.manager.data.vfolder.dto import UserIdentity +from ai.backend.manager.models import ( + DomainRow, + GroupRow, + ProjectResourcePolicyRow, + ProjectType, + UserResourcePolicyRow, + UserRole, + UserRow, + UserStatus, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, +) +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.models.vfolder import ( + VFolderRow, +) + @pytest.fixture def mock_authenticated_request(): @@ -33,3 +61,223 @@ class TestResponse(BaseModel): @pytest.fixture def mock_success_response() -> TestResponse: return TestResponse(test="response") + + +@pytest.fixture +async def create_user_resource_policy( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[[str], Awaitable[str]]: + async def _create_user_resource_policy( + name: str, + max_vfolder_count: int = 0, + max_quota_scope_size: int = -1, + max_session_count_per_model_session: int = 5, + max_customized_image_count: int = 3, + ) -> str: + async with database_engine.begin() as conn: + policy_data = { + "name": name, + "max_vfolder_count": max_vfolder_count, + "max_quota_scope_size": max_quota_scope_size, + "max_session_count_per_model_session": max_session_count_per_model_session, + "max_customized_image_count": max_customized_image_count, + } + await conn.execute( + sa.insert(UserResourcePolicyRow) + .values(policy_data) + .returning(UserResourcePolicyRow) + ) + return name + + return _create_user_resource_policy + + +@pytest.fixture +async def create_project_resource_policy( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[[str], Awaitable[str]]: + async def _create_project_resource_policy( + name: str, + max_vfolder_count: int = 0, + max_quota_scope_size: int = -1, + max_network_count: int = 0, + ) -> str: + async with database_engine.begin() as conn: + policy_data = { + "name": name, + "max_vfolder_count": max_vfolder_count, + "max_quota_scope_size": max_quota_scope_size, + "max_network_count": max_network_count, + } + await conn.execute( + sa.insert(ProjectResourcePolicyRow) + .values(policy_data) + .returning(ProjectResourcePolicyRow) + ) + return name + + return _create_project_resource_policy + + +@pytest.fixture +async def create_domain( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[str]]: + async def _create_domain(name: str = "test-domain") -> str: + async with database_engine.begin() as conn: + domain_name = name + domain_data: dict[str, Any] = { + "name": domain_name, + "description": f"Test Domain for {name}", + "is_active": True, + "total_resource_slots": {}, + "allowed_vfolder_hosts": { + "local": [VFolderHostPermission.CREATE], + }, + "allowed_docker_registries": [], + "integration_id": None, + } + await conn.execute(sa.insert(DomainRow).values(domain_data).returning(DomainRow)) + return domain_name + + return _create_domain + + +@pytest.fixture +async def create_user_with_role( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + """ + NOTICE: To use 'default' resource policy, you must use `database_fixture` concurrently in test function + """ + + async def _create_user( + domain_name: str, + role: UserRole, + name: str, + container_uid: Optional[int] = 1000, + resource_policy_name: str = "default", + ) -> uuid.UUID: + async with database_engine.begin() as conn: + user_id = uuid.uuid4() + username = name + user_data = { + "uuid": user_id, + "username": username, + "email": f"{username}@test.com", + "password": "sample_password", + "need_password_change": False, + "full_name": "Sample User", + "description": "Test user", + "status": UserStatus.ACTIVE, + "status_info": None, + "domain_name": domain_name, + "role": role, + "resource_policy": resource_policy_name, + "totp_activated": False, + "sudo_session_enabled": False, + "container_uid": container_uid, + } + await conn.execute(sa.insert(UserRow).values(user_data)) + await conn.execute(sa.select(UserRow).where(UserRow.uuid == user_id)) + return user_id + + return _create_user + + +@pytest.fixture +async def create_identity( + create_user_with_role: Callable, +) -> Callable[..., Awaitable[tuple[uuid.UUID, UserIdentity]]]: + """ + NOTICE: To use 'default' resource policy in create_user_with_role, you must use `database_fixture` concurrently in test function + """ + + async def _create_identity( + domain_name: str, role: UserRole, name: str, resource_policy_name: str = "default" + ) -> tuple[uuid.UUID, UserIdentity]: + user_id = await create_user_with_role(domain_name, role, name, resource_policy_name) + identity = UserIdentity( + user_uuid=user_id, user_role=role, domain_name=domain_name, user_email="test@email.com" + ) + return user_id, identity + + return _create_identity + + +@pytest.fixture +async def create_vfolder( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + async def _create_vfolder( + domain_name: str, + user_id: Optional[uuid.UUID], + group_id: Optional[uuid.UUID], + name: str, + ) -> uuid.UUID: + assert (user_id is not None) or (group_id is not None) + quota_scope_type = QuotaScopeType.USER if user_id else QuotaScopeType.PROJECT + scope_id = user_id if user_id is not None else group_id + assert scope_id is not None + quota_scope_id = str(QuotaScopeID(quota_scope_type, scope_id)) + async with database_engine.begin() as conn: + vfolder_id = uuid.uuid4() + vfolder_data = { + "id": vfolder_id, + "host": "local", + "name": name, + "domain_name": domain_name, + "user": user_id, + "group": group_id, + "quota_scope_id": quota_scope_id, + "usage_mode": VFolderUsageMode.GENERAL, + "permission": VFolderPermission.READ_WRITE, + "ownership_type": VFolderOwnershipType.USER + if user_id + else VFolderOwnershipType.GROUP, + "status": VFolderOperationStatus.READY, + "max_files": 1000, + "max_size": 1024 * 1024, + "created_at": datetime.now(timezone.utc), + "last_used": None, + "cloneable": True, + } + await conn.execute(sa.insert(VFolderRow).values(vfolder_data)) + await conn.execute(sa.select(VFolderRow).where(VFolderRow.id == vfolder_id)) + return vfolder_id + + return _create_vfolder + + +@pytest.fixture +async def create_group( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + """ + NOTICE: To use 'default' resource policy, you must use `database_fixture` concurrently in test function + """ + + async def _create_group( + domain_name: str, + name: str, + type: ProjectType = ProjectType.GENERAL, + resource_policy_name: str = "default", + ) -> uuid.UUID: + async with database_engine.begin() as conn: + group_id = uuid.uuid4() + group_data = { + "id": group_id, + "name": name, + "description": "Test group", + "is_active": True, + "domain_name": domain_name, + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "resource_policy": resource_policy_name, + "type": type, + } + await conn.execute(sa.insert(GroupRow).values(group_data)) + await conn.execute(sa.select(GroupRow).where(GroupRow.id == group_id)) + return group_id + + return _create_group diff --git a/tests/manager/api/vfolders/test_repositories.py b/tests/manager/api/vfolders/test_repositories.py new file mode 100644 index 00000000000..ee6695fc318 --- /dev/null +++ b/tests/manager/api/vfolders/test_repositories.py @@ -0,0 +1,496 @@ +import uuid +from datetime import datetime, timezone +from typing import Callable + +import pytest +import sqlalchemy as sa + +from ai.backend.common.types import ( + QuotaScopeID, + QuotaScopeType, + VFolderUsageMode, +) +from ai.backend.manager.api.exceptions import GroupNotFound, UserNotFound +from ai.backend.manager.api.vfolders.repositories import VFolderRepository +from ai.backend.manager.data.vfolder.dto import UserIdentity, VFolderItem, VFolderMetadataToCreate +from ai.backend.manager.models import ( + ProjectType, + UserRole, + VFolderInvitationState, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, +) +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.models.vfolder import VFolderInvitationRow, VFolderPermissionRow, VFolderRow + +""" +This test file also use fixtures created in `database_fixture` in tests/conftest.py, especially when using 'default' resource policy. +""" + + +@pytest.mark.parametrize("project_type", [ProjectType.GENERAL, ProjectType.MODEL_STORE]) +@pytest.mark.asyncio +async def test_get_group_type( + database_fixture, + create_domain: Callable, + create_group: Callable, + database_engine: ExtendedAsyncSAEngine, + project_type, +): + # Given + domain_name = await create_domain(name="test_add_permission") + group_id = await create_group( + domain_name=domain_name, name="test_get_group_type", type=project_type + ) + + # When + vfolder_repository = VFolderRepository(db=database_engine) + group_type = await vfolder_repository.get_group_type(group_id=group_id) + + # Then + assert isinstance(group_type, ProjectType) + assert group_type == project_type + + +@pytest.mark.asyncio +async def test_get_container_id( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_user_with_role: Callable, + create_domain: Callable, +): + # Given + domain_name = await create_domain(name="test_get_container_id") + expected_container_id = 1234 + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test_get_container_id", + container_uid=expected_container_id, + ) + + # When + repo = VFolderRepository(database_engine) + actual_container_id = await repo.get_user_container_id(user_id=user_id) + + # Then + assert actual_container_id == expected_container_id + + +@pytest.mark.asyncio +async def test_get_created_vfolder_count( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_vfolder: Callable, + create_group: Callable, +): + # Given + domain_name = await create_domain(name="test_get_created_vfolder_count") + user_id = await create_user_with_role( + domain_name=domain_name, role=UserRole.USER, name="test_get_created_vfolder_count" + ) + group_id = await create_group(domain_name=domain_name, name="test_get_created_vfolder_count") + + # Create 3 user vfolders and delete 1 + await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="user-vfolder-1" + ) + await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="user-vfolder-2" + ) + deleted_vfolder_id = await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="deleted-vfolder" + ) + async with database_engine.begin() as conn: + await conn.execute( + sa.update(VFolderRow) + .where(VFolderRow.id == deleted_vfolder_id) + .values(status=VFolderOperationStatus.DELETE_COMPLETE) + ) + + # Create 3 group vfolders + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-1" + ) + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-2" + ) + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-3" + ) + + # When & Then + vfolder_repository = VFolderRepository(db=database_engine) + user_count = await vfolder_repository.get_created_vfolder_count( + user_id, VFolderOwnershipType.USER + ) + assert user_count == 2 + + group_count = await vfolder_repository.get_created_vfolder_count( + group_id, VFolderOwnershipType.GROUP + ) + assert group_count == 3 + + +@pytest.mark.asyncio +async def test_persist_vfolder_metadata( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, +): + # Given + domain_name = await create_domain(name="test_persist_vfolder_metadata") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test_persist_vfolder_metadata", + ) + quota_scope_id = QuotaScopeID(QuotaScopeType.USER, user_id) + host = "local" + + vfolder_name = "test_persist_vfolder_metadata" + metadata = VFolderMetadataToCreate( + name=vfolder_name, + domain_name=domain_name, + quota_scope_id=str(quota_scope_id), + usage_mode=VFolderUsageMode.GENERAL, + permission=VFolderPermission.READ_WRITE, + host=host, + creator=str(user_id), + ownership_type=VFolderOwnershipType.USER, + cloneable=True, + ) + + # When + vfolder_repository = VFolderRepository(db=database_engine) + vfolder_item: VFolderItem = await vfolder_repository.persist_vfolder_metadata(metadata=metadata) + + # Then + assert vfolder_item.name == vfolder_name + assert vfolder_item.host == host + assert vfolder_item.ownership_type == VFolderOwnershipType.USER + assert vfolder_item.quota_scope_id == quota_scope_id + assert vfolder_item.permission == VFolderPermission.READ_WRITE + assert vfolder_item.creator == str(user_id) + assert vfolder_item.cloneable == True # noqa: E712 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "policy_name,vfolder_count,quota_size", + [ + ("test-user-policy-1", 10, 100000), + ("test-user-policy-2", 20, 200000), + ("test-user-policy-3", 0, -1), + ], +) +async def test_get_user_vfolder_resource_limit( + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_user_resource_policy: Callable, + policy_name: str, + vfolder_count: int, + quota_size: int, +): + # Given + domain_name = await create_domain(name=f"test-domain-{policy_name}") + policy_name = await create_user_resource_policy( + name=policy_name, + max_vfolder_count=vfolder_count, + max_quota_scope_size=quota_size, + ) + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name=f"test-user-{policy_name}", + resource_policy_name=policy_name, + ) + + # When + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + vfolder_repository = VFolderRepository(db=database_engine) + resource_limit = await vfolder_repository.get_user_vfolder_resource_limit(user_identity) + + # Then + assert resource_limit.max_vfolder_count == vfolder_count + assert resource_limit.max_quota_scope_size == quota_size + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "policy_name,vfolder_count,quota_size", + [ + ("test-project-policy-1", 20, 20000), + ("test-project-policy-2", 50, 50000), + ("test-project-policy-3", 0, -1), + ], +) +async def test_get_group_vfolder_resource_limit( + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_group: Callable, + create_project_resource_policy: Callable, + create_user_resource_policy: Callable, + policy_name: str, + vfolder_count: int, + quota_size: int, +): + # Given + domain_name = await create_domain(name=f"test-domain-{policy_name}") + + # Create user resource policy for the user + user_policy_name = f"user-{policy_name}" + await create_user_resource_policy( + name=user_policy_name, + max_vfolder_count=10, + max_quota_scope_size=100000, + ) + + # Create project resource policy for the group + project_policy_name = await create_project_resource_policy( + name=policy_name, + max_vfolder_count=vfolder_count, + max_quota_scope_size=quota_size, + ) + + # Create user with user resource policy + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name=f"test-user-{policy_name}", + resource_policy_name=user_policy_name, + ) + + # Create group with project resource policy + group_id = await create_group( + domain_name=domain_name, + name=f"test-group-{policy_name}", + resource_policy_name=project_policy_name, + ) + + # When + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + vfolder_repository = VFolderRepository(db=database_engine) + resource_limit = await vfolder_repository.get_group_vfolder_resource_limit( + user_identity=user_identity, + group_id=group_id, + ) + + # Then + assert resource_limit.max_vfolder_count == vfolder_count + assert resource_limit.max_quota_scope_size == quota_size + + +@pytest.mark.asyncio +async def test_get_user_vfolder_resource_limit_not_found( + database_fixture, + database_engine: ExtendedAsyncSAEngine, +): + # Given + non_exist_user_id = uuid.uuid4() + user_identity = UserIdentity( + user_uuid=non_exist_user_id, + domain_name="test-domain", + user_role=UserRole.USER, + user_email="test@example.com", + ) + + # When & Then + vfolder_repository = VFolderRepository(db=database_engine) + with pytest.raises(UserNotFound): + await vfolder_repository.get_user_vfolder_resource_limit(user_identity) + + +@pytest.mark.asyncio +async def test_get_group_vfolder_resource_limit_not_found( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, +): + # Given + domain_name = await create_domain("test-domain") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user", + ) + non_exist_group_id = uuid.uuid4() + + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + + # When & Then + vfolder_repository = VFolderRepository(db=database_engine) + with pytest.raises(GroupNotFound): + await vfolder_repository.get_group_vfolder_resource_limit( + user_identity=user_identity, + group_id=non_exist_group_id, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "old_name,new_name", + [ + ("test-folder", "renamed-folder"), + ("before_change", "after_change"), + ("test-1", "test-2"), + ], +) +async def test_patch_vfolder_name( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_vfolder: Callable, + old_name: str, + new_name: str, +): + # Given + domain_name = await create_domain(name="test-patch-vfolder-name") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user", + ) + vfolder_id = await create_vfolder( + domain_name=domain_name, + user_id=user_id, + group_id=None, + name=old_name, + ) + + # When + vfolder_repository = VFolderRepository(db=database_engine) + await vfolder_repository.patch_vfolder_name(vfolder_id, new_name) + + # Then + async with database_engine.begin() as conn: + result = await conn.execute(sa.select(VFolderRow.name).where(VFolderRow.id == vfolder_id)) + updated_name = result.scalar() + assert updated_name == new_name + + +@pytest.mark.asyncio +async def test_delete_vfolder_by_id( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_group: Callable, + create_vfolder: Callable, +): + # Given + domain_name = await create_domain(name="test_delete_vfolder_by_id") + user1_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user1", + ) + user1_email = "test-user1@test.com" + + user2_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user2", + ) + user2_email = "test-user2@test.com" + + group_id = await create_group( + domain_name=domain_name, + name="test-group", + ) + + # Create a vfolder owned by user1 + vfolder_id = await create_vfolder( + domain_name=domain_name, + user_id=user1_id, + group_id=group_id, + name="test-folder", + ) + + # Create permissions and invitations between users (invitation from user1 to user2) + async with database_engine.begin_session() as sess: + perm_data = { + "user": user2_id, + "vfolder": vfolder_id, + "permission": VFolderPermission.READ_WRITE, + } + await sess.execute(sa.insert(VFolderPermissionRow).values(perm_data)) + + invite_data = { + "id": uuid.uuid4(), + "vfolder": vfolder_id, + "inviter": user1_email, + "invitee": user2_email, + "created_at": datetime.now(timezone.utc), + "state": VFolderInvitationState.ACCEPTED, + } + await sess.execute(sa.insert(VFolderInvitationRow).values(invite_data)) + + # Check permission and invitation count before delete + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderPermissionRow) + .where(VFolderPermissionRow.vfolder == vfolder_id) + ) + permission_count = result.scalar() + assert permission_count == 1 + + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderInvitationRow) + .where(VFolderInvitationRow.vfolder == vfolder_id) + ) + invite_count = result.scalar() + assert invite_count == 1 + + # When + vfolder_repository = VFolderRepository(db=database_engine) + await vfolder_repository.delete_vfolder_by_id(vfolder_id=vfolder_id) + + # Then + async with database_engine.begin_session() as sess: + # Check vfolder status + result = await sess.execute(sa.select(VFolderRow.status).where(VFolderRow.id == vfolder_id)) + status = result.scalar() + assert status == VFolderOperationStatus.DELETE_PENDING + + # Check if permissions are deleted + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderPermissionRow) + .where(VFolderPermissionRow.vfolder == vfolder_id) + ) + permission_count = result.scalar() + assert permission_count == 0 + + # Check if invitations are deleted + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderInvitationRow) + .where(VFolderInvitationRow.vfolder == vfolder_id) + ) + invite_count = result.scalar() + assert invite_count == 0