diff --git a/src/ai/backend/manager/models/gql_models/image.py b/src/ai/backend/manager/models/gql_models/image.py index c98b25343cf..5fc35e6939b 100644 --- a/src/ai/backend/manager/models/gql_models/image.py +++ b/src/ai/backend/manager/models/gql_models/image.py @@ -28,6 +28,7 @@ ImageAlias, ) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType from ...api.exceptions import ImageNotFound, ObjectNotFound @@ -331,6 +332,27 @@ class Meta: graphene.String, description="Added in 24.03.4. The array of image aliases." ) + @classmethod + async def batch_load_by_name_and_arch( + cls, + graph_ctx: GraphQueryContext, + name_and_arch: Sequence[tuple[str, str]], + ) -> Sequence[Sequence[ImageNode]]: + query = ( + sa.select(ImageRow) + .where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch)) + .options(selectinload(ImageRow.aliases)) + ) + async with graph_ctx.db.begin_readonly_session() as db_session: + return await batch_multiresult_in_scalar_stream( + graph_ctx, + db_session, + query, + cls, + name_and_arch, + lambda row: (row.name, row.architecture), + ) + @overload @classmethod def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ... diff --git a/src/ai/backend/manager/models/gql_models/kernel.py b/src/ai/backend/manager/models/gql_models/kernel.py index 2aedffaa51c..e371b50d78e 100644 --- a/src/ai/backend/manager/models/gql_models/kernel.py +++ b/src/ai/backend/manager/models/gql_models/kernel.py @@ -4,21 +4,21 @@ from typing import ( TYPE_CHECKING, Any, + Optional, Self, + cast, ) import graphene import sqlalchemy as sa from graphene.types.datetime import DateTime as GQLDateTime from redis.asyncio import Redis -from sqlalchemy.orm import joinedload, selectinload from ai.backend.common import msgpack, redis_helper from ai.backend.common.types import KernelId, SessionId from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream from ..gql_relay import AsyncNode, Connection -from ..image import ImageRow from ..kernel import KernelRow, KernelStatus from ..user import UserRole from .image import ImageNode @@ -47,6 +47,7 @@ class Meta: # image image = graphene.Field(ImageNode) + image_reference = graphene.String(description="Added in 24.12.0.") architecture = graphene.String( description="Added in 24.12.0. The architecture that the image of this kernel requires" ) @@ -78,13 +79,7 @@ async def batch_load_by_session_id( session_ids: Sequence[SessionId], ) -> Sequence[Sequence[Self]]: async with graph_ctx.db.begin_readonly_session() as db_sess: - query = ( - sa.select(KernelRow) - .where(KernelRow.session_id.in_(session_ids)) - .options( - joinedload(KernelRow.image_row).options(selectinload(ImageRow.aliases)), - ) - ) + query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids)) return await batch_multiresult_in_scalar_stream( graph_ctx, db_sess, @@ -112,7 +107,7 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: cluster_role=row.cluster_role, session_id=row.session_id, architecture=row.architecture, - image=ImageNode.from_row(ctx, row.image_row), + image_reference=row.image, status=row.status, status_changed=row.status_changed, status_info=row.status_info, @@ -129,6 +124,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: preopen_ports=row.preopen_ports, ) + async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader_by_func( + graph_ctx, ImageNode.batch_load_by_name_and_arch + ) + images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture))) + try: + return images[0] + except IndexError: + return None + async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None: graph_ctx: GraphQueryContext = info.context loader = graph_ctx.dataloader_manager.get_loader_by_func( diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index ed9685574c4..d34db3f8fdd 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -54,8 +54,8 @@ ) from ..user import UserRole from ..utils import execute_with_txn_retry -from .kernel import KernelConnection, KernelNode -from .vfolder import VirtualFolderConnection, VirtualFolderNode +from .kernel import KernelNode +from .vfolder import VirtualFolderNode if TYPE_CHECKING: from ..gql import GraphQueryContext @@ -196,8 +196,8 @@ class Meta: lambda: graphene.String, description="Added in 24.12.0.", ) - vfolder_nodes = PaginatedConnectionField( - VirtualFolderConnection, + vfolder_nodes = graphene.List( + lambda: VirtualFolderNode, description="Added in 24.12.0.", ) @@ -206,8 +206,9 @@ class Meta: inference_metrics = graphene.JSONString() # relations - kernel_nodes = PaginatedConnectionField( - KernelConnection, + kernel_nodes = graphene.List( + lambda: KernelNode, + description="Added in 24.9.0.", ) dependents = PaginatedConnectionField( "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", @@ -285,34 +286,20 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any async def resolve_vfolder_nodes( self, info: graphene.ResolveInfo, - ) -> ConnectionResolverResult[VirtualFolderNode]: + ) -> list[VirtualFolderNode]: ctx: GraphQueryContext = info.context loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id) vfolder_mounts = cast(list[VFolderMount], self.vfolder_mounts) _folder_ids = [vf_mount.vfid.folder_id for vf_mount in vfolder_mounts] - folders = cast(list[VirtualFolderNode], await loader.load_many(_folder_ids)) - return ConnectionResolverResult( - folders, - None, - None, - None, - total_count=len(folders), - ) + return await loader.load_many(_folder_ids) async def resolve_kernel_nodes( self, info: graphene.ResolveInfo, - ) -> ConnectionResolverResult[KernelNode]: + ) -> list[KernelNode]: ctx: GraphQueryContext = info.context loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id") - kernels = await loader.load(self.row_id) - return ConnectionResolverResult( - kernels, - None, - None, - None, - total_count=len(kernels), - ) + return await loader.load(self.row_id) async def resolve_dependees( self, diff --git a/src/ai/backend/manager/models/gql_models/vfolder.py b/src/ai/backend/manager/models/gql_models/vfolder.py index 97bf83bd6b4..1e19d0c4ebc 100644 --- a/src/ai/backend/manager/models/gql_models/vfolder.py +++ b/src/ai/backend/manager/models/gql_models/vfolder.py @@ -25,6 +25,7 @@ VFolderID, VFolderUsageMode, ) +from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream from ...api.exceptions import ( VFolderOperationFailed, @@ -221,7 +222,7 @@ async def batch_load_by_id( cls, graph_ctx: GraphQueryContext, folder_ids: Sequence[uuid.UUID], - ) -> list[Self]: + ) -> Sequence[Sequence[Self]]: query = ( sa.select(VFolderRow) .where(VFolderRow.id.in_(folder_ids)) @@ -231,7 +232,9 @@ async def batch_load_by_id( ) ) async with graph_ctx.db.begin_readonly_session() as db_session: - return [cls.from_row(graph_ctx, row) for row in await db_session.scalars(query)] + return await batch_multiresult_in_scalar_stream( + graph_ctx, db_session, query, cls, folder_ids, lambda row: row.id + ) @classmethod async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self: