Skip to content

Commit be01b27

Browse files
committed
feat: Add image_node and vfolder_node fields to ComputeSession schema
1 parent 10cfd3c commit be01b27

File tree

9 files changed

+169
-45
lines changed

9 files changed

+169
-45
lines changed

changes/2987.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add image_node and vfolder_node fields to ComputeSession schema

src/ai/backend/manager/api/schema.graphql

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,21 @@ type Queries {
156156
id: GlobalIDField!
157157

158158
"""Added in 24.09.0."""
159-
project_id: UUID!
159+
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0.")
160160

161161
"""Added in 24.09.0. Default is read_attribute."""
162162
permission: SessionPermissionValueField = "read_attribute"
163163
): ComputeSessionNode
164164

165165
"""Added in 24.09.0."""
166166
compute_session_nodes(
167+
"""
168+
Added in 24.12.0. Default value `system` queries across the entire system.
169+
"""
170+
scope_id: ScopeField
171+
167172
"""Added in 24.09.0."""
168-
project_id: UUID!
173+
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0. Use `scope_id` instead.")
169174

170175
"""Added in 24.09.0. Default is read_attribute."""
171176
permission: SessionPermissionValueField = "read_attribute"
@@ -579,6 +584,14 @@ type KernelNode implements Node {
579584
cluster_hostname: String
580585
session_id: UUID
581586
image: ImageNode
587+
588+
"""Added in 24.12.0."""
589+
image_reference: String
590+
591+
"""
592+
Added in 24.12.0. The architecture that the image of this kernel requires
593+
"""
594+
architecture: String
582595
status: String
583596
status_changed: DateTime
584597
status_info: String
@@ -1177,6 +1190,12 @@ type ComputeSessionNode implements Node {
11771190
vfolder_mounts: [String]
11781191
occupied_slots: JSONString
11791192
requested_slots: JSONString
1193+
1194+
"""Added in 24.12.0."""
1195+
image_references: [String]
1196+
1197+
"""Added in 24.12.0."""
1198+
vfolder_nodes: [VirtualFolderNode]
11801199
num_queries: BigInt
11811200
inference_metrics: JSONString
11821201
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection

src/ai/backend/manager/models/endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ async def from_row(
793793
return cls(
794794
endpoint_id=row.id,
795795
# image="", # deprecated, row.image_object.name,
796-
image_object=ImageNode.from_row(row.image_row),
796+
image_object=ImageNode.from_row(ctx, row.image_row),
797797
domain=row.domain,
798798
project=row.project,
799799
resource_group=row.resource_group,

src/ai/backend/manager/models/gql.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
LegacyComputeSessionList,
131131
)
132132
from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair
133-
from .rbac import ScopeType, SystemScope
133+
from .rbac import ProjectScope, ScopeType, SystemScope
134134
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
135135
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
136136
from .resource_policy import (
@@ -726,7 +726,11 @@ class Queries(graphene.ObjectType):
726726
ComputeSessionNode,
727727
description="Added in 24.09.0.",
728728
id=GlobalIDField(required=True),
729-
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
729+
project_id=graphene.UUID(
730+
required=False,
731+
description="Added in 24.09.0.",
732+
deprecation_reason="Deprecated since 24.12.0.",
733+
),
730734
permission=SessionPermissionValueField(
731735
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
732736
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
@@ -736,7 +740,15 @@ class Queries(graphene.ObjectType):
736740
compute_session_nodes = PaginatedConnectionField(
737741
ComputeSessionConnection,
738742
description="Added in 24.09.0.",
739-
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
743+
scope_id=ScopeField(
744+
required=False,
745+
description="Added in 24.12.0. Default value `system` queries across the entire system.",
746+
),
747+
project_id=graphene.UUID(
748+
required=False,
749+
description="Added in 24.09.0.",
750+
deprecation_reason="Deprecated since 24.12.0. Use `scope_id` instead.",
751+
),
740752
permission=SessionPermissionValueField(
741753
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
742754
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
@@ -2043,17 +2055,23 @@ async def resolve_compute_session_node(
20432055
info: graphene.ResolveInfo,
20442056
*,
20452057
id: ResolvedGlobalID,
2046-
project_id: uuid.UUID,
2058+
project_id: Optional[uuid.UUID] = None,
20472059
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
20482060
) -> ComputeSessionNode | None:
2049-
return await ComputeSessionNode.get_accessible_node(info, id, project_id, permission)
2061+
scope_id: ScopeType
2062+
if project_id is None:
2063+
scope_id = SystemScope()
2064+
else:
2065+
scope_id = ProjectScope(project_id=project_id)
2066+
return await ComputeSessionNode.get_accessible_node(info, id, scope_id, permission)
20502067

20512068
@staticmethod
20522069
async def resolve_compute_session_nodes(
20532070
root: Any,
20542071
info: graphene.ResolveInfo,
20552072
*,
2056-
project_id: uuid.UUID,
2073+
scope_id: Optional[ScopeType] = None,
2074+
project_id: Optional[uuid.UUID] = None,
20572075
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
20582076
filter: str | None = None,
20592077
order: str | None = None,
@@ -2063,9 +2081,17 @@ async def resolve_compute_session_nodes(
20632081
before: str | None = None,
20642082
last: int | None = None,
20652083
) -> ConnectionResolverResult[ComputeSessionNode]:
2084+
_scope_id: ScopeType
2085+
if scope_id is not None:
2086+
_scope_id = scope_id
2087+
else:
2088+
if project_id is not None:
2089+
_scope_id = ProjectScope(project_id=project_id)
2090+
else:
2091+
_scope_id = SystemScope()
20662092
return await ComputeSessionNode.get_accessible_connection(
20672093
info,
2068-
project_id,
2094+
_scope_id,
20692095
permission,
20702096
filter,
20712097
order,

src/ai/backend/manager/models/gql_models/image.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AsyncIterator,
1010
List,
1111
Optional,
12+
Self,
1213
overload,
1314
)
1415
from uuid import UUID
@@ -27,12 +28,13 @@
2728
ImageAlias,
2829
)
2930
from ai.backend.logging import BraceStyleAdapter
31+
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
3032
from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType
3133

3234
from ...api.exceptions import ImageNotFound, ObjectNotFound
3335
from ...defs import DEFAULT_IMAGE_ARCH
3436
from ..base import set_if_set
35-
from ..gql_relay import AsyncNode
37+
from ..gql_relay import AsyncNode, Connection
3638
from ..image import (
3739
ImageAliasRow,
3840
ImageIdentifier,
@@ -330,16 +332,37 @@ class Meta:
330332
graphene.String, description="Added in 24.03.4. The array of image aliases."
331333
)
332334

335+
@classmethod
336+
async def batch_load_by_name_and_arch(
337+
cls,
338+
graph_ctx: GraphQueryContext,
339+
name_and_arch: Sequence[tuple[str, str]],
340+
) -> Sequence[Sequence[ImageNode]]:
341+
query = (
342+
sa.select(ImageRow)
343+
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
344+
.options(selectinload(ImageRow.aliases))
345+
)
346+
async with graph_ctx.db.begin_readonly_session() as db_session:
347+
return await batch_multiresult_in_scalar_stream(
348+
graph_ctx,
349+
db_session,
350+
query,
351+
cls,
352+
name_and_arch,
353+
lambda row: (row.name, row.architecture),
354+
)
355+
333356
@overload
334357
@classmethod
335-
def from_row(cls, row: ImageRow) -> ImageNode: ...
358+
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...
336359

337360
@overload
338361
@classmethod
339-
def from_row(cls, row: None) -> None: ...
362+
def from_row(cls, graph_ctx: GraphQueryContext, row: None) -> None: ...
340363

341364
@classmethod
342-
def from_row(cls, row: ImageRow | None) -> ImageNode | None:
365+
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow | None) -> Self | None:
343366
if row is None:
344367
return None
345368
return cls(
@@ -401,7 +424,13 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
401424
image_row = await db_session.scalar(query)
402425
if image_row is None:
403426
raise ValueError(f"Image not found (id: {image_id})")
404-
return cls.from_row(image_row)
427+
return cls.from_row(graph_ctx, image_row)
428+
429+
430+
class ImageConnection(Connection):
431+
class Meta:
432+
node = ImageNode
433+
description = "Added in 24.12.0."
405434

406435

407436
class ForgetImageById(graphene.Mutation):
@@ -453,7 +482,7 @@ async def mutate(
453482
):
454483
return ForgetImageById(ok=False, msg="Forbidden")
455484
await session.delete(image_row)
456-
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row))
485+
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))
457486

458487

459488
class ForgetImage(graphene.Mutation):
@@ -500,7 +529,7 @@ async def mutate(
500529
):
501530
return ForgetImage(ok=False, msg="Forbidden")
502531
await session.delete(image_row)
503-
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row))
532+
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))
504533

505534

506535
class UntagImageFromRegistry(graphene.Mutation):
@@ -566,7 +595,7 @@ async def mutate(
566595
scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info)
567596
await scanner.untag(image_row.image_ref)
568597

569-
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row))
598+
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))
570599

571600

572601
class PreloadImage(graphene.Mutation):

src/ai/backend/manager/models/gql_models/kernel.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import (
55
TYPE_CHECKING,
66
Any,
7+
Optional,
78
Self,
9+
cast,
810
)
911

1012
import graphene
@@ -14,10 +16,7 @@
1416

1517
from ai.backend.common import msgpack, redis_helper
1618
from ai.backend.common.types import AgentId, KernelId, SessionId
17-
from ai.backend.manager.models.base import (
18-
batch_multiresult_in_scalar_stream,
19-
batch_multiresult_in_session,
20-
)
19+
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
2120

2221
from ..gql_relay import AsyncNode, Connection
2322
from ..kernel import KernelRow, KernelStatus
@@ -48,6 +47,10 @@ class Meta:
4847

4948
# image
5049
image = graphene.Field(ImageNode)
50+
image_reference = graphene.String(description="Added in 24.12.0.")
51+
architecture = graphene.String(
52+
description="Added in 24.12.0. The architecture that the image of this kernel requires"
53+
)
5154

5255
# status
5356
status = graphene.String()
@@ -75,11 +78,9 @@ async def batch_load_by_session_id(
7578
graph_ctx: GraphQueryContext,
7679
session_ids: Sequence[SessionId],
7780
) -> Sequence[Sequence[Self]]:
78-
from ..kernel import kernels
79-
8081
async with graph_ctx.db.begin_readonly_session() as db_sess:
81-
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
82-
return await batch_multiresult_in_session(
82+
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
83+
return await batch_multiresult_in_scalar_stream(
8384
graph_ctx,
8485
db_sess,
8586
query,
@@ -122,6 +123,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
122123
local_rank=row.local_rank,
123124
cluster_role=row.cluster_role,
124125
session_id=row.session_id,
126+
architecture=row.architecture,
127+
image_reference=row.image,
125128
status=row.status,
126129
status_changed=row.status_changed,
127130
status_info=row.status_info,
@@ -138,6 +141,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
138141
preopen_ports=row.preopen_ports,
139142
)
140143

144+
async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
145+
graph_ctx: GraphQueryContext = info.context
146+
loader = graph_ctx.dataloader_manager.get_loader_by_func(
147+
graph_ctx, ImageNode.batch_load_by_name_and_arch
148+
)
149+
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
150+
try:
151+
return images[0]
152+
except IndexError:
153+
return None
154+
141155
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
142156
graph_ctx: GraphQueryContext = info.context
143157
loader = graph_ctx.dataloader_manager.get_loader_by_func(

0 commit comments

Comments
 (0)