Skip to content

Commit f2afa38

Browse files
fix: Make ScanArtifactModelsInput's registry_id to optional (#5733)
Co-authored-by: octodog <mu001@lablup.com>
1 parent 011db8f commit f2afa38

15 files changed

Lines changed: 152 additions & 44 deletions

File tree

docs/manager/graphql-reference/supergraph.graphql

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ enum ArtifactOrderField
271271
NAME @join__enumValue(graph: STRAWBERRY)
272272
TYPE @join__enumValue(graph: STRAWBERRY)
273273
SIZE @join__enumValue(graph: STRAWBERRY)
274+
SCANNED_AT @join__enumValue(graph: STRAWBERRY)
275+
UPDATED_AT @join__enumValue(graph: STRAWBERRY)
274276
}
275277

276278
"""Added in 25.14.0"""
@@ -2912,7 +2914,7 @@ input ModelTarget
29122914
@join__type(graph: STRAWBERRY)
29132915
{
29142916
modelId: String!
2915-
revision: String!
2917+
revision: String = null
29162918
}
29172919

29182920
type ModifyAgent
@@ -4861,7 +4863,7 @@ input ScanArtifactModelsInput
48614863
@join__type(graph: STRAWBERRY)
48624864
{
48634865
models: [ModelTarget!]!
4864-
registryId: UUID!
4866+
registryId: UUID = null
48654867
}
48664868

48674869
"""Added in 25.14.0"""

docs/manager/graphql-reference/v2-schema.graphql

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ enum ArtifactOrderField {
7676
NAME
7777
TYPE
7878
SIZE
79+
SCANNED_AT
80+
UPDATED_AT
7981
}
8082

8183
"""Added in 25.14.0"""
@@ -615,7 +617,7 @@ input ModelRuntimeConfigInput {
615617
"""Added in 25.14.0"""
616618
input ModelTarget {
617619
modelId: String!
618-
revision: String!
620+
revision: String = null
619621
}
620622

621623
type Mutation {
@@ -951,7 +953,7 @@ type ScalingRule {
951953
"""Added in 25.14.0"""
952954
input ScanArtifactModelsInput {
953955
models: [ModelTarget!]!
954-
registryId: UUID!
956+
registryId: UUID = null
955957
}
956958

957959
"""Added in 25.14.0"""

src/ai/backend/common/data/storage/registries/types.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
from pydantic import BaseModel, Field
66

7+
from ai.backend.common.data.artifact.types import ArtifactRegistryType
8+
from ai.backend.common.exception import ArtifactDefaultRevisionResolveError
9+
710

811
class ModelSortKey(enum.StrEnum):
912
LAST_MODIFIED = "last_modified"
@@ -13,6 +16,7 @@ class ModelSortKey(enum.StrEnum):
1316
LIKES = "likes"
1417

1518

19+
# TODO: Separate of the ModelTarget type used in the storage proxy and the one used in the manager
1620
class ModelTarget(BaseModel):
1721
model_id: str = Field(
1822
description="""
@@ -21,15 +25,27 @@ class ModelTarget(BaseModel):
2125
""",
2226
examples=["microsoft/DialoGPT-medium", "openai/gpt-2", "bert-base-uncased"],
2327
)
24-
revision: str = Field(
25-
default="main",
28+
revision: Optional[str] = Field(
29+
default=None,
2630
description="""
2731
Specific revision (branch or tag) of the model to import.
2832
Defaults to 'main' if not specified.
2933
""",
3034
examples=["main", "v1.0", "latest"],
3135
)
3236

37+
def resolve_revision(self, registry_type: ArtifactRegistryType) -> str:
38+
if self.revision is not None:
39+
return self.revision
40+
41+
match registry_type:
42+
case ArtifactRegistryType.HUGGINGFACE:
43+
return "main"
44+
case _:
45+
raise ArtifactDefaultRevisionResolveError(
46+
f"Cannot resolve default revision for registry type: {registry_type}"
47+
)
48+
3349

3450
class FileObjectData(BaseModel):
3551
"""

src/ai/backend/common/exception.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,16 @@ def error_code(cls) -> ErrorCode:
552552
operation=ErrorOperation.GENERIC,
553553
error_detail=ErrorDetail.INTERNAL_ERROR,
554554
)
555+
556+
557+
class ArtifactDefaultRevisionResolveError(BackendAIError, web.HTTPBadRequest):
558+
error_type = "https://api.backend.ai/probs/artifact-revision-resolve-failed"
559+
error_title = "Cannot Resolve Artifact Default Revision"
560+
561+
@classmethod
562+
def error_code(cls) -> ErrorCode:
563+
return ErrorCode(
564+
domain=ErrorDomain.ARTIFACT,
565+
operation=ErrorOperation.REQUEST,
566+
error_detail=ErrorDetail.BAD_REQUEST,
567+
)

src/ai/backend/manager/api/gql/artifact.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class ArtifactStatusChangedInput:
205205
@strawberry.input(description="Added in 25.14.0")
206206
class ModelTarget:
207207
model_id: str
208-
revision: str
208+
revision: Optional[str] = None
209209

210210
def to_dataclass(self) -> ModelTargetData:
211211
return ModelTargetData(model_id=self.model_id, revision=self.revision)
@@ -214,7 +214,7 @@ def to_dataclass(self) -> ModelTargetData:
214214
@strawberry.input(description="Added in 25.14.0")
215215
class ScanArtifactModelsInput:
216216
models: list[ModelTarget]
217-
registry_id: uuid.UUID
217+
registry_id: Optional[uuid.UUID] = None
218218

219219

220220
# Object Types
@@ -280,18 +280,6 @@ async def revisions(
280280
offset=offset,
281281
)
282282

283-
# TODO: Is this necessary?
284-
# @strawberry.field
285-
# async def updated_at(self, info: Info[StrawberryGQLContext]) -> Optional[datetime]:
286-
# action_result = await info.context.processors.artifact.get_revisions.wait_for_complete(
287-
# GetArtifactRevisionsAction(uuid.UUID(self.id))
288-
# )
289-
290-
# updated_at_list = [
291-
# r.updated_at for r in action_result.revisions if r.updated_at is not None
292-
# ]
293-
# return max(updated_at_list) if updated_at_list else None
294-
295283

296284
@strawberry.type(description="Added in 25.14.0")
297285
class ArtifactRevision(Node):

src/ai/backend/manager/data/artifact/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class ArtifactOrderField(enum.StrEnum):
6565
NAME = "NAME"
6666
TYPE = "TYPE"
6767
SIZE = "SIZE"
68+
SCANNED_AT = "SCANNED_AT"
69+
UPDATED_AT = "UPDATED_AT"
6870

6971

7072
class ArtifactRevisionOrderField(enum.StrEnum):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Make artifact scanned_at, updated_at columns timezone aware
2+
3+
Revision ID: 982a7c19a06c
4+
Revises: 5b171528a6f5
5+
Create Date: 2025-09-04 03:23:38.648535
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "982a7c19a06c"
15+
down_revision = "5b171528a6f5"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.alter_column(
23+
"artifacts",
24+
"scanned_at",
25+
existing_type=postgresql.TIMESTAMP(),
26+
type_=sa.DateTime(timezone=True),
27+
existing_nullable=False,
28+
existing_server_default=sa.text("now()"),
29+
)
30+
op.alter_column(
31+
"artifacts",
32+
"updated_at",
33+
existing_type=postgresql.TIMESTAMP(),
34+
type_=sa.DateTime(timezone=True),
35+
existing_nullable=False,
36+
existing_server_default=sa.text("now()"),
37+
)
38+
# ### end Alembic commands ###
39+
40+
41+
def downgrade() -> None:
42+
# ### commands auto generated by Alembic - please adjust! ###
43+
op.alter_column(
44+
"artifacts",
45+
"updated_at",
46+
existing_type=sa.DateTime(timezone=False),
47+
type_=postgresql.TIMESTAMP(),
48+
existing_nullable=False,
49+
existing_server_default=sa.text("now()"),
50+
)
51+
op.alter_column(
52+
"artifacts",
53+
"scanned_at",
54+
existing_type=sa.DateTime(timezone=False),
55+
type_=postgresql.TIMESTAMP(),
56+
existing_nullable=False,
57+
existing_server_default=sa.text("now()"),
58+
)
59+
# ### end Alembic commands ###

src/ai/backend/manager/models/artifact.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,12 @@ class ArtifactRow(Base):
4949
source_registry_type = sa.Column("source_registry_type", sa.String, nullable=False, index=True)
5050
description = sa.Column("description", sa.String, nullable=True)
5151
readonly = sa.Column("readonly", sa.Boolean, default=False, nullable=False)
52-
scanned_at = sa.Column("scanned_at", sa.DateTime, nullable=False, server_default=sa.func.now())
53-
updated_at = sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now())
52+
scanned_at = sa.Column(
53+
"scanned_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()
54+
)
55+
updated_at = sa.Column(
56+
"updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()
57+
)
5458

5559
huggingface_registry = relationship(
5660
"HuggingFaceRegistryRow",

src/ai/backend/manager/repositories/artifact/repository.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uuid
2-
from datetime import datetime
2+
from datetime import datetime, timezone
33
from typing import Any, Optional, override
44

55
import sqlalchemy as sa
@@ -122,7 +122,7 @@ class ArtifactOrderingApplier(BaseOrderingApplier[ArtifactOrderingOptions]):
122122
@override
123123
def get_order_column(self, field) -> sa.Column:
124124
"""Get the SQLAlchemy column for the given artifact field"""
125-
return getattr(ArtifactRow, field.value, ArtifactRow.name)
125+
return getattr(ArtifactRow, field.value.lower(), ArtifactRow.name)
126126

127127

128128
class ArtifactModelConverter:
@@ -323,7 +323,7 @@ async def upsert_artifacts(
323323
has_changes = existing_artifact.description != artifact_data.description
324324
if has_changes:
325325
existing_artifact.description = artifact_data.description
326-
existing_artifact.updated_at = datetime.now()
326+
existing_artifact.updated_at = datetime.now(timezone.utc)
327327

328328
await db_sess.flush()
329329
await db_sess.refresh(
@@ -745,7 +745,10 @@ async def list_artifacts_paginated(
745745
result = await db_sess.execute(querybuild_result.data_query)
746746
rows = result.scalars().all()
747747

748+
# Build count query with same filters applied
748749
count_stmt = sa.select(sa.func.count()).select_from(ArtifactRow)
750+
if filters is not None:
751+
count_stmt = artifact_paginator.filter_applier.apply_filters(count_stmt, filters)
749752
count_result = await db_sess.execute(count_stmt)
750753
total_count = count_result.scalar()
751754

@@ -803,7 +806,10 @@ async def list_artifacts_with_revisions_paginated(
803806
result = await db_sess.execute(querybuild_result.data_query)
804807
rows = result.scalars().all()
805808

809+
# Build count query with same filters applied
806810
count_stmt = sa.select(sa.func.count()).select_from(ArtifactRow)
811+
if filters is not None:
812+
count_stmt = artifact_paginator.filter_applier.apply_filters(count_stmt, filters)
807813
count_result = await db_sess.execute(count_stmt)
808814
total_count = count_result.scalar()
809815

src/ai/backend/manager/repositories/artifact/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ArtifactOrderingOptions:
1818
"""Ordering options for artifact queries."""
1919

2020
order_by: list[tuple[ArtifactOrderField, bool]] = field(
21-
default_factory=lambda: [(ArtifactOrderField.NAME, True)]
21+
default_factory=lambda: [(ArtifactOrderField.NAME, False)]
2222
) # (field, desc)
2323

2424

@@ -27,7 +27,7 @@ class ArtifactRevisionOrderingOptions:
2727
"""Ordering options for artifact revision queries."""
2828

2929
order_by: list[tuple[ArtifactRevisionOrderField, bool]] = field(
30-
default_factory=lambda: [(ArtifactRevisionOrderField.CREATED_AT, True)]
30+
default_factory=lambda: [(ArtifactRevisionOrderField.CREATED_AT, False)]
3131
) # (field, desc)
3232

3333

0 commit comments

Comments
 (0)