Skip to content

Commit 9102836

Browse files
committed
feat(db): Update Models
- add state machine db model - rearrange shared types and serde helpers - support arbitrary table schema in new db models and tests - Refactor validation function for enum fields - Add FK relations to model fields
1 parent 2b0c852 commit 9102836

9 files changed

Lines changed: 195 additions & 53 deletions

File tree

src/lsst/cmservice/common/types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1+
from typing import Annotated
2+
13
from sqlalchemy.ext.asyncio import AsyncSession as AsyncSessionSA
24
from sqlalchemy.ext.asyncio import async_scoped_session
35
from sqlmodel.ext.asyncio.session import AsyncSession
46

57
from .. import models
8+
from ..models.serde import EnumSerializer, ManifestKindEnumValidator, StatusEnumValidator
9+
from .enums import ManifestKind, StatusEnum
610

711
type AnyAsyncSession = AsyncSession | AsyncSessionSA | async_scoped_session
812
"""A type union of async database sessions the application may use"""
913

1014

1115
type AnyCampaignElement = models.Group | models.Campaign | models.Step | models.Job
1216
"""A type union of Campaign elements"""
17+
18+
19+
type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer]
20+
"""A type for fields representing a Status with a custom validator tuned for
21+
enums operations.
22+
"""
23+
24+
25+
type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer]
26+
"""A type for fields representing a Kind with a custom validator tuned for
27+
enums operations.
28+
"""

src/lsst/cmservice/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,8 @@ class DatabaseConfiguration(BaseModel):
551551
description="The password for the cm-service database",
552552
)
553553

554-
table_schema: str | None = Field(
555-
default=None,
554+
table_schema: str = Field(
555+
default="public",
556556
description="Schema to use for cm-service database",
557557
)
558558

src/lsst/cmservice/db/campaigns_v2.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from datetime import datetime
2-
from typing import Annotated, Any
2+
from typing import Any
33
from uuid import NAMESPACE_DNS, UUID, uuid5
44

5-
from pydantic import AliasChoices, PlainSerializer, PlainValidator, ValidationInfo, model_validator
5+
from pydantic import AliasChoices, ValidationInfo, model_validator
66
from sqlalchemy.dialects import postgresql
77
from sqlalchemy.ext.mutable import MutableDict, MutableList
88
from sqlalchemy.types import PickleType
9-
from sqlmodel import Column, Enum, Field, SQLModel, String
9+
from sqlmodel import Column, Enum, Field, MetaData, SQLModel, String
1010

1111
from ..common.enums import ManifestKind, StatusEnum
12+
from ..common.types import KindField, StatusField
1213
from ..config import config
1314

1415
_default_campaign_namespace = uuid5(namespace=NAMESPACE_DNS, name="io.lsst.cmservice")
1516
"""Default UUID5 namespace for campaigns"""
1617

18+
metadata: MetaData = MetaData(schema=config.db.table_schema)
19+
"""SQLModel metadata for table models"""
20+
1721

1822
def jsonb_column(name: str, aliases: list[str] | None = None) -> Any:
1923
"""Constructor for a Field based on a JSONB database column.
@@ -47,34 +51,10 @@ def jsonb_column(name: str, aliases: list[str] | None = None) -> Any:
4751
# 3. the model of the manifest when updating an object
4852
# 4. a response model for APIs related to the object
4953

50-
EnumSerializer = PlainSerializer(
51-
lambda x: x.name,
52-
return_type="str",
53-
when_used="always",
54-
)
55-
"""A serializer for enums that produces its name, not the value."""
56-
57-
58-
StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x))
59-
"""A validator for the StatusEnum that can parse the enum from either a name
60-
or a value.
61-
"""
62-
63-
64-
ManifestKindEnumValidator = PlainValidator(
65-
lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x)
66-
)
67-
"""A validator for the ManifestKindEnum that can parse the enum from a name
68-
or a value.
69-
"""
70-
71-
72-
type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer]
73-
type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer]
74-
7554

7655
class BaseSQLModel(SQLModel):
7756
__table_args__ = {"schema": config.db.table_schema}
57+
metadata = metadata
7858

7959

8060
class CampaignBase(BaseSQLModel):
@@ -116,10 +96,10 @@ class Campaign(CampaignModel, table=True):
11696

11797
__tablename__: str = "campaigns_v2" # type: ignore[misc]
11898

119-
machine: UUID | None
99+
machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE")
120100

121101

122-
class CampaignUpdate(SQLModel):
102+
class CampaignUpdate(BaseSQLModel):
123103
"""Model representing updatable fields for a PATCH operation on a Campaign
124104
using RFC7396.
125105
"""
@@ -168,7 +148,7 @@ def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
168148
class Node(NodeModel, table=True):
169149
__tablename__: str = "nodes_v2" # type: ignore[misc]
170150

171-
machine: UUID | None
151+
machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE")
172152

173153

174154
class EdgeBase(BaseSQLModel):
@@ -215,13 +195,19 @@ class MachineBase(BaseSQLModel):
215195
state: Any | None = Field(sa_column=Column("state", PickleType))
216196

217197

198+
class Machine(MachineBase, table=True):
199+
"""machines_v2 db table."""
200+
201+
__tablename__: str = "machines_v2" # type: ignore[misc]
202+
203+
218204
class ManifestBase(BaseSQLModel):
219205
"""manifests_v2 db table"""
220206

221207
id: UUID = Field(primary_key=True)
222208
name: str
223209
version: int
224-
namespace: UUID
210+
namespace: UUID = Field(foreign_key="campaigns_v2.id")
225211
kind: KindField = Field(
226212
default=ManifestKind.other,
227213
sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)),
@@ -252,14 +238,14 @@ class Manifest(ManifestBase, table=True):
252238
__tablename__: str = "manifests_v2" # type: ignore[misc]
253239

254240

255-
class Task(SQLModel, table=True):
241+
class Task(BaseSQLModel, table=True):
256242
"""tasks_v2 db table"""
257243

258244
__tablename__: str = "tasks_v2" # type: ignore[misc]
259245

260246
id: UUID = Field(primary_key=True)
261-
namespace: UUID
262-
node: UUID
247+
namespace: UUID = Field(foreign_key="campaigns_v2.id")
248+
node: UUID = Field(foreign_key="nodes_v2.id")
263249
priority: int
264250
created_at: datetime
265251
last_processed_at: datetime
@@ -280,8 +266,8 @@ class Task(SQLModel, table=True):
280266

281267
class ActivityLogBase(BaseSQLModel):
282268
id: UUID = Field(primary_key=True)
283-
namespace: UUID
284-
node: UUID
269+
namespace: UUID = Field(foreign_key="campaigns_v2.id")
270+
node: UUID = Field(foreign_key="nodes_v2.id")
285271
operator: str
286272
to_status: StatusField = Field(
287273
sa_column=Column(

src/lsst/cmservice/db/manifests_v2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
from typing import Annotated
1+
"""Module for models representing generic CM Service manifests."""
22

33
from pydantic import AliasChoices
44
from sqlmodel import Field, SQLModel
55

66
from ..common.enums import ManifestKind
7-
from .campaigns_v2 import EnumSerializer, ManifestKindEnumValidator
7+
from ..common.types import KindField
88

99

10-
# this can probably be a BaseModel since this is not a db relation, but the
11-
# distinction probably doesn't matter
1210
class ManifestWrapper(SQLModel):
13-
"""a model for an object's Manifest wrapper, used by APIs where the `spec`
11+
"""A model for an object's Manifest wrapper, used by APIs where the `spec`
1412
should be the kind's table model, more or less.
1513
"""
1614

1715
apiversion: str = Field(default="io.lsst.cmservice/v1")
18-
kind: Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] = Field(
19-
default=ManifestKind.other,
20-
)
16+
kind: KindField = Field(default=ManifestKind.other)
2117
metadata_: dict = Field(
2218
default_factory=dict,
2319
schema_extra={

src/lsst/cmservice/models/serde.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Module for serialization and deserialization support for pydantic and
2+
other derivative models.
3+
"""
4+
5+
from enum import EnumType
6+
from functools import partial
7+
from typing import Any
8+
9+
from pydantic import PlainSerializer, PlainValidator
10+
11+
from ..common.enums import ManifestKind, StatusEnum
12+
13+
14+
def EnumValidator[T: EnumType](value: Any, enum_: T) -> T:
15+
"""Create an enum from the input value. The input can be either the
16+
enum name or its value.
17+
18+
Used as a Validator for a pydantic field.
19+
"""
20+
try:
21+
new_enum: T = enum_[value] if value in enum_.__members__ else enum_(value)
22+
except (KeyError, ValueError):
23+
raise ValueError(f"Value must be a member of {enum_.__qualname__}")
24+
return new_enum
25+
26+
27+
EnumSerializer = PlainSerializer(
28+
lambda x: x.name,
29+
return_type="str",
30+
when_used="always",
31+
)
32+
"""A serializer for enums that produces its name, not the value."""
33+
34+
35+
StatusEnumValidator = PlainValidator(partial(EnumValidator, enum_=StatusEnum))
36+
"""A validator for the StatusEnum that can parse the enum from either a name
37+
or a value.
38+
"""
39+
40+
41+
ManifestKindEnumValidator = PlainValidator(partial(EnumValidator, enum_=ManifestKind))
42+
"""A validator for the ManifestKindEnum that can parse the enum from a name
43+
or a value.
44+
"""

tests/models/test_serde.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
from pydantic import BaseModel, ValidationError
3+
4+
from lsst.cmservice.common.enums import ManifestKind, StatusEnum
5+
from lsst.cmservice.common.types import KindField, StatusField
6+
7+
8+
class TestModel(BaseModel):
9+
status: StatusField
10+
kind: KindField
11+
12+
13+
def test_validators() -> None:
14+
"""Test model field enum validators."""
15+
# test enum validation by name and value
16+
x = TestModel(status=0, kind="campaign")
17+
assert x.status is StatusEnum.waiting
18+
assert x.kind is ManifestKind.campaign
19+
20+
# test bad input (wrong name)
21+
with pytest.raises(ValidationError):
22+
x = TestModel(status="bad", kind="edge")
23+
24+
# test bad input (bad value)
25+
with pytest.raises(ValidationError):
26+
x = TestModel(status="waiting", kind=99)
27+
28+
29+
def test_serializers() -> None:
30+
x = TestModel(status="accepted", kind="node")
31+
y = x.model_dump()
32+
assert y["status"] == "accepted"
33+
assert y["kind"] == "node"

tests/v2/conftest.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
import os
55
from collections.abc import AsyncGenerator, Generator
66
from typing import TYPE_CHECKING
7+
from uuid import uuid4
78

89
import pytest
910
import pytest_asyncio
1011
from fastapi.testclient import TestClient
1112
from httpx import ASGITransport, AsyncClient
1213
from sqlalchemy.pool import NullPool
13-
from sqlmodel import SQLModel
14+
from sqlalchemy.schema import CreateSchema, DropSchema
1415
from testcontainers.postgres import PostgresContainer
1516

1617
from lsst.cmservice.common.types import AnyAsyncSession
1718
from lsst.cmservice.config import config
19+
from lsst.cmservice.db.campaigns_v2 import metadata
1820
from lsst.cmservice.db.session import DatabaseSessionDependency, db_session_dependency
1921

2022
if TYPE_CHECKING:
@@ -37,9 +39,13 @@ async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[Databa
3739
`TEST__LOCAL_DB` is not set; otherwise the fixture will assume that the
3840
correct database is available to the test environment through ordinary
3941
configuration parameters.
42+
43+
The tests are performed within a random temporary schema that is created
44+
and dropped along with the tables.
4045
"""
4146

4247
monkeypatch_module.setattr(target=config.asgi, name="enable_frontend", value=False)
48+
monkeypatch_module.setattr(target=config.db, name="table_schema", value=uuid4().hex[:8])
4349

4450
if os.getenv("TEST__LOCAL_DB") is not None:
4551
db_session_dependency.pool_class = NullPool
@@ -77,11 +83,15 @@ async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSes
7783
# v2 objects are created from the SQLModel metadata.
7884
assert rawdb.engine is not None
7985
async with rawdb.engine.begin() as aconn:
80-
await aconn.run_sync(SQLModel.metadata.drop_all)
81-
await aconn.run_sync(SQLModel.metadata.create_all)
86+
await aconn.run_sync(metadata.drop_all)
87+
await aconn.execute(CreateSchema(config.db.table_schema, if_not_exists=True))
88+
await aconn.run_sync(metadata.create_all)
89+
await aconn.commit()
8290
yield rawdb
8391
async with rawdb.engine.begin() as aconn:
84-
await aconn.run_sync(SQLModel.metadata.drop_all)
92+
await aconn.run_sync(metadata.drop_all)
93+
await aconn.execute(DropSchema(config.db.table_schema, if_exists=True))
94+
await aconn.commit()
8595

8696

8797
@pytest_asyncio.fixture(name="session", scope="module", loop_scope="module")

tests/v2/test_db.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Tests v2 database operations"""
22

3-
from uuid import uuid5
3+
from uuid import uuid4, uuid5
44

55
import pytest
66
from sqlmodel import select
77

8-
from lsst.cmservice.db.campaigns_v2 import Campaign, _default_campaign_namespace
8+
from lsst.cmservice.db.campaigns_v2 import Campaign, Machine, _default_campaign_namespace
99
from lsst.cmservice.db.session import DatabaseSessionDependency
1010

1111

1212
@pytest.mark.asyncio
1313
async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None:
14+
"""Tests the campaigns_v2 table by creating and updating a Campaign."""
15+
1416
assert testdb.sessionmaker is not None
1517

1618
campaign_name = "test_campaign"
@@ -50,3 +52,26 @@ async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None:
5052
assert campaign.configuration["crtime"] == 0
5153
assert campaign.metadata_["crtime"] == 1750107719
5254
assert campaign.metadata_["mtime"] == 0
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_create_machines_v2(testdb: DatabaseSessionDependency) -> None:
59+
"""Tests the machines_v2 table by storing + retrieving a pickled object."""
60+
61+
assert testdb.sessionmaker is not None
62+
63+
# the machines table is a PickleType so it doesn't really matter for this
64+
# test what kind of object is being pickled.
65+
o = {"a": [1, 2, 3, 4, {"aa": [[0, 1], [2, 3]]}]}
66+
67+
machine_id = uuid4()
68+
machine = Machine(id=machine_id, state=o)
69+
async with testdb.sessionmaker() as session:
70+
session.add(machine)
71+
await session.commit()
72+
73+
async with testdb.sessionmaker() as session:
74+
s = select(Machine).where(Machine.id == machine_id).limit(1)
75+
unpickled = (await session.exec(s)).one()
76+
77+
assert unpickled.state == o

0 commit comments

Comments
 (0)