From 9c9ee04331e347aba4053c316cc8659eb4215866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 15:09:30 +0900 Subject: [PATCH 01/17] feat: Add vfolder repository in manager --- src/ai/backend/common/dto/manager/field.py | 4 +- src/ai/backend/common/dto/manager/request.py | 4 +- .../manager/api/vfolders/repositories.py | 414 ++++++++++++++++++ src/ai/backend/manager/data/vfolder/dto.py | 98 ++++- 4 files changed, 512 insertions(+), 8 deletions(-) create mode 100644 src/ai/backend/manager/api/vfolders/repositories.py diff --git a/src/ai/backend/common/dto/manager/field.py b/src/ai/backend/common/dto/manager/field.py index 6cc56d79c29..89c9d00ae85 100644 --- a/src/ai/backend/common/dto/manager/field.py +++ b/src/ai/backend/common/dto/manager/field.py @@ -47,8 +47,8 @@ class VFolderItemField(BaseModel): cloneable: bool status: VFolderOperationStatusField is_owner: bool - user_email: str - group_name: str + user_email: Optional[str] + group_name: Optional[str] type: str # legacy max_files: int cur_size: int diff --git a/src/ai/backend/common/dto/manager/request.py b/src/ai/backend/common/dto/manager/request.py index a9835adbc5d..c36250d516f 100644 --- a/src/ai/backend/common/dto/manager/request.py +++ b/src/ai/backend/common/dto/manager/request.py @@ -4,7 +4,7 @@ from pydantic import AliasChoices, BaseModel, Field from ai.backend.common import typed_validators as tv -from ai.backend.common.dto.manager.dto import VFolderPermissionDTO +from ai.backend.common.dto.manager.field import VFolderPermissionField from ai.backend.common.types import VFolderUsageMode @@ -17,7 +17,7 @@ class VFolderCreateReq(BaseModel): default=None, ) usage_mode: VFolderUsageMode = Field(default=VFolderUsageMode.GENERAL) - permission: VFolderPermissionDTO = Field(default=VFolderPermissionDTO.READ_WRITE) + permission: VFolderPermissionField = Field(default=VFolderPermissionField.READ_WRITE) unmanaged_path: Optional[str] = Field( validation_alias=AliasChoices("unmanaged_path", "unmanagedPath"), default=None, diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py new file mode 100644 index 00000000000..c00789d6314 --- /dev/null +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -0,0 +1,414 @@ +import uuid +from contextlib import AbstractAsyncContextManager as AbstractAsyncCtxMgr +from typing import Any, Awaitable, Callable, Iterable, Optional, ParamSpec, TypeVar + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import selectinload + +from ai.backend.manager.api.exceptions import GroupNotFound, UserNotFound +from ai.backend.manager.data.vfolder.dto import ( + UserIdentity, + VFolderItem, + VFolderMetadataToCreate, + VFolderResourceLimit, +) +from ai.backend.manager.models import ( + HARD_DELETED_VFOLDER_STATUSES, + ProjectType, + UserRole, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, + VFolderStatusSet, + vfolder_status_map, +) +from ai.backend.manager.models.group import AssocGroupUserRow +from ai.backend.manager.models.utils import ( + ExtendedAsyncSAEngine, + execute_with_txn_retry, +) +from ai.backend.manager.models.vfolder import ( + GroupRow, + UserRow, + VFolderInvitationRow, + VFolderPermissionRow, + VFolderRow, +) + +_P = ParamSpec("_P") +_TQueryResult = TypeVar("_TQueryResult") + + +class VFolderRepository: + _db: ExtendedAsyncSAEngine + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db = db + + async def get_group_type(self, group_id: uuid.UUID) -> ProjectType: + async with self._db.begin_session() as sess: + query = sa.select(GroupRow).where(GroupRow.id == group_id) + group_row = await sess.scalar(query) + + if group_row is None: + raise GroupNotFound(extra_data=group_id) + + return group_row.type + + async def get_user_container_id(self, user_id: uuid.UUID) -> Optional[int]: + async with self._db.begin_session() as sess: + query = sa.select(UserRow.container_uid).where(UserRow.uuid == user_id) + result = await sess.scalar(query) + return result + + async def get_created_vfolder_count( + self, owner_id: uuid.UUID, ownership_type: VFolderOwnershipType + ) -> int: + async with self._db.begin_session() as sess: + if ownership_type == VFolderOwnershipType.USER: + ownership_type_caluse = VFolderRow.user == owner_id + else: + ownership_type_caluse = VFolderRow.group == owner_id + + query = ( + sa.select([sa.func.count()]) + .select_from(VFolderRow) + .where( + (ownership_type_caluse) + & (VFolderRow.status.not_in(HARD_DELETED_VFOLDER_STATUSES)) + ) + ) + result = await sess.scalar(query) + + return result + + async def get_user_vfolder_resource_limit( + self, user_identity: UserIdentity + ) -> VFolderResourceLimit: + async with self._db.begin_session() as sess: + query = ( + sa.select(UserRow) + .where(UserRow.uuid == user_identity.user_uuid) + .options(selectinload(UserRow.resource_policy_row)) + ) + user_row = await sess.scalar(query) + + if user_row is None: + raise UserNotFound(extra_data=user_identity.user_uuid) + + max_vfolder_count = user_row.resource_policy_row.max_vfolder_count + max_quota_scope_size = user_row.resource_policy_row.max_quota_scope_size + + return VFolderResourceLimit( + max_vfolder_count=max_vfolder_count, + max_quota_scope_size=max_quota_scope_size, + ) + + async def get_group_vfolder_resource_limit( + self, user_identity: UserIdentity, group_id: uuid.UUID + ) -> VFolderResourceLimit: + async with self._db.begin_session() as sess: + query = ( + sa.select(GroupRow) + .where( + (GroupRow.domain_name == user_identity.domain_name) & (GroupRow.id == group_id) + ) + .options(selectinload(GroupRow.resource_policy_row)) + ) + group_row = await sess.scalar(query) + + if group_row is None: + raise GroupNotFound(extra_data=group_id) + + max_vfolder_count = group_row.resource_policy_row.max_vfolder_count + max_quota_scope_size = group_row.resource_policy_row.max_quota_scope_size + + return VFolderResourceLimit( + max_vfolder_count=max_vfolder_count, + max_quota_scope_size=max_quota_scope_size, + ) + + async def persist_vfolder_metadata(self, metadata: VFolderMetadataToCreate) -> VFolderItem: + async with self._db.begin_session() as sess: + query = sa.insert(VFolderRow).values(metadata.to_dict()).returning(VFolderRow) + vfolder: VFolderRow = await sess.scalar(query) + vfolder_item = VFolderItem.from_orm(orm=vfolder, is_owner=True) + return vfolder_item + + async def create_vfolder_permission( + self, + user_id: uuid.UUID, + vfolder_id: uuid.UUID, + permission: VFolderPermission = VFolderPermission.OWNER_PERM, + ) -> None: + async with self._db.begin_session() as sess: + insert_value: dict[str, Any] = { + "user": user_id, + "vfolder": vfolder_id.hex, + "permission": permission, + } + + stmt = sa.insert(VFolderPermissionRow).values(insert_value) + await sess.execute(stmt) + + async def get_accessible_folders( + self, + user_identity: UserIdentity, + allowed_vfolder_types: list[str], + group_id: Optional[uuid.UUID] = None, + ) -> list[VFolderItem]: + all_entries: list[VFolderItem] = [] + if "user" in allowed_vfolder_types: + owned_vfolders = await self.query_owned_vfolders( + user_identity=user_identity, group_id=group_id + ) + all_entries.extend(owned_vfolders) + + shared_vfolders = await self.query_shared_vfolders(user_identity=user_identity) + all_entries.extend(shared_vfolders) + + if "group" in allowed_vfolder_types: + if group_id is not None: + group_vfolders = await self._query_specific_group_vfolders(user_identity, group_id) + else: + group_vfolders = await self._query_all_accessible_group_vfolders(user_identity) + + all_entries.extend(group_vfolders) + + return all_entries + + async def _query_specific_group_vfolders( + self, + user_identity: UserIdentity, + group_id: uuid.UUID, + ) -> list[VFolderItem]: + async with self._db.begin_session() as sess: + if user_identity.is_admin: + # check if group belongs to admin's domain + domain_check_query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where( + (GroupRow.id == group_id) + & (GroupRow.domain_name == user_identity.domain_name) + ) + ) + if await sess.scalar(domain_check_query) is None: + raise GroupNotFound( + extra_msg=f"group {group_id} does not belong to domain {user_identity.domain_name}" + ) + + if user_identity.is_normal_user: + # check if user is in the group + membership_query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from(AssocGroupUserRow) + .where( + (AssocGroupUserRow.user_id == user_identity.user_uuid) + & (AssocGroupUserRow.group_id == group_id) + ) + ) + if await sess.scalar(membership_query) is None: + raise GroupNotFound( + extra_msg=f"user {user_identity.user_uuid} is not a member of group {group_id}" + ) + + query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where( + ( + (VFolderRow.group == group_id) + | (VFolderRow.user.isnot(None)) + & VFolderRow.status.not_in( + vfolder_status_map[VFolderStatusSet.INACCESSIBLE] + ) + ) + ) + ) + vfolders: list[VFolderRow] = (await sess.scalars(query)).all() + + entries = [ + VFolderItem.from_orm( + orm=vfolder, + is_owner=user_identity.has_privilege_role, + include_relations=True, + override_with_group_member_permission=True, + ) + for vfolder in vfolders + ] + + return entries + + async def _query_all_accessible_group_vfolders( + self, + user_identity: UserIdentity, + ) -> list[VFolderItem]: + async with self._db.begin_session() as sess: + base_query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + + if user_identity.is_superadmin: + query = base_query + elif user_identity.is_admin: + query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where(GroupRow.domain_name == user_identity.domain_name) + ) + group_ids = await sess.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) + else: + query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from( + AssocGroupUserRow.join(UserRow, AssocGroupUserRow.user_id == UserRow.uuid) + ) + .where(AssocGroupUserRow.user_id == user_identity.user_uuid) + ) + group_ids = await sess.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) + + vfolders: list[VFolderRow] = (await sess.scalars(query)).all() + entries = [ + VFolderItem.from_orm( + orm=vfolder, is_owner=user_identity.has_privilege_role, include_relations=True + ) + for vfolder in vfolders + ] + + return entries + + async def query_owned_vfolders( + self, user_identity: UserIdentity, group_id: Optional[uuid.UUID] = None + ) -> list[VFolderItem]: + async with self._db.begin_session() as sess: + user_join = VFolderRow.join(UserRow, VFolderRow.user == UserRow.uuid) + + query = ( + sa.select(VFolderRow) + .select_from(user_join) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + # If group id is provided, filter user owned vfolders that are in certain group + if group_id is not None: + query = query.where((VFolderRow.group == group_id) | (VFolderRow.user.isnot(None))) + + if user_identity.user_role not in (UserRole.ADMIN, UserRole.SUPERADMIN): + query = query.where(VFolderRow.user == user_identity.user_uuid) + + vfolders: list[VFolderRow] = (await sess.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=True, include_relations=True) + for vfolder in vfolders + ] + + return entries + + async def query_shared_vfolders( + self, + user_identity: UserIdentity, + ) -> list[VFolderItem]: + async with self._db.begin_session() as sess: + shared_join = VFolderRow.join( + VFolderPermissionRow, + VFolderRow.id == VFolderPermissionRow.vfolder, + isouter=True, + ).join( + UserRow, + VFolderRow.user == UserRow.uuid, + isouter=True, + ) + + query = ( + sa.select(VFolderRow) + .select_from(shared_join) + .where( + (VFolderPermissionRow.user == user_identity.user_uuid) + & (VFolderRow.ownership_type == VFolderOwnershipType.USER) + & (VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + ) + + vfolders: list[VFolderRow] = (await sess.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=False, include_relations=True) + for vfolder in vfolders + ] + + return entries + + async def patch_vFolder_name(self, vfolder_id: uuid.UUID, new_name: str) -> None: + async with self._db.begin_session() as sess: + stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).values(name=new_name) + await sess.execute(stmt) + + async def _delete_vfolder_permission_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + stmt = sa.delete(VFolderInvitationRow).where( + VFolderInvitationRow.vfolder.in_(vfolder_row_ids) + ) + await db_session.execute(stmt) + + async def _retry( + self, + func: Callable[[SASession], Awaitable[_TQueryResult]], + db_session: Callable[..., AbstractAsyncCtxMgr], + ) -> None: + await execute_with_txn_retry( + txn_func=func, begin_trx=db_session, connection=self._db.connect() + ) + + async def _delete_vfolder_invitation_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + stmt = sa.delete(VFolderPermissionRow).where( + VFolderPermissionRow.vfolder.in_(vfolder_row_ids) + ) + await db_session.execute(stmt) + + async def _delete_vfolder_relation_rows( + self, + db_session: SASession, + vfolder_row_ids: Iterable[uuid.UUID], + ) -> None: + async def _delete(db_session: SASession) -> None: + await self._delete_vfolder_invitation_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) + await self._delete_vfolder_permission_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) + + await self._retry(func=_delete, db_session=db_session) + + async def _update_vfolder_status( + self, + db_session: SASession, + vfolder_id: uuid.UUID, + vfolder_status: VFolderOperationStatus, + ) -> None: + stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).value(status=vfolder_status) + await db_session.execute(stmt) + + async def delete_vFolder_by_id( + self, + vfolder_id: uuid.UUID, + ) -> None: + vfolder_ids = [vfolder_id] + async with self._db.begin_session() as sess: + await self._delete_vfolder_relation_rows(db_session=sess, vfolder_row_ids=vfolder_ids) + await self._update_vfolder_status( + db_session=sess, + vfolder_id=vfolder_id, + vfolder_status=VFolderOperationStatus.DELETE_PENDING, + ) diff --git a/src/ai/backend/manager/data/vfolder/dto.py b/src/ai/backend/manager/data/vfolder/dto.py index 14cf862c38c..75d38bddf84 100644 --- a/src/ai/backend/manager/data/vfolder/dto.py +++ b/src/ai/backend/manager/data/vfolder/dto.py @@ -1,5 +1,5 @@ import uuid -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Mapping, Optional, Self from ai.backend.common.dto.manager.context import KeypairCtx, UserIdentityCtx @@ -8,10 +8,12 @@ ) from ai.backend.common.dto.manager.request import VFolderCreateReq from ai.backend.common.types import VFolderUsageMode -from ai.backend.manager.models import ( +from ai.backend.manager.models.user import UserRole +from ai.backend.manager.models.vfolder import ( VFolderOperationStatus, VFolderOwnershipType, VFolderPermission, + VFolderRow, ) @@ -31,6 +33,22 @@ def from_ctx(cls, ctx: UserIdentityCtx) -> Self: domain_name=ctx.domain_name, ) + @property + def is_admin(self) -> bool: + return self.user_role == UserRole.ADMIN + + @property + def is_superadmin(self) -> bool: + return self.user_role == UserRole.SUPERADMIN + + @property + def has_privilege_role(self) -> bool: + return (self.user_role == UserRole.ADMIN) or (self.user_role == UserRole.SUPERADMIN) + + @property + def is_normal_user(self) -> bool: + return (self.user_role != UserRole.ADMIN) and (self.user_role != UserRole.SUPERADMIN) + @dataclass class Keypair: @@ -68,6 +86,26 @@ def from_request(cls, request: VFolderCreateReq) -> Self: ) +@dataclass +class VFolderMetadataToCreate: + name: str + domain_name: str + quota_scope_id: str + usage_mode: VFolderUsageMode + permission: VFolderPermission + host: str + creator: str + ownership_type: VFolderOwnershipType + user: str | None + group: str | None + unmanaged_path: str | None + cloneable: bool + status: VFolderOperationStatus = VFolderOperationStatus.READY + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + @dataclass class VFolderItem: id: str @@ -85,12 +123,58 @@ class VFolderItem: cloneable: bool status: VFolderOperationStatus is_owner: bool - user_email: str - group_name: str + user_email: Optional[str] + group_name: Optional[str] type: str # legacy max_files: int cur_size: int + @classmethod + def from_orm( + cls, + orm: VFolderRow, + is_owner, + include_relations=False, + override_with_group_member_permission=False, + ): + user_email = orm.user_row.email if include_relations and orm.user_row else None + group_name = orm.group_row.name if include_relations and orm.group_row else None + permission = ( + orm.permission_rows.permission + if override_with_group_member_permission + else orm.permission + ) + type = ( + VFolderOwnershipType.USER + if orm.ownership_type == VFolderOwnershipType.USER + else VFolderOwnershipType.GROUP + ) + user = str(orm.user) if type == VFolderOwnershipType.USER else None + group = str(orm.group) if type == VFolderOwnershipType.GROUP else None + + return cls( + id=orm.id.hex, + name=orm.name, + quota_scope_id=orm.quota_scope_id, + host=orm.host, + usage_mode=orm.usage_mode, + created_at=orm.created_at.isoformat(), + permission=permission, + max_size=orm.max_size, + creator=orm.creator, + ownership_type=orm.ownership_type, + user=user, + group=group, + cloneable=orm.cloneable, + status=orm.status, + is_owner=is_owner, + user_email=user_email, + group_name=group_name, + type=type, + max_files=orm.max_files, + cur_size=orm.cur_size, + ) + def to_field(self) -> VFolderItemField: return VFolderItemField( id=self.id, @@ -114,3 +198,9 @@ def to_field(self) -> VFolderItemField: max_files=self.max_files, cur_size=self.cur_size, ) + + +@dataclass +class VFolderResourceLimit: + max_vfolder_count: int + max_quota_scope_size: int From 4183cda92c3807701a04339bae87454b3aad4e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 15:13:19 +0900 Subject: [PATCH 02/17] chore: Change internal used method into private --- src/ai/backend/manager/api/vfolders/repositories.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index c00789d6314..61b47334bf7 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -160,12 +160,12 @@ async def get_accessible_folders( ) -> list[VFolderItem]: all_entries: list[VFolderItem] = [] if "user" in allowed_vfolder_types: - owned_vfolders = await self.query_owned_vfolders( + owned_vfolders = await self._query_owned_vfolders( user_identity=user_identity, group_id=group_id ) all_entries.extend(owned_vfolders) - shared_vfolders = await self.query_shared_vfolders(user_identity=user_identity) + shared_vfolders = await self._query_shared_vfolders(user_identity=user_identity) all_entries.extend(shared_vfolders) if "group" in allowed_vfolder_types: @@ -283,7 +283,7 @@ async def _query_all_accessible_group_vfolders( return entries - async def query_owned_vfolders( + async def _query_owned_vfolders( self, user_identity: UserIdentity, group_id: Optional[uuid.UUID] = None ) -> list[VFolderItem]: async with self._db.begin_session() as sess: @@ -309,7 +309,7 @@ async def query_owned_vfolders( return entries - async def query_shared_vfolders( + async def _query_shared_vfolders( self, user_identity: UserIdentity, ) -> list[VFolderItem]: From 83b1ab60911c8cf32753e3fde0fb17cff34bcfbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 06:27:18 +0000 Subject: [PATCH 03/17] chore: update api schema dump Co-authored-by: octodog --- docs/manager/rest-reference/openapi.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 03b8c65b892..248d84cf02f 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "Backend.AI Manager API", "description": "Backend.AI Manager REST API specification", - "version": "25.1.1", + "version": "25.2.0", "contact": { "name": "Lablup Inc.", "url": "https://docs.backend.ai", From edaa63439ed59a2fb92fa1f8e8d60201d2d67f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 15:28:19 +0900 Subject: [PATCH 04/17] doc: Add changelog of adding repository layer --- changes/3699.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3699.feat.md diff --git a/changes/3699.feat.md b/changes/3699.feat.md new file mode 100644 index 00000000000..6ea21892f78 --- /dev/null +++ b/changes/3699.feat.md @@ -0,0 +1 @@ +Add Repository for VFolder in manager \ No newline at end of file From b784fa8c6aad6baeb3407882ce001b848f0e6a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 16:03:38 +0900 Subject: [PATCH 05/17] fix: Fix wrong changelog number --- changes/{3699.feat.md => 3669.feat.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changes/{3699.feat.md => 3669.feat.md} (100%) diff --git a/changes/3699.feat.md b/changes/3669.feat.md similarity index 100% rename from changes/3699.feat.md rename to changes/3669.feat.md From 3c48af70de454ed379b1b53025e4570a760eb36f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 12 Feb 2025 17:14:11 +0900 Subject: [PATCH 06/17] refactor: remove session begin in private methods --- .../manager/api/vfolders/repositories.py | 256 +++++++++--------- 1 file changed, 131 insertions(+), 125 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index 61b47334bf7..bd54306de6b 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -159,75 +159,79 @@ async def get_accessible_folders( group_id: Optional[uuid.UUID] = None, ) -> list[VFolderItem]: all_entries: list[VFolderItem] = [] - if "user" in allowed_vfolder_types: - owned_vfolders = await self._query_owned_vfolders( - user_identity=user_identity, group_id=group_id - ) - all_entries.extend(owned_vfolders) + async with self._db.begin_session() as sess: + if "user" in allowed_vfolder_types: + owned_vfolders = await self._query_owned_vfolders( + db_session=sess, user_identity=user_identity, group_id=group_id + ) + all_entries.extend(owned_vfolders) - shared_vfolders = await self._query_shared_vfolders(user_identity=user_identity) - all_entries.extend(shared_vfolders) + shared_vfolders = await self._query_shared_vfolders( + db_session=sess, user_identity=user_identity + ) + all_entries.extend(shared_vfolders) - if "group" in allowed_vfolder_types: - if group_id is not None: - group_vfolders = await self._query_specific_group_vfolders(user_identity, group_id) - else: - group_vfolders = await self._query_all_accessible_group_vfolders(user_identity) + if "group" in allowed_vfolder_types: + if group_id is not None: + group_vfolders = await self._query_specific_group_vfolders( + db_session=sess, user_identity=user_identity, group_id=group_id + ) + else: + group_vfolders = await self._query_all_accessible_group_vfolders( + db_session=sess, user_identity=user_identity + ) - all_entries.extend(group_vfolders) + all_entries.extend(group_vfolders) return all_entries async def _query_specific_group_vfolders( self, + db_session: SASession, user_identity: UserIdentity, group_id: uuid.UUID, ) -> list[VFolderItem]: - async with self._db.begin_session() as sess: - if user_identity.is_admin: - # check if group belongs to admin's domain - domain_check_query = ( - sa.select(GroupRow.id) - .select_from(GroupRow) - .where( - (GroupRow.id == group_id) - & (GroupRow.domain_name == user_identity.domain_name) - ) + if user_identity.is_admin: + # check if group belongs to admin's domain + domain_check_query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where( + (GroupRow.id == group_id) & (GroupRow.domain_name == user_identity.domain_name) ) - if await sess.scalar(domain_check_query) is None: - raise GroupNotFound( - extra_msg=f"group {group_id} does not belong to domain {user_identity.domain_name}" - ) - - if user_identity.is_normal_user: - # check if user is in the group - membership_query = ( - sa.select(AssocGroupUserRow.group_id) - .select_from(AssocGroupUserRow) - .where( - (AssocGroupUserRow.user_id == user_identity.user_uuid) - & (AssocGroupUserRow.group_id == group_id) - ) + ) + if await db_session.scalar(domain_check_query) is None: + raise GroupNotFound( + extra_msg=f"group {group_id} does not belong to domain {user_identity.domain_name}" ) - if await sess.scalar(membership_query) is None: - raise GroupNotFound( - extra_msg=f"user {user_identity.user_uuid} is not a member of group {group_id}" - ) - query = ( - sa.select(VFolderRow) - .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + if user_identity.is_normal_user: + # check if user is in the group + membership_query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from(AssocGroupUserRow) .where( - ( - (VFolderRow.group == group_id) - | (VFolderRow.user.isnot(None)) - & VFolderRow.status.not_in( - vfolder_status_map[VFolderStatusSet.INACCESSIBLE] - ) - ) + (AssocGroupUserRow.user_id == user_identity.user_uuid) + & (AssocGroupUserRow.group_id == group_id) + ) + ) + if await db_session.scalar(membership_query) is None: + raise GroupNotFound( + extra_msg=f"user {user_identity.user_uuid} is not a member of group {group_id}" + ) + + query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where( + ( + (VFolderRow.group == group_id) + | (VFolderRow.user.isnot(None)) + & VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE]) ) ) - vfolders: list[VFolderRow] = (await sess.scalars(query)).all() + ) + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() entries = [ VFolderItem.from_orm( @@ -243,102 +247,104 @@ async def _query_specific_group_vfolders( async def _query_all_accessible_group_vfolders( self, + db_session: SASession, user_identity: UserIdentity, ) -> list[VFolderItem]: - async with self._db.begin_session() as sess: - base_query = ( - sa.select(VFolderRow) - .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) - .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) - ) + base_query = ( + sa.select(VFolderRow) + .select_from(VFolderRow.join(GroupRow, VFolderRow.group == GroupRow.id)) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) - if user_identity.is_superadmin: - query = base_query - elif user_identity.is_admin: - query = ( - sa.select(GroupRow.id) - .select_from(GroupRow) - .where(GroupRow.domain_name == user_identity.domain_name) - ) - group_ids = await sess.scalars(query) - query = base_query.where(VFolderRow.group.in_(group_ids)) - else: - query = ( - sa.select(AssocGroupUserRow.group_id) - .select_from( - AssocGroupUserRow.join(UserRow, AssocGroupUserRow.user_id == UserRow.uuid) - ) - .where(AssocGroupUserRow.user_id == user_identity.user_uuid) + if user_identity.is_superadmin: + query = base_query + elif user_identity.is_admin: + query = ( + sa.select(GroupRow.id) + .select_from(GroupRow) + .where(GroupRow.domain_name == user_identity.domain_name) + ) + group_ids = await db_session.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) + else: + query = ( + sa.select(AssocGroupUserRow.group_id) + .select_from( + AssocGroupUserRow.join(UserRow, AssocGroupUserRow.user_id == UserRow.uuid) ) - group_ids = await sess.scalars(query) - query = base_query.where(VFolderRow.group.in_(group_ids)) + .where(AssocGroupUserRow.user_id == user_identity.user_uuid) + ) + group_ids = await db_session.scalars(query) + query = base_query.where(VFolderRow.group.in_(group_ids)) - vfolders: list[VFolderRow] = (await sess.scalars(query)).all() - entries = [ - VFolderItem.from_orm( - orm=vfolder, is_owner=user_identity.has_privilege_role, include_relations=True - ) - for vfolder in vfolders - ] + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm( + orm=vfolder, is_owner=user_identity.has_privilege_role, include_relations=True + ) + for vfolder in vfolders + ] return entries async def _query_owned_vfolders( - self, user_identity: UserIdentity, group_id: Optional[uuid.UUID] = None + self, + db_session: SASession, + user_identity: UserIdentity, + group_id: Optional[uuid.UUID] = None, ) -> list[VFolderItem]: - async with self._db.begin_session() as sess: - user_join = VFolderRow.join(UserRow, VFolderRow.user == UserRow.uuid) + user_join = VFolderRow.join(UserRow, VFolderRow.user == UserRow.uuid) - query = ( - sa.select(VFolderRow) - .select_from(user_join) - .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) - ) - # If group id is provided, filter user owned vfolders that are in certain group - if group_id is not None: - query = query.where((VFolderRow.group == group_id) | (VFolderRow.user.isnot(None))) + query = ( + sa.select(VFolderRow) + .select_from(user_join) + .where(VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) + ) + # If group id is provided, filter user owned vfolders that are in certain group + if group_id is not None: + query = query.where((VFolderRow.group == group_id) | (VFolderRow.user.isnot(None))) - if user_identity.user_role not in (UserRole.ADMIN, UserRole.SUPERADMIN): - query = query.where(VFolderRow.user == user_identity.user_uuid) + if user_identity.user_role not in (UserRole.ADMIN, UserRole.SUPERADMIN): + query = query.where(VFolderRow.user == user_identity.user_uuid) - vfolders: list[VFolderRow] = (await sess.scalars(query)).all() - entries = [ - VFolderItem.from_orm(orm=vfolder, is_owner=True, include_relations=True) - for vfolder in vfolders - ] + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=True, include_relations=True) + for vfolder in vfolders + ] return entries async def _query_shared_vfolders( self, + db_session: SASession, user_identity: UserIdentity, ) -> list[VFolderItem]: - async with self._db.begin_session() as sess: - shared_join = VFolderRow.join( - VFolderPermissionRow, - VFolderRow.id == VFolderPermissionRow.vfolder, - isouter=True, - ).join( - UserRow, - VFolderRow.user == UserRow.uuid, - isouter=True, - ) + shared_join = VFolderRow.join( + VFolderPermissionRow, + VFolderRow.id == VFolderPermissionRow.vfolder, + isouter=True, + ).join( + UserRow, + VFolderRow.user == UserRow.uuid, + isouter=True, + ) - query = ( - sa.select(VFolderRow) - .select_from(shared_join) - .where( - (VFolderPermissionRow.user == user_identity.user_uuid) - & (VFolderRow.ownership_type == VFolderOwnershipType.USER) - & (VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) - ) + query = ( + sa.select(VFolderRow) + .select_from(shared_join) + .where( + (VFolderPermissionRow.user == user_identity.user_uuid) + & (VFolderRow.ownership_type == VFolderOwnershipType.USER) + & (VFolderRow.status.not_in(vfolder_status_map[VFolderStatusSet.INACCESSIBLE])) ) + ) - vfolders: list[VFolderRow] = (await sess.scalars(query)).all() - entries = [ - VFolderItem.from_orm(orm=vfolder, is_owner=False, include_relations=True) - for vfolder in vfolders - ] + vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() + entries = [ + VFolderItem.from_orm(orm=vfolder, is_owner=False, include_relations=True) + for vfolder in vfolders + ] return entries From b01baff0754c8c0b41ce75553d0b30d2e6de1eac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 09:36:14 +0900 Subject: [PATCH 07/17] fix: Fix _retry method to propery use transaction --- .../manager/api/vfolders/repositories.py | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index bd54306de6b..46264211702 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -1,5 +1,4 @@ import uuid -from contextlib import AbstractAsyncContextManager as AbstractAsyncCtxMgr from typing import Any, Awaitable, Callable, Iterable, Optional, ParamSpec, TypeVar import sqlalchemy as sa @@ -363,15 +362,6 @@ async def _delete_vfolder_permission_rows( ) await db_session.execute(stmt) - async def _retry( - self, - func: Callable[[SASession], Awaitable[_TQueryResult]], - db_session: Callable[..., AbstractAsyncCtxMgr], - ) -> None: - await execute_with_txn_retry( - txn_func=func, begin_trx=db_session, connection=self._db.connect() - ) - async def _delete_vfolder_invitation_rows( self, db_session: SASession, @@ -387,15 +377,12 @@ async def _delete_vfolder_relation_rows( db_session: SASession, vfolder_row_ids: Iterable[uuid.UUID], ) -> None: - async def _delete(db_session: SASession) -> None: - await self._delete_vfolder_invitation_rows( - db_session=db_session, vfolder_row_ids=vfolder_row_ids - ) - await self._delete_vfolder_permission_rows( - db_session=db_session, vfolder_row_ids=vfolder_row_ids - ) - - await self._retry(func=_delete, db_session=db_session) + await self._delete_vfolder_invitation_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) + await self._delete_vfolder_permission_rows( + db_session=db_session, vfolder_row_ids=vfolder_row_ids + ) async def _update_vfolder_status( self, @@ -406,15 +393,32 @@ async def _update_vfolder_status( stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).value(status=vfolder_status) await db_session.execute(stmt) + """ + NOTICE: _retry method must be used in top level function + """ + + async def _retry( + self, + func: Callable[[SASession], Awaitable[_TQueryResult]], + ) -> None: + await execute_with_txn_retry( + txn_func=func, begin_trx=self._db.begin_session, connection=self._db.connect() + ) + async def delete_vFolder_by_id( self, vfolder_id: uuid.UUID, ) -> None: vfolder_ids = [vfolder_id] - async with self._db.begin_session() as sess: - await self._delete_vfolder_relation_rows(db_session=sess, vfolder_row_ids=vfolder_ids) + + async def _delete_and_update(db_session: SASession) -> None: + await self._delete_vfolder_relation_rows( + db_session=db_session, vfolder_row_ids=vfolder_ids + ) await self._update_vfolder_status( - db_session=sess, + db_session=db_session, vfolder_id=vfolder_id, vfolder_status=VFolderOperationStatus.DELETE_PENDING, ) + + await self._retry(func=_delete_and_update) From 44aeabc4bd59296e923c228433d9c8ed6130bad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 09:56:19 +0900 Subject: [PATCH 08/17] refactor: Split user and group type vfolder get methods --- .../manager/api/vfolders/repositories.py | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index 46264211702..fa1ef8582b4 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -160,30 +160,48 @@ async def get_accessible_folders( all_entries: list[VFolderItem] = [] async with self._db.begin_session() as sess: if "user" in allowed_vfolder_types: - owned_vfolders = await self._query_owned_vfolders( + user_type_vfolders: list[ + VFolderItem + ] = await self._query_accessible_user_type_vfolders( db_session=sess, user_identity=user_identity, group_id=group_id ) - all_entries.extend(owned_vfolders) - - shared_vfolders = await self._query_shared_vfolders( - db_session=sess, user_identity=user_identity - ) - all_entries.extend(shared_vfolders) + all_entries.extend(user_type_vfolders) if "group" in allowed_vfolder_types: - if group_id is not None: - group_vfolders = await self._query_specific_group_vfolders( - db_session=sess, user_identity=user_identity, group_id=group_id - ) - else: - group_vfolders = await self._query_all_accessible_group_vfolders( - db_session=sess, user_identity=user_identity - ) + group_type_vfolders: list[ + VFolderItem + ] = await self._query_accessible_group_type_vfolders( + db_session=sess, user_identity=user_identity, group_id=group_id + ) - all_entries.extend(group_vfolders) + all_entries.extend(group_type_vfolders) return all_entries + async def _query_accessible_user_type_vfolders( + self, db_session: SASession, user_identity: UserIdentity, group_id: Optional[uuid.UUID] + ) -> list[VFolderItem]: + owned: list[VFolderItem] = await self._query_owned_vfolders( + db_session=db_session, user_identity=user_identity, group_id=group_id + ) + shared: list[VFolderItem] = await self._query_shared_vfolders( + db_session=db_session, user_identity=user_identity + ) + + return [*owned, *shared] + + async def _query_accessible_group_type_vfolders( + self, db_session: SASession, user_identity: UserIdentity, group_id: Optional[uuid.UUID] + ) -> list[VFolderItem]: + if group_id is not None: + return await self._query_specific_group_vfolders( + db_session=db_session, user_identity=user_identity, group_id=group_id + ) + + return await self._query_all_accessible_group_vfolders( + db_session=db_session, user_identity=user_identity + ) + async def _query_specific_group_vfolders( self, db_session: SASession, From 16fe4e4083495d6a6a208a2aa813b4bafa5a4b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 10:04:27 +0900 Subject: [PATCH 09/17] refactor: Encapsulated user role checking --- src/ai/backend/manager/api/vfolders/repositories.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index fa1ef8582b4..fcd55ad55c1 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -15,7 +15,6 @@ from ai.backend.manager.models import ( HARD_DELETED_VFOLDER_STATUSES, ProjectType, - UserRole, VFolderOperationStatus, VFolderOwnershipType, VFolderPermission, @@ -321,7 +320,7 @@ async def _query_owned_vfolders( if group_id is not None: query = query.where((VFolderRow.group == group_id) | (VFolderRow.user.isnot(None))) - if user_identity.user_role not in (UserRole.ADMIN, UserRole.SUPERADMIN): + if user_identity.is_normal_user: query = query.where(VFolderRow.user == user_identity.user_uuid) vfolders: list[VFolderRow] = (await db_session.scalars(query)).all() From 931812530b5d8a4e858284cbca36b24f8ade199d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 15:07:52 +0900 Subject: [PATCH 10/17] fix: Change back to using queries that explicitly select --- src/ai/backend/manager/api/vfolders/repositories.py | 5 ++++- src/ai/backend/manager/data/vfolder/dto.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index fcd55ad55c1..2d6a832cc27 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -129,7 +129,10 @@ async def get_group_vfolder_resource_limit( async def persist_vfolder_metadata(self, metadata: VFolderMetadataToCreate) -> VFolderItem: async with self._db.begin_session() as sess: - query = sa.insert(VFolderRow).values(metadata.to_dict()).returning(VFolderRow) + insert_query = sa.insert(VFolderRow).values(metadata.to_dict()).returning(VFolderRow.id) + vfolder_id = await sess.scalar(insert_query) + + query = sa.select(VFolderRow).where(VFolderRow.id == vfolder_id) vfolder: VFolderRow = await sess.scalar(query) vfolder_item = VFolderItem.from_orm(orm=vfolder, is_owner=True) return vfolder_item diff --git a/src/ai/backend/manager/data/vfolder/dto.py b/src/ai/backend/manager/data/vfolder/dto.py index 75d38bddf84..7425633a726 100644 --- a/src/ai/backend/manager/data/vfolder/dto.py +++ b/src/ai/backend/manager/data/vfolder/dto.py @@ -96,10 +96,10 @@ class VFolderMetadataToCreate: host: str creator: str ownership_type: VFolderOwnershipType - user: str | None - group: str | None - unmanaged_path: str | None cloneable: bool + user: str | None = None + group: str | None = None + unmanaged_path: str | None = None status: VFolderOperationStatus = VFolderOperationStatus.READY def to_dict(self) -> dict[str, Any]: From 2d537379a810bcb34506517c95932f934b38eca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 17:04:29 +0900 Subject: [PATCH 11/17] test: add test methods for some methods --- tests/manager/api/vfolders/conftest.py | 248 +++++++++++++ .../manager/api/vfolders/test_repositories.py | 333 ++++++++++++++++++ 2 files changed, 581 insertions(+) create mode 100644 tests/manager/api/vfolders/test_repositories.py diff --git a/tests/manager/api/vfolders/conftest.py b/tests/manager/api/vfolders/conftest.py index 9618f350832..cece4b0358b 100644 --- a/tests/manager/api/vfolders/conftest.py +++ b/tests/manager/api/vfolders/conftest.py @@ -1,9 +1,37 @@ import uuid +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable, Optional from unittest.mock import MagicMock import pytest +import sqlalchemy as sa from pydantic import BaseModel +from ai.backend.common.types import ( + QuotaScopeID, + QuotaScopeType, + VFolderHostPermission, + VFolderUsageMode, +) +from ai.backend.manager.data.vfolder.dto import UserIdentity +from ai.backend.manager.models import ( + DomainRow, + GroupRow, + ProjectResourcePolicyRow, + ProjectType, + UserResourcePolicyRow, + UserRole, + UserRow, + UserStatus, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, +) +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.models.vfolder import ( + VFolderRow, +) + @pytest.fixture def mock_authenticated_request(): @@ -33,3 +61,223 @@ class TestResponse(BaseModel): @pytest.fixture def mock_success_response() -> TestResponse: return TestResponse(test="response") + + +@pytest.fixture +async def create_user_resource_policy( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[[str], Awaitable[str]]: + async def _create_user_resource_policy( + name: str, + max_vfolder_count: int = 0, + max_quota_scope_size: int = -1, + max_session_count_per_model_session: int = 5, + max_customized_image_count: int = 3, + ) -> str: + async with database_engine.begin() as conn: + policy_data = { + "name": name, + "max_vfolder_count": max_vfolder_count, + "max_quota_scope_size": max_quota_scope_size, + "max_session_count_per_model_session": max_session_count_per_model_session, + "max_customized_image_count": max_customized_image_count, + } + await conn.execute( + sa.insert(UserResourcePolicyRow) + .values(policy_data) + .returning(UserResourcePolicyRow) + ) + return name + + return _create_user_resource_policy + + +@pytest.fixture +async def create_project_resource_policy( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[[str], Awaitable[str]]: + async def _create_project_resource_policy( + name: str, + max_vfolder_count: int = 0, + max_quota_scope_size: int = -1, + max_network_count: int = 0, + ) -> str: + async with database_engine.begin() as conn: + policy_data = { + "name": name, + "max_vfolder_count": max_vfolder_count, + "max_quota_scope_size": max_quota_scope_size, + "max_network_count": max_network_count, + } + await conn.execute( + sa.insert(ProjectResourcePolicyRow) + .values(policy_data) + .returning(ProjectResourcePolicyRow) + ) + return name + + return _create_project_resource_policy + + +@pytest.fixture +async def create_domain( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[str]]: + async def _create_domain(name: str = "test-domain") -> str: + async with database_engine.begin() as conn: + domain_name = name + domain_data: dict[str, Any] = { + "name": domain_name, + "description": f"Test Domain for {name}", + "is_active": True, + "total_resource_slots": {}, + "allowed_vfolder_hosts": { + "local": [VFolderHostPermission.CREATE], + }, + "allowed_docker_registries": [], + "integration_id": None, + } + await conn.execute(sa.insert(DomainRow).values(domain_data).returning(DomainRow)) + return domain_name + + return _create_domain + + +@pytest.fixture +async def create_user_with_role( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + """ + NOTICE: To use 'default' resource policy, you must use `database_fixture` concurrently in test function + """ + + async def _create_user( + domain_name: str, + role: UserRole, + name: str, + container_uid: Optional[int] = 1000, + resource_policy_name: str = "default", + ) -> uuid.UUID: + async with database_engine.begin() as conn: + user_id = uuid.uuid4() + username = name + user_data = { + "uuid": user_id, + "username": username, + "email": f"{username}@test.com", + "password": "sample_password", + "need_password_change": False, + "full_name": "Sample User", + "description": "Test user", + "status": UserStatus.ACTIVE, + "status_info": None, + "domain_name": domain_name, + "role": role, + "resource_policy": resource_policy_name, + "totp_activated": False, + "sudo_session_enabled": False, + "container_uid": container_uid, + } + await conn.execute(sa.insert(UserRow).values(user_data)) + await conn.execute(sa.select(UserRow).where(UserRow.uuid == user_id)) + return user_id + + return _create_user + + +@pytest.fixture +async def create_identity( + create_user_with_role: Callable, +) -> Callable[..., Awaitable[tuple[uuid.UUID, UserIdentity]]]: + """ + NOTICE: To use 'default' resource policy in create_user_with_role, you must use `database_fixture` concurrently in test function + """ + + async def _create_identity( + domain_name: str, role: UserRole, name: str, resource_policy_name: str = "default" + ) -> tuple[uuid.UUID, UserIdentity]: + user_id = await create_user_with_role(domain_name, role, name, resource_policy_name) + identity = UserIdentity( + user_uuid=user_id, user_role=role, domain_name=domain_name, user_email="test@email.com" + ) + return user_id, identity + + return _create_identity + + +@pytest.fixture +async def create_vfolder( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + async def _create_vfolder( + domain_name: str, + user_id: Optional[uuid.UUID], + group_id: Optional[uuid.UUID], + name: str, + ) -> uuid.UUID: + assert (user_id is not None) or (group_id is not None) + quota_scope_type = QuotaScopeType.USER if user_id else QuotaScopeType.PROJECT + scope_id = user_id if user_id is not None else group_id + assert scope_id is not None + quota_scope_id = str(QuotaScopeID(quota_scope_type, scope_id)) + async with database_engine.begin() as conn: + vfolder_id = uuid.uuid4() + vfolder_data = { + "id": vfolder_id, + "host": "local", + "name": name, + "domain_name": domain_name, + "user": user_id, + "group": group_id, + "quota_scope_id": quota_scope_id, + "usage_mode": VFolderUsageMode.GENERAL, + "permission": VFolderPermission.READ_WRITE, + "ownership_type": VFolderOwnershipType.USER + if user_id + else VFolderOwnershipType.GROUP, + "status": VFolderOperationStatus.READY, + "max_files": 1000, + "max_size": 1024 * 1024, + "created_at": datetime.now(timezone.utc), + "last_used": None, + "cloneable": True, + } + await conn.execute(sa.insert(VFolderRow).values(vfolder_data)) + await conn.execute(sa.select(VFolderRow).where(VFolderRow.id == vfolder_id)) + return vfolder_id + + return _create_vfolder + + +@pytest.fixture +async def create_group( + database_engine: ExtendedAsyncSAEngine, +) -> Callable[..., Awaitable[uuid.UUID]]: + """ + NOTICE: To use 'default' resource policy, you must use `database_fixture` concurrently in test function + """ + + async def _create_group( + domain_name: str, + name: str, + type: ProjectType = ProjectType.GENERAL, + resource_policy_name: str = "default", + ) -> uuid.UUID: + async with database_engine.begin() as conn: + group_id = uuid.uuid4() + group_data = { + "id": group_id, + "name": name, + "description": "Test group", + "is_active": True, + "domain_name": domain_name, + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "resource_policy": resource_policy_name, + "type": type, + } + await conn.execute(sa.insert(GroupRow).values(group_data)) + await conn.execute(sa.select(GroupRow).where(GroupRow.id == group_id)) + return group_id + + return _create_group diff --git a/tests/manager/api/vfolders/test_repositories.py b/tests/manager/api/vfolders/test_repositories.py new file mode 100644 index 00000000000..e26f9d5ff85 --- /dev/null +++ b/tests/manager/api/vfolders/test_repositories.py @@ -0,0 +1,333 @@ +import uuid +from typing import Callable + +import pytest +import sqlalchemy as sa + +from ai.backend.common.types import ( + QuotaScopeID, + QuotaScopeType, + VFolderUsageMode, +) +from ai.backend.manager.api.exceptions import GroupNotFound, UserNotFound +from ai.backend.manager.api.vfolders.repositories import VFolderRepository +from ai.backend.manager.data.vfolder.dto import UserIdentity, VFolderItem, VFolderMetadataToCreate +from ai.backend.manager.models import ( + ProjectType, + UserRole, + VFolderOperationStatus, + VFolderOwnershipType, + VFolderPermission, +) +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.models.vfolder import ( + VFolderRow, +) + +""" +This test file also use fixtures created in `database_fixture` in tests/conftest.py, especially when using 'default' resource policy. +""" + + +@pytest.mark.parametrize("project_type", [ProjectType.GENERAL, ProjectType.MODEL_STORE]) +@pytest.mark.asyncio +async def test_get_group_type( + database_fixture, + create_domain: Callable, + create_group: Callable, + database_engine: ExtendedAsyncSAEngine, + project_type, +): + domain_name = await create_domain(name="test_add_permission") + group_id = await create_group( + domain_name=domain_name, name="test_get_group_type", type=project_type + ) + vfolder_repository = VFolderRepository(db=database_engine) + group_type = await vfolder_repository.get_group_type(group_id=group_id) + assert isinstance(group_type, ProjectType) + assert group_type == project_type + + +@pytest.mark.asyncio +async def test_get_container_id( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_user_with_role: Callable, + create_domain: Callable, +): + domain_name = await create_domain(name="test_get_container_id") + expected_container_id = 1234 + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test_get_container_id", + container_uid=expected_container_id, + ) + + repo = VFolderRepository(database_engine) + actual_container_id = await repo.get_user_container_id(user_id=user_id) + assert actual_container_id == expected_container_id + + +@pytest.mark.asyncio +async def test_get_created_vfolder_count( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_vfolder: Callable, + create_group: Callable, +): + domain_name = await create_domain(name="test_get_created_vfolder_count") + user_id = await create_user_with_role( + domain_name=domain_name, role=UserRole.USER, name="test_get_created_vfolder_count" + ) + group_id = await create_group(domain_name=domain_name, name="test_get_created_vfolder_count") + + # Create 3 user vfolders and delete 1 + await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="user-vfolder-1" + ) + await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="user-vfolder-2" + ) + deleted_vfolder_id = await create_vfolder( + domain_name=domain_name, user_id=user_id, group_id=None, name="deleted-vfolder" + ) + async with database_engine.begin() as conn: + await conn.execute( + sa.update(VFolderRow) + .where(VFolderRow.id == deleted_vfolder_id) + .values(status=VFolderOperationStatus.DELETE_COMPLETE) + ) + + # Create 3 group vfolders + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-1" + ) + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-2" + ) + await create_vfolder( + domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-3" + ) + + vfolder_repository = VFolderRepository(db=database_engine) + user_count = await vfolder_repository.get_created_vfolder_count( + user_id, VFolderOwnershipType.USER + ) + assert user_count == 2 + + group_count = await vfolder_repository.get_created_vfolder_count( + group_id, VFolderOwnershipType.GROUP + ) + assert group_count == 3 + + +@pytest.mark.asyncio +async def test_persist_vfolder_metadata( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, +): + domain_name = await create_domain(name="test_persist_vfolder_metadata") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test_persist_vfolder_metadata", + ) + quota_scope_id = QuotaScopeID(QuotaScopeType.USER, user_id) + host = "local" + + vfolder_name = "test_persist_vfolder_metadata" + metadata = VFolderMetadataToCreate( + name=vfolder_name, + domain_name=domain_name, + quota_scope_id=str(quota_scope_id), + usage_mode=VFolderUsageMode.GENERAL, + permission=VFolderPermission.READ_WRITE, + host=host, + creator=str(user_id), + ownership_type=VFolderOwnershipType.USER, + cloneable=True, + ) + vfolder_repository = VFolderRepository(db=database_engine) + vfolder_item: VFolderItem = await vfolder_repository.persist_vfolder_metadata(metadata=metadata) + assert vfolder_item.name == vfolder_name + assert vfolder_item.host == host + assert vfolder_item.ownership_type == VFolderOwnershipType.USER + assert vfolder_item.quota_scope_id == quota_scope_id + assert vfolder_item.permission == VFolderPermission.READ_WRITE + assert vfolder_item.creator == str(user_id) + assert vfolder_item.cloneable == True # noqa: E712 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "policy_name,vfolder_count,quota_size", + [ + ("test-user-policy-1", 10, 100000), + ("test-user-policy-2", 20, 200000), + ("test-user-policy-3", 0, -1), + ], +) +async def test_get_user_vfolder_resource_limit( + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_user_resource_policy: Callable, + policy_name: str, + vfolder_count: int, + quota_size: int, +): + # Given + domain_name = await create_domain(name=f"test-domain-{policy_name}") + policy_name = await create_user_resource_policy( + name=policy_name, + max_vfolder_count=vfolder_count, + max_quota_scope_size=quota_size, + ) + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name=f"test-user-{policy_name}", + resource_policy_name=policy_name, + ) + + # When + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + vfolder_repository = VFolderRepository(db=database_engine) + resource_limit = await vfolder_repository.get_user_vfolder_resource_limit(user_identity) + + # Then + assert resource_limit.max_vfolder_count == vfolder_count + assert resource_limit.max_quota_scope_size == quota_size + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "policy_name,vfolder_count,quota_size", + [ + ("test-project-policy-1", 20, 20000), + ("test-project-policy-2", 50, 50000), + ("test-project-policy-3", 0, -1), + ], +) +async def test_get_group_vfolder_resource_limit( + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_group: Callable, + create_project_resource_policy: Callable, + create_user_resource_policy: Callable, + policy_name: str, + vfolder_count: int, + quota_size: int, +): + # Given + domain_name = await create_domain(name=f"test-domain-{policy_name}") + + # Create user resource policy for the user + user_policy_name = f"user-{policy_name}" + await create_user_resource_policy( + name=user_policy_name, + max_vfolder_count=10, + max_quota_scope_size=100000, + ) + + # Create project resource policy for the group + project_policy_name = await create_project_resource_policy( + name=policy_name, + max_vfolder_count=vfolder_count, + max_quota_scope_size=quota_size, + ) + + # Create user with user resource policy + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name=f"test-user-{policy_name}", + resource_policy_name=user_policy_name, + ) + + # Create group with project resource policy + group_id = await create_group( + domain_name=domain_name, + name=f"test-group-{policy_name}", + resource_policy_name=project_policy_name, + ) + + # When + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + vfolder_repository = VFolderRepository(db=database_engine) + resource_limit = await vfolder_repository.get_group_vfolder_resource_limit( + user_identity=user_identity, + group_id=group_id, + ) + + # Then + assert resource_limit.max_vfolder_count == vfolder_count + assert resource_limit.max_quota_scope_size == quota_size + + +@pytest.mark.asyncio +async def test_get_user_vfolder_resource_limit_not_found( + database_fixture, + database_engine: ExtendedAsyncSAEngine, +): + # Given + non_exist_user_id = uuid.uuid4() + user_identity = UserIdentity( + user_uuid=non_exist_user_id, + domain_name="test-domain", + user_role=UserRole.USER, + user_email="test@example.com", + ) + + # When & Then + vfolder_repository = VFolderRepository(db=database_engine) + with pytest.raises(UserNotFound): + await vfolder_repository.get_user_vfolder_resource_limit(user_identity) + + +@pytest.mark.asyncio +async def test_get_group_vfolder_resource_limit_not_found( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, +): + # Given + domain_name = await create_domain("test-domain") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user", + ) + non_exist_group_id = uuid.uuid4() + + user_identity = UserIdentity( + user_uuid=user_id, + domain_name=domain_name, + user_role=UserRole.USER, + user_email="test@example.com", + ) + + # When & Then + vfolder_repository = VFolderRepository(db=database_engine) + with pytest.raises(GroupNotFound): + await vfolder_repository.get_group_vfolder_resource_limit( + user_identity=user_identity, + group_id=non_exist_group_id, + ) From 74df1139b392d571df091749959b12ed89fbc4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 17:56:38 +0900 Subject: [PATCH 12/17] test: add test for patch vfolder name --- .../manager/api/vfolders/test_repositories.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/manager/api/vfolders/test_repositories.py b/tests/manager/api/vfolders/test_repositories.py index e26f9d5ff85..92c56403ab6 100644 --- a/tests/manager/api/vfolders/test_repositories.py +++ b/tests/manager/api/vfolders/test_repositories.py @@ -38,12 +38,17 @@ async def test_get_group_type( database_engine: ExtendedAsyncSAEngine, project_type, ): + # Given domain_name = await create_domain(name="test_add_permission") group_id = await create_group( domain_name=domain_name, name="test_get_group_type", type=project_type ) + + # When vfolder_repository = VFolderRepository(db=database_engine) group_type = await vfolder_repository.get_group_type(group_id=group_id) + + # Then assert isinstance(group_type, ProjectType) assert group_type == project_type @@ -55,6 +60,7 @@ async def test_get_container_id( create_user_with_role: Callable, create_domain: Callable, ): + # Given domain_name = await create_domain(name="test_get_container_id") expected_container_id = 1234 user_id = await create_user_with_role( @@ -64,8 +70,11 @@ async def test_get_container_id( container_uid=expected_container_id, ) + # When repo = VFolderRepository(database_engine) actual_container_id = await repo.get_user_container_id(user_id=user_id) + + # Then assert actual_container_id == expected_container_id @@ -78,6 +87,7 @@ async def test_get_created_vfolder_count( create_vfolder: Callable, create_group: Callable, ): + # Given domain_name = await create_domain(name="test_get_created_vfolder_count") user_id = await create_user_with_role( domain_name=domain_name, role=UserRole.USER, name="test_get_created_vfolder_count" @@ -112,6 +122,7 @@ async def test_get_created_vfolder_count( domain_name=domain_name, user_id=None, group_id=group_id, name="group-vfolder-3" ) + # When & Then vfolder_repository = VFolderRepository(db=database_engine) user_count = await vfolder_repository.get_created_vfolder_count( user_id, VFolderOwnershipType.USER @@ -131,6 +142,7 @@ async def test_persist_vfolder_metadata( create_domain: Callable, create_user_with_role: Callable, ): + # Given domain_name = await create_domain(name="test_persist_vfolder_metadata") user_id = await create_user_with_role( domain_name=domain_name, @@ -152,8 +164,12 @@ async def test_persist_vfolder_metadata( ownership_type=VFolderOwnershipType.USER, cloneable=True, ) + + # When vfolder_repository = VFolderRepository(db=database_engine) vfolder_item: VFolderItem = await vfolder_repository.persist_vfolder_metadata(metadata=metadata) + + # Then assert vfolder_item.name == vfolder_name assert vfolder_item.host == host assert vfolder_item.ownership_type == VFolderOwnershipType.USER @@ -331,3 +347,46 @@ async def test_get_group_vfolder_resource_limit_not_found( user_identity=user_identity, group_id=non_exist_group_id, ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "old_name,new_name", + [ + ("test-folder", "renamed-folder"), + ("before_change", "after_change"), + ("test-1", "test-2"), + ], +) +async def test_patch_vfolder_name( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_vfolder: Callable, + old_name: str, + new_name: str, +): + # Given + domain_name = await create_domain(name="test-patch-vfolder-name") + user_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user", + ) + vfolder_id = await create_vfolder( + domain_name=domain_name, + user_id=user_id, + group_id=None, + name=old_name, + ) + + # When + vfolder_repository = VFolderRepository(db=database_engine) + await vfolder_repository.patch_vFolder_name(vfolder_id, new_name) + + # Then + async with database_engine.begin() as conn: + result = await conn.execute(sa.select(VFolderRow.name).where(VFolderRow.id == vfolder_id)) + updated_name = result.scalar() + assert updated_name == new_name From c5983d977ac9e0cdc94726f77451a1c81a3e20ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 18:33:03 +0900 Subject: [PATCH 13/17] fix: Add async with context in _retry --- .../manager/api/vfolders/repositories.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index 2d6a832cc27..df3797ebed1 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -377,8 +377,8 @@ async def _delete_vfolder_permission_rows( db_session: SASession, vfolder_row_ids: Iterable[uuid.UUID], ) -> None: - stmt = sa.delete(VFolderInvitationRow).where( - VFolderInvitationRow.vfolder.in_(vfolder_row_ids) + stmt = sa.delete(VFolderPermissionRow).where( + VFolderPermissionRow.vfolder.in_(vfolder_row_ids) ) await db_session.execute(stmt) @@ -387,8 +387,8 @@ async def _delete_vfolder_invitation_rows( db_session: SASession, vfolder_row_ids: Iterable[uuid.UUID], ) -> None: - stmt = sa.delete(VFolderPermissionRow).where( - VFolderPermissionRow.vfolder.in_(vfolder_row_ids) + stmt = sa.delete(VFolderInvitationRow).where( + VFolderInvitationRow.vfolder.in_(vfolder_row_ids) ) await db_session.execute(stmt) @@ -410,7 +410,9 @@ async def _update_vfolder_status( vfolder_id: uuid.UUID, vfolder_status: VFolderOperationStatus, ) -> None: - stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).value(status=vfolder_status) + stmt = ( + sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).values(status=vfolder_status) + ) await db_session.execute(stmt) """ @@ -421,9 +423,10 @@ async def _retry( self, func: Callable[[SASession], Awaitable[_TQueryResult]], ) -> None: - await execute_with_txn_retry( - txn_func=func, begin_trx=self._db.begin_session, connection=self._db.connect() - ) + async with self._db.connect() as conn: + await execute_with_txn_retry( + txn_func=func, begin_trx=self._db.begin_session, connection=conn + ) async def delete_vFolder_by_id( self, From 9d40d5d2581d32bc691ef06e358801e7c943f4d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 18:33:24 +0900 Subject: [PATCH 14/17] test: add test for delete vfolder --- .../manager/api/vfolders/test_repositories.py | 110 +++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) diff --git a/tests/manager/api/vfolders/test_repositories.py b/tests/manager/api/vfolders/test_repositories.py index 92c56403ab6..e06d2220d20 100644 --- a/tests/manager/api/vfolders/test_repositories.py +++ b/tests/manager/api/vfolders/test_repositories.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime, timezone from typing import Callable import pytest @@ -15,14 +16,13 @@ from ai.backend.manager.models import ( ProjectType, UserRole, + VFolderInvitationState, VFolderOperationStatus, VFolderOwnershipType, VFolderPermission, ) from ai.backend.manager.models.utils import ExtendedAsyncSAEngine -from ai.backend.manager.models.vfolder import ( - VFolderRow, -) +from ai.backend.manager.models.vfolder import VFolderInvitationRow, VFolderPermissionRow, VFolderRow """ This test file also use fixtures created in `database_fixture` in tests/conftest.py, especially when using 'default' resource policy. @@ -390,3 +390,107 @@ async def test_patch_vfolder_name( result = await conn.execute(sa.select(VFolderRow.name).where(VFolderRow.id == vfolder_id)) updated_name = result.scalar() assert updated_name == new_name + + +@pytest.mark.asyncio +async def test_delete_vfolder_by_id( + database_fixture, + database_engine: ExtendedAsyncSAEngine, + create_domain: Callable, + create_user_with_role: Callable, + create_group: Callable, + create_vfolder: Callable, +): + # Given + domain_name = await create_domain(name="test_delete_vfolder_by_id") + user1_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user1", + ) + user1_email = "test-user1@test.com" + + user2_id = await create_user_with_role( + domain_name=domain_name, + role=UserRole.USER, + name="test-user2", + ) + user2_email = "test-user2@test.com" + + group_id = await create_group( + domain_name=domain_name, + name="test-group", + ) + + # Create a vfolder owned by user1 + vfolder_id = await create_vfolder( + domain_name=domain_name, + user_id=user1_id, + group_id=group_id, + name="test-folder", + ) + + # Create permissions and invitations between users (invitation from user1 to user2) + async with database_engine.begin_session() as sess: + perm_data = { + "user": user2_id, + "vfolder": vfolder_id, + "permission": VFolderPermission.READ_WRITE, + } + await sess.execute(sa.insert(VFolderPermissionRow).values(perm_data)) + + invite_data = { + "id": uuid.uuid4(), + "vfolder": vfolder_id, + "inviter": user1_email, + "invitee": user2_email, + "created_at": datetime.now(timezone.utc), + "state": VFolderInvitationState.ACCEPTED, + } + await sess.execute(sa.insert(VFolderInvitationRow).values(invite_data)) + + # Check permission and invitation count before delete + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderPermissionRow) + .where(VFolderPermissionRow.vfolder == vfolder_id) + ) + permission_count = result.scalar() + assert permission_count == 1 + + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderInvitationRow) + .where(VFolderInvitationRow.vfolder == vfolder_id) + ) + invite_count = result.scalar() + assert invite_count == 1 + + # When + vfolder_repository = VFolderRepository(db=database_engine) + await vfolder_repository.delete_vFolder_by_id(vfolder_id=vfolder_id) + + # Then + async with database_engine.begin_session() as sess: + # Check vfolder status + result = await sess.execute(sa.select(VFolderRow.status).where(VFolderRow.id == vfolder_id)) + status = result.scalar() + assert status == VFolderOperationStatus.DELETE_PENDING + + # Check if permissions are deleted + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderPermissionRow) + .where(VFolderPermissionRow.vfolder == vfolder_id) + ) + permission_count = result.scalar() + assert permission_count == 0 + + # Check if invitations are deleted + result = await sess.execute( + sa.select(sa.func.count()) + .select_from(VFolderInvitationRow) + .where(VFolderInvitationRow.vfolder == vfolder_id) + ) + invite_count = result.scalar() + assert invite_count == 0 From d362f564f4094de9521493570beb3144084f919a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 18:49:44 +0900 Subject: [PATCH 15/17] fix: Fix vfolder name with convention --- src/ai/backend/manager/api/vfolders/repositories.py | 4 ++-- tests/manager/api/vfolders/test_repositories.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index df3797ebed1..3104fc41ec3 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -367,7 +367,7 @@ async def _query_shared_vfolders( return entries - async def patch_vFolder_name(self, vfolder_id: uuid.UUID, new_name: str) -> None: + async def patch_vfolder_name(self, vfolder_id: uuid.UUID, new_name: str) -> None: async with self._db.begin_session() as sess: stmt = sa.update(VFolderRow).where(VFolderRow.id == vfolder_id).values(name=new_name) await sess.execute(stmt) @@ -428,7 +428,7 @@ async def _retry( txn_func=func, begin_trx=self._db.begin_session, connection=conn ) - async def delete_vFolder_by_id( + async def delete_vfolder_by_id( self, vfolder_id: uuid.UUID, ) -> None: diff --git a/tests/manager/api/vfolders/test_repositories.py b/tests/manager/api/vfolders/test_repositories.py index e06d2220d20..ee6695fc318 100644 --- a/tests/manager/api/vfolders/test_repositories.py +++ b/tests/manager/api/vfolders/test_repositories.py @@ -383,7 +383,7 @@ async def test_patch_vfolder_name( # When vfolder_repository = VFolderRepository(db=database_engine) - await vfolder_repository.patch_vFolder_name(vfolder_id, new_name) + await vfolder_repository.patch_vfolder_name(vfolder_id, new_name) # Then async with database_engine.begin() as conn: @@ -468,7 +468,7 @@ async def test_delete_vfolder_by_id( # When vfolder_repository = VFolderRepository(db=database_engine) - await vfolder_repository.delete_vFolder_by_id(vfolder_id=vfolder_id) + await vfolder_repository.delete_vfolder_by_id(vfolder_id=vfolder_id) # Then async with database_engine.begin_session() as sess: From 749d43e539e2ad2dade2406e27db90d1231c04d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 13 Feb 2025 18:55:42 +0900 Subject: [PATCH 16/17] fix: Remove list in sa.select --- src/ai/backend/manager/api/vfolders/repositories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ai/backend/manager/api/vfolders/repositories.py b/src/ai/backend/manager/api/vfolders/repositories.py index 3104fc41ec3..e39d7b0f09b 100644 --- a/src/ai/backend/manager/api/vfolders/repositories.py +++ b/src/ai/backend/manager/api/vfolders/repositories.py @@ -70,7 +70,7 @@ async def get_created_vfolder_count( ownership_type_caluse = VFolderRow.group == owner_id query = ( - sa.select([sa.func.count()]) + sa.select(sa.func.count()) .select_from(VFolderRow) .where( (ownership_type_caluse) From 5cdc8c02a7464736098aa3300c6222bd970a811d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 14 Feb 2025 13:20:58 +0900 Subject: [PATCH 17/17] chore: change name of dto --- src/ai/backend/common/dto/manager/field.py | 4 ++-- src/ai/backend/common/dto/manager/request.py | 4 ++-- src/ai/backend/manager/models/vfolder.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ai/backend/common/dto/manager/field.py b/src/ai/backend/common/dto/manager/field.py index 89c9d00ae85..a14ba581e45 100644 --- a/src/ai/backend/common/dto/manager/field.py +++ b/src/ai/backend/common/dto/manager/field.py @@ -6,7 +6,7 @@ from ai.backend.common.types import VFolderUsageMode -class VFolderPermissionField(enum.StrEnum): +class VFolderMountPermissionField(enum.StrEnum): READ_ONLY = "ro" READ_WRITE = "rw" RW_DELETE = "wd" @@ -38,7 +38,7 @@ class VFolderItemField(BaseModel): host: str usage_mode: VFolderUsageMode created_at: str - permission: VFolderPermissionField + permission: VFolderMountPermissionField max_size: int creator: str ownership_type: VFolderOwnershipTypeField diff --git a/src/ai/backend/common/dto/manager/request.py b/src/ai/backend/common/dto/manager/request.py index c36250d516f..e904bd61ca7 100644 --- a/src/ai/backend/common/dto/manager/request.py +++ b/src/ai/backend/common/dto/manager/request.py @@ -4,7 +4,7 @@ from pydantic import AliasChoices, BaseModel, Field from ai.backend.common import typed_validators as tv -from ai.backend.common.dto.manager.field import VFolderPermissionField +from ai.backend.common.dto.manager.field import VFolderMountPermissionField from ai.backend.common.types import VFolderUsageMode @@ -17,7 +17,7 @@ class VFolderCreateReq(BaseModel): default=None, ) usage_mode: VFolderUsageMode = Field(default=VFolderUsageMode.GENERAL) - permission: VFolderPermissionField = Field(default=VFolderPermissionField.READ_WRITE) + permission: VFolderMountPermissionField = Field(default=VFolderMountPermissionField.READ_WRITE) unmanaged_path: Optional[str] = Field( validation_alias=AliasChoices("unmanaged_path", "unmanagedPath"), default=None, diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py index 1a74190b6a2..b58dc894fbf 100644 --- a/src/ai/backend/manager/models/vfolder.py +++ b/src/ai/backend/manager/models/vfolder.py @@ -41,9 +41,9 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.defs import MODEL_VFOLDER_LENGTH_LIMIT from ai.backend.common.dto.manager.field import ( + VFolderMountPermissionField, VFolderOperationStatusField, VFolderOwnershipTypeField, - VFolderPermissionField, ) from ai.backend.common.types import ( MountPermission, @@ -177,8 +177,8 @@ class VFolderPermission(enum.StrEnum): RW_DELETE = "wd" OWNER_PERM = "wd" # resolved as RW_DELETE - def to_field(self) -> VFolderPermissionField: - return VFolderPermissionField(self) + def to_field(self) -> VFolderMountPermissionField: + return VFolderMountPermissionField(self) class VFolderPermissionValidator(t.Trafaret):