Skip to content

Commit

Permalink
feat: Add image_node and vfolder_node fields to ComputeSession schema
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 31, 2024
1 parent 10cfd3c commit be01b27
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 45 deletions.
1 change: 1 addition & 0 deletions changes/2987.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add image_node and vfolder_node fields to ComputeSession schema
23 changes: 21 additions & 2 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,21 @@ type Queries {
id: GlobalIDField!

"""Added in 24.09.0."""
project_id: UUID!
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0.")

Check notice on line 159 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Type for argument 'project_id' on field 'Queries.compute_session_node' changed from 'UUID!' to 'UUID'

Changing an input field from non-null to null is considered non-breaking.

"""Added in 24.09.0. Default is read_attribute."""
permission: SessionPermissionValueField = "read_attribute"
): ComputeSessionNode

"""Added in 24.09.0."""
compute_session_nodes(
"""
Added in 24.12.0. Default value `system` queries across the entire system.
"""
scope_id: ScopeField

Check warning on line 170 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'scope_id: ScopeField' added to field 'Queries.compute_session_nodes'

Adding a new argument to an existing field may involve a change in resolve function logic that potentially may cause some side effects.

"""Added in 24.09.0."""
project_id: UUID!
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0. Use `scope_id` instead.")

Check notice on line 173 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Type for argument 'project_id' on field 'Queries.compute_session_nodes' changed from 'UUID!' to 'UUID'

Changing an input field from non-null to null is considered non-breaking.

"""Added in 24.09.0. Default is read_attribute."""
permission: SessionPermissionValueField = "read_attribute"
Expand Down Expand Up @@ -579,6 +584,14 @@ type KernelNode implements Node {
cluster_hostname: String
session_id: UUID
image: ImageNode

"""Added in 24.12.0."""
image_reference: String

Check notice on line 589 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_reference' was added to object type 'KernelNode'

Field 'image_reference' was added to object type 'KernelNode'

"""
Added in 24.12.0. The architecture that the image of this kernel requires
"""
architecture: String

Check notice on line 594 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'architecture' was added to object type 'KernelNode'

Field 'architecture' was added to object type 'KernelNode'
status: String
status_changed: DateTime
status_info: String
Expand Down Expand Up @@ -1177,6 +1190,12 @@ type ComputeSessionNode implements Node {
vfolder_mounts: [String]
occupied_slots: JSONString
requested_slots: JSONString

"""Added in 24.12.0."""
image_references: [String]

Check notice on line 1195 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_references' was added to object type 'ComputeSessionNode'

Field 'image_references' was added to object type 'ComputeSessionNode'

"""Added in 24.12.0."""
vfolder_nodes: [VirtualFolderNode]

Check notice on line 1198 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'
num_queries: BigInt
inference_metrics: JSONString
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ async def from_row(
return cls(
endpoint_id=row.id,
# image="", # deprecated, row.image_object.name,
image_object=ImageNode.from_row(row.image_row),
image_object=ImageNode.from_row(ctx, row.image_row),
domain=row.domain,
project=row.project,
resource_group=row.resource_group,
Expand Down
40 changes: 33 additions & 7 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
LegacyComputeSessionList,
)
from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair
from .rbac import ScopeType, SystemScope
from .rbac import ProjectScope, ScopeType, SystemScope
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
from .resource_policy import (
Expand Down Expand Up @@ -726,7 +726,11 @@ class Queries(graphene.ObjectType):
ComputeSessionNode,
description="Added in 24.09.0.",
id=GlobalIDField(required=True),
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
project_id=graphene.UUID(
required=False,
description="Added in 24.09.0.",
deprecation_reason="Deprecated since 24.12.0.",
),
permission=SessionPermissionValueField(
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
Expand All @@ -736,7 +740,15 @@ class Queries(graphene.ObjectType):
compute_session_nodes = PaginatedConnectionField(
ComputeSessionConnection,
description="Added in 24.09.0.",
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
scope_id=ScopeField(
required=False,
description="Added in 24.12.0. Default value `system` queries across the entire system.",
),
project_id=graphene.UUID(
required=False,
description="Added in 24.09.0.",
deprecation_reason="Deprecated since 24.12.0. Use `scope_id` instead.",
),
permission=SessionPermissionValueField(
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
Expand Down Expand Up @@ -2043,17 +2055,23 @@ async def resolve_compute_session_node(
info: graphene.ResolveInfo,
*,
id: ResolvedGlobalID,
project_id: uuid.UUID,
project_id: Optional[uuid.UUID] = None,
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
) -> ComputeSessionNode | None:
return await ComputeSessionNode.get_accessible_node(info, id, project_id, permission)
scope_id: ScopeType
if project_id is None:
scope_id = SystemScope()
else:
scope_id = ProjectScope(project_id=project_id)
return await ComputeSessionNode.get_accessible_node(info, id, scope_id, permission)

@staticmethod
async def resolve_compute_session_nodes(
root: Any,
info: graphene.ResolveInfo,
*,
project_id: uuid.UUID,
scope_id: Optional[ScopeType] = None,
project_id: Optional[uuid.UUID] = None,
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
filter: str | None = None,
order: str | None = None,
Expand All @@ -2063,9 +2081,17 @@ async def resolve_compute_session_nodes(
before: str | None = None,
last: int | None = None,
) -> ConnectionResolverResult[ComputeSessionNode]:
_scope_id: ScopeType
if scope_id is not None:
_scope_id = scope_id
else:
if project_id is not None:
_scope_id = ProjectScope(project_id=project_id)
else:
_scope_id = SystemScope()
return await ComputeSessionNode.get_accessible_connection(
info,
project_id,
_scope_id,
permission,
filter,
order,
Expand Down
45 changes: 37 additions & 8 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AsyncIterator,
List,
Optional,
Self,
overload,
)
from uuid import UUID
Expand All @@ -27,12 +28,13 @@
ImageAlias,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType

from ...api.exceptions import ImageNotFound, ObjectNotFound
from ...defs import DEFAULT_IMAGE_ARCH
from ..base import set_if_set
from ..gql_relay import AsyncNode
from ..gql_relay import AsyncNode, Connection
from ..image import (
ImageAliasRow,
ImageIdentifier,
Expand Down Expand Up @@ -330,16 +332,37 @@ class Meta:
graphene.String, description="Added in 24.03.4. The array of image aliases."
)

@classmethod
async def batch_load_by_name_and_arch(
cls,
graph_ctx: GraphQueryContext,
name_and_arch: Sequence[tuple[str, str]],
) -> Sequence[Sequence[ImageNode]]:
query = (
sa.select(ImageRow)
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
.options(selectinload(ImageRow.aliases))
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_session,
query,
cls,
name_and_arch,
lambda row: (row.name, row.architecture),
)

@overload
@classmethod
def from_row(cls, row: ImageRow) -> ImageNode: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...

@overload
@classmethod
def from_row(cls, row: None) -> None: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: None) -> None: ...

@classmethod
def from_row(cls, row: ImageRow | None) -> ImageNode | None:
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow | None) -> Self | None:
if row is None:
return None
return cls(
Expand Down Expand Up @@ -401,7 +424,13 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
image_row = await db_session.scalar(query)
if image_row is None:
raise ValueError(f"Image not found (id: {image_id})")
return cls.from_row(image_row)
return cls.from_row(graph_ctx, image_row)


class ImageConnection(Connection):
class Meta:
node = ImageNode
description = "Added in 24.12.0."


class ForgetImageById(graphene.Mutation):
Expand Down Expand Up @@ -453,7 +482,7 @@ async def mutate(
):
return ForgetImageById(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class ForgetImage(graphene.Mutation):
Expand Down Expand Up @@ -500,7 +529,7 @@ async def mutate(
):
return ForgetImage(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class UntagImageFromRegistry(graphene.Mutation):
Expand Down Expand Up @@ -566,7 +595,7 @@ async def mutate(
scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info)
await scanner.untag(image_row.image_ref)

return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row))
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class PreloadImage(graphene.Mutation):
Expand Down
30 changes: 22 additions & 8 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Self,
cast,
)

import graphene
Expand All @@ -14,10 +16,7 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, KernelId, SessionId
from ai.backend.manager.models.base import (
batch_multiresult_in_scalar_stream,
batch_multiresult_in_session,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
Expand Down Expand Up @@ -48,6 +47,10 @@ class Meta:

# image
image = graphene.Field(ImageNode)
image_reference = graphene.String(description="Added in 24.12.0.")
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)

# status
status = graphene.String()
Expand Down Expand Up @@ -75,11 +78,9 @@ async def batch_load_by_session_id(
graph_ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
from ..kernel import kernels

async with graph_ctx.db.begin_readonly_session() as db_sess:
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
return await batch_multiresult_in_session(
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
query,
Expand Down Expand Up @@ -122,6 +123,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
local_rank=row.local_rank,
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image_reference=row.image,
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand All @@ -138,6 +141,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
preopen_ports=row.preopen_ports,
)

async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, ImageNode.batch_load_by_name_and_arch
)
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
try:
return images[0]
except IndexError:
return None

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down
Loading

0 comments on commit be01b27

Please sign in to comment.