diff --git a/changes/2987.feature.md b/changes/2987.feature.md new file mode 100644 index 00000000000..9b51929f852 --- /dev/null +++ b/changes/2987.feature.md @@ -0,0 +1 @@ +Add image_node and vfolder_node fields to ComputeSession schema diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 74d56366fcd..99e4debb61b 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -655,6 +655,14 @@ type KernelNode implements Node { cluster_hostname: String session_id: UUID image: ImageNode + + """Added in 25.3.1.""" + image_reference: String + + """ + Added in 25.3.1. The architecture that the image of this kernel requires + """ + architecture: String status: String status_changed: DateTime status_info: String @@ -1332,6 +1340,12 @@ type ComputeSessionNode implements Node { vfolder_mounts: [String] occupied_slots: JSONString requested_slots: JSONString + + """Added in 25.3.1.""" + image_references: [String] + + """Added in 25.3.1.""" + vfolder_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): VirtualFolderConnection num_queries: BigInt inference_metrics: JSONString kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 8c2999a60db..9eba68e8eb0 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "Backend.AI Manager API", "description": "Backend.AI Manager REST API specification", - "version": "25.2.0", + "version": "25.3.0", "contact": { "name": "Lablup Inc.", "url": "https://docs.backend.ai", diff --git a/src/ai/backend/manager/models/endpoint.py b/src/ai/backend/manager/models/endpoint.py index 10d49c8b2fd..7279e4cb183 100644 --- a/src/ai/backend/manager/models/endpoint.py +++ b/src/ai/backend/manager/models/endpoint.py @@ -1039,7 +1039,7 @@ async def from_row( return cls( endpoint_id=row.id, # image="", # deprecated, row.image_object.name, - image_object=ImageNode.from_row(row.image_row), + image_object=ImageNode.from_row(ctx, row.image_row), domain=row.domain, project=row.project, resource_group=row.resource_group, diff --git a/src/ai/backend/manager/models/gql_models/image.py b/src/ai/backend/manager/models/gql_models/image.py index 8294848fa1e..16bda8f71a5 100644 --- a/src/ai/backend/manager/models/gql_models/image.py +++ b/src/ai/backend/manager/models/gql_models/image.py @@ -449,23 +449,27 @@ async def batch_load_by_image_identifier( @overload @classmethod - def from_row(cls, row: ImageRow) -> ImageNode: ... + def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ... @overload @classmethod def from_row( - cls, row: ImageRow, *, permissions: Optional[Iterable[ImagePermission]] = None + cls, graph_ctx, row: ImageRow, *, permissions: Optional[Iterable[ImagePermission]] = None ) -> ImageNode: ... @overload @classmethod def from_row( - cls, row: None, *, permissions: Optional[Iterable[ImagePermission]] = None + cls, graph_ctx, row: None, *, permissions: Optional[Iterable[ImagePermission]] = None ) -> None: ... @classmethod def from_row( - cls, row: ImageRow | None, *, permissions: Optional[Iterable[ImagePermission]] = None + cls, + graph_ctx, + row: Optional[ImageRow], + *, + permissions: Optional[Iterable[ImagePermission]] = None, ) -> ImageNode | None: if row is None: return None @@ -565,6 +569,7 @@ async def get_node( return None return cls.from_row( + graph_ctx, image_row, permissions=await permission_ctx.calculate_final_permission(image_row), ) @@ -629,6 +634,7 @@ async def get_connection( total_cnt = await db_session.scalar(cnt_query) result: list[Self] = [ cls.from_row( + graph_ctx, row, permissions=await permission_ctx.calculate_final_permission(row), ) @@ -692,7 +698,7 @@ async def mutate( ): return ForgetImageById(ok=False, msg="Forbidden") await session.delete(image_row) - return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row)) + return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row)) class ForgetImage(graphene.Mutation): @@ -739,7 +745,7 @@ async def mutate( ): return ForgetImage(ok=False, msg="Forbidden") await session.delete(image_row) - return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row)) + return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row)) class UntagImageFromRegistry(graphene.Mutation): @@ -805,7 +811,7 @@ async def mutate( scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info) await scanner.untag(image_row.image_ref) - return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row)) + return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row)) class PreloadImage(graphene.Mutation): diff --git a/src/ai/backend/manager/models/gql_models/kernel.py b/src/ai/backend/manager/models/gql_models/kernel.py index cbb52db6858..8711676d1ab 100644 --- a/src/ai/backend/manager/models/gql_models/kernel.py +++ b/src/ai/backend/manager/models/gql_models/kernel.py @@ -4,7 +4,9 @@ from typing import ( TYPE_CHECKING, Any, + Optional, Self, + cast, ) import graphene @@ -14,10 +16,7 @@ from ai.backend.common import msgpack, redis_helper from ai.backend.common.types import AgentId, KernelId, SessionId -from ai.backend.manager.models.base import ( - batch_multiresult_in_scalar_stream, - batch_multiresult_in_session, -) +from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream from ..gql_relay import AsyncNode, Connection from ..kernel import KernelRow, KernelStatus @@ -48,6 +47,10 @@ class Meta: # image image = graphene.Field(ImageNode) + image_reference = graphene.String(description="Added in 25.3.1.") + architecture = graphene.String( + description="Added in 25.3.1. The architecture that the image of this kernel requires" + ) # status status = graphene.String() @@ -75,11 +78,9 @@ async def batch_load_by_session_id( graph_ctx: GraphQueryContext, session_ids: Sequence[SessionId], ) -> Sequence[Sequence[Self]]: - from ..kernel import kernels - async with graph_ctx.db.begin_readonly_session() as db_sess: - query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids)) - return await batch_multiresult_in_session( + query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids)) + return await batch_multiresult_in_scalar_stream( graph_ctx, db_sess, query, @@ -122,6 +123,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: local_rank=row.local_rank, cluster_role=row.cluster_role, session_id=row.session_id, + architecture=row.architecture, + image_reference=row.image, status=row.status, status_changed=row.status_changed, status_info=row.status_info, @@ -138,6 +141,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: preopen_ports=row.preopen_ports, ) + async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader_by_func( + graph_ctx, ImageNode.batch_load_by_name_and_arch + ) + images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture))) + try: + return images[0] + except IndexError: + return None + async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None: graph_ctx: GraphQueryContext = info.context loader = graph_ctx.dataloader_manager.get_loader_by_func( diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index e480ea5b6e1..e023ab1bd8f 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -13,6 +13,7 @@ import graphene import graphql +import more_itertools import sqlalchemy as sa import trafaret as t from dateutil.parser import parse as dtparse @@ -59,6 +60,7 @@ from ..user import UserRole from ..utils import execute_with_txn_retry from .kernel import KernelConnection, KernelNode +from .vfolder import VirtualFolderConnection, VirtualFolderNode if TYPE_CHECKING: from ..gql import GraphQueryContext @@ -195,6 +197,14 @@ class Meta: vfolder_mounts = graphene.List(lambda: graphene.String) occupied_slots = graphene.JSONString() requested_slots = graphene.JSONString() + image_references = graphene.List( + lambda: graphene.String, + description="Added in 25.3.1.", + ) + vfolder_nodes = PaginatedConnectionField( + VirtualFolderConnection, + description="Added in 25.3.1.", + ) # statistics num_queries = BigInt() @@ -264,6 +274,7 @@ def from_row( vfolder_mounts=[vf.vfid.folder_id for vf in row.vfolders_sorted_by_id], occupied_slots=row.occupying_slots.to_json(), requested_slots=row.requested_slots.to_json(), + image_references=row.images, # statistics num_queries=row.num_queries, ) @@ -277,19 +288,27 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any ) return await loader.load(self.row_id) + async def resolve_vfolder_nodes( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[VirtualFolderNode]: + ctx: GraphQueryContext = info.context + _folder_ids = cast(list[uuid.UUID], self.vfolder_mounts) + loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id) + result = cast(list[list[VirtualFolderNode]], await loader.load_many(_folder_ids)) + + vf_nodes = cast(list[VirtualFolderNode], list(more_itertools.flatten(result))) + return ConnectionResolverResult(vf_nodes, None, None, None, total_count=len(vf_nodes)) + async def resolve_kernel_nodes( self, info: graphene.ResolveInfo, ) -> ConnectionResolverResult[KernelNode]: ctx: GraphQueryContext = info.context loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id") - kernels = await loader.load(self.row_id) + kernel_nodes = await loader.load(self.row_id) return ConnectionResolverResult( - kernels, - None, - None, - None, - total_count=len(kernels), + kernel_nodes, None, None, None, total_count=len(kernel_nodes) ) async def resolve_dependees( @@ -492,7 +511,6 @@ async def get_accessible_connection( before=before, last=last, ) - query = query.options(selectinload(SessionRow.kernels)) async with graph_ctx.db.connect() as db_conn: user = graph_ctx.user client_ctx = ClientContext( diff --git a/src/ai/backend/manager/models/gql_models/vfolder.py b/src/ai/backend/manager/models/gql_models/vfolder.py index 63795d7f0d9..f8d9824cbdf 100644 --- a/src/ai/backend/manager/models/gql_models/vfolder.py +++ b/src/ai/backend/manager/models/gql_models/vfolder.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence from datetime import datetime from typing import ( TYPE_CHECKING, @@ -25,6 +25,7 @@ VFolderID, VFolderUsageMode, ) +from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream from ...api.exceptions import ( VFolderOperationFailed, @@ -216,6 +217,25 @@ def from_row( result.permissions = [] if permissions is None else permissions return result + @classmethod + async def batch_load_by_id( + cls, + graph_ctx: GraphQueryContext, + folder_ids: Sequence[uuid.UUID], + ) -> Sequence[Sequence[Self]]: + query = ( + sa.select(VFolderRow) + .where(VFolderRow.id.in_(folder_ids)) + .options( + joinedload(VFolderRow.user_row), + joinedload(VFolderRow.group_row), + ) + ) + async with graph_ctx.db.begin_readonly_session() as db_session: + return await batch_multiresult_in_scalar_stream( + graph_ctx, db_session, query, cls, folder_ids, lambda row: row.id + ) + @classmethod async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self: graph_ctx: GraphQueryContext = info.context diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index e42355e20be..dd0cd204699 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -941,7 +941,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]: "session_id": row.session_id, # image "image": row.image, - "image_object": ImageNode.from_row(row.image_row), + "image_object": ImageNode.from_row(ctx, row.image_row), "architecture": row.architecture, "registry": row.registry, # status