Skip to content

Commit 1f8cf0a

Browse files
committed
feat(db): Update Models
- add state machine db model - rearrange shared types and serde helpers - support arbitrary table schema in new db models and tests
1 parent fb887ed commit 1f8cf0a

8 files changed

Lines changed: 139 additions & 42 deletions

File tree

src/lsst/cmservice/common/types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1+
from typing import Annotated
2+
13
from sqlalchemy.ext.asyncio import AsyncSession as AsyncSessionSA
24
from sqlalchemy.ext.asyncio import async_scoped_session
35
from sqlmodel.ext.asyncio.session import AsyncSession
46

57
from .. import models
8+
from ..models.serde import EnumSerializer, ManifestKindEnumValidator, StatusEnumValidator
9+
from .enums import ManifestKind, StatusEnum
610

711
type AnyAsyncSession = AsyncSession | AsyncSessionSA | async_scoped_session
812
"""A type union of async database sessions the application may use"""
913

1014

1115
type AnyCampaignElement = models.Group | models.Campaign | models.Step | models.Job
1216
"""A type union of Campaign elements"""
17+
18+
19+
type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer]
20+
"""A type for fields representing a Status with a custom validator tuned for
21+
enums operations.
22+
"""
23+
24+
25+
type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer]
26+
"""A type for fields representing a Kind with a custom validator tuned for
27+
enums operations.
28+
"""

src/lsst/cmservice/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,8 @@ class DatabaseConfiguration(BaseModel):
546546
description="The password for the cm-service database",
547547
)
548548

549-
table_schema: str | None = Field(
550-
default=None,
549+
table_schema: str = Field(
550+
default="public",
551551
description="Schema to use for cm-service database",
552552
)
553553

src/lsst/cmservice/db/campaigns_v2.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from datetime import datetime
2-
from typing import Annotated, Any
2+
from typing import Any
33
from uuid import NAMESPACE_DNS, UUID, uuid5
44

5-
from pydantic import AliasChoices, PlainSerializer, PlainValidator, ValidationInfo, model_validator
5+
from pydantic import AliasChoices, ValidationInfo, model_validator
66
from sqlalchemy.dialects import postgresql
7-
from sqlalchemy.ext.mutable import MutableDict
7+
from sqlalchemy.ext.mutable import MutableDict, MutableList
88
from sqlalchemy.types import PickleType
9-
from sqlmodel import Column, Enum, Field, SQLModel, String
9+
from sqlmodel import Column, Enum, Field, MetaData, SQLModel, String
1010

1111
from ..common.enums import ManifestKind, StatusEnum
12+
from ..common.types import KindField, StatusField
1213
from ..config import config
1314

1415
_default_campaign_namespace = uuid5(namespace=NAMESPACE_DNS, name="io.lsst.cmservice")
1516
"""Default UUID5 namespace for campaigns"""
1617

18+
metadata: MetaData = MetaData(schema=config.db.table_schema)
19+
"""SQLModel metadata for table models"""
20+
1721

1822
def jsonb_column(name: str, aliases: list[str] | None = None) -> Any:
1923
"""Constructor for a Field based on a JSONB database column.
@@ -47,34 +51,10 @@ def jsonb_column(name: str, aliases: list[str] | None = None) -> Any:
4751
# 3. the model of the manifest when updating an object
4852
# 4. a response model for APIs related to the object
4953

50-
EnumSerializer = PlainSerializer(
51-
lambda x: x.name,
52-
return_type="str",
53-
when_used="always",
54-
)
55-
"""A serializer for enums that produces its name, not the value."""
56-
57-
58-
StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x))
59-
"""A validator for the StatusEnum that can parse the enum from either a name
60-
or a value.
61-
"""
62-
63-
64-
ManifestKindEnumValidator = PlainValidator(
65-
lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x)
66-
)
67-
"""A validator for the ManifestKindEnum that can parse the enum from a name
68-
or a value.
69-
"""
70-
71-
72-
type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer]
73-
type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer]
74-
7554

7655
class BaseSQLModel(SQLModel):
7756
__table_args__ = {"schema": config.db.table_schema}
57+
metadata = metadata
7858

7959

8060
class CampaignBase(BaseSQLModel):
@@ -116,10 +96,10 @@ class Campaign(CampaignModel, table=True):
11696

11797
__tablename__: str = "campaigns_v2" # type: ignore[misc]
11898

119-
machine: UUID | None
99+
machine: UUID | None = None
120100

121101

122-
class CampaignUpdate(SQLModel):
102+
class CampaignUpdate(BaseSQLModel):
123103
"""Model representing updatable fields for a PATCH operation on a Campaign
124104
using RFC7396.
125105
"""
@@ -168,7 +148,7 @@ def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
168148
class Node(NodeModel, table=True):
169149
__tablename__: str = "nodes_v2" # type: ignore[misc]
170150

171-
machine: UUID | None
151+
machine: UUID | None = None
172152

173153

174154
class EdgeBase(BaseSQLModel):
@@ -215,6 +195,12 @@ class MachineBase(BaseSQLModel):
215195
state: Any | None = Field(sa_column=Column("state", PickleType))
216196

217197

198+
class Machine(MachineBase, table=True):
199+
"""machines_v2 db table."""
200+
201+
__tablename__: str = "machines_v2" # type: ignore[misc]
202+
203+
218204
class ManifestBase(BaseSQLModel):
219205
"""manifests_v2 db table"""
220206

@@ -252,7 +238,7 @@ class Manifest(ManifestBase, table=True):
252238
__tablename__: str = "manifests_v2" # type: ignore[misc]
253239

254240

255-
class Task(SQLModel, table=True):
241+
class Task(BaseSQLModel, table=True):
256242
"""tasks_v2 db table"""
257243

258244
__tablename__: str = "tasks_v2" # type: ignore[misc]

src/lsst/cmservice/db/manifests_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlmodel import Field, SQLModel
55

66
from ..common.enums import ManifestKind
7-
from .campaigns_v2 import EnumSerializer, ManifestKindEnumValidator
7+
from ..models.serde import EnumSerializer, ManifestKindEnumValidator
88

99

1010
# this can probably be a BaseModel since this is not a db relation, but the

src/lsst/cmservice/models/serde.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Module for serialization and deserialization support for pydantic and
2+
other derivative models.
3+
"""
4+
5+
from pydantic import PlainSerializer, PlainValidator
6+
7+
from ..common.enums import ManifestKind, StatusEnum
8+
9+
EnumSerializer = PlainSerializer(
10+
lambda x: x.name,
11+
return_type="str",
12+
when_used="always",
13+
)
14+
"""A serializer for enums that produces its name, not the value."""
15+
16+
17+
StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x))
18+
"""A validator for the StatusEnum that can parse the enum from either a name
19+
or a value.
20+
"""
21+
22+
23+
ManifestKindEnumValidator = PlainValidator(
24+
lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x)
25+
)
26+
"""A validator for the ManifestKindEnum that can parse the enum from a name
27+
or a value.
28+
"""

tests/v2/conftest.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
import os
55
from collections.abc import AsyncGenerator, Generator
66
from typing import TYPE_CHECKING
7+
from uuid import uuid4
78

89
import pytest
910
import pytest_asyncio
1011
from fastapi.testclient import TestClient
1112
from httpx import ASGITransport, AsyncClient
1213
from sqlalchemy.pool import NullPool
13-
from sqlmodel import SQLModel
14+
from sqlalchemy.schema import CreateSchema, DropSchema
1415
from testcontainers.postgres import PostgresContainer
1516

1617
from lsst.cmservice.common.types import AnyAsyncSession
1718
from lsst.cmservice.config import config
19+
from lsst.cmservice.db.campaigns_v2 import metadata
1820
from lsst.cmservice.db.session import DatabaseSessionDependency, db_session_dependency
1921

2022
if TYPE_CHECKING:
@@ -37,9 +39,13 @@ async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[Databa
3739
`TEST__LOCAL_DB` is not set; otherwise the fixture will assume that the
3840
correct database is available to the test environment through ordinary
3941
configuration parameters.
42+
43+
The tests are performed within a random temporary schema that is created
44+
and dropped along with the tables.
4045
"""
4146

4247
monkeypatch_module.setattr(target=config.asgi, name="enable_frontend", value=False)
48+
monkeypatch_module.setattr(target=config.db, name="table_schema", value=uuid4().hex[:8])
4349

4450
if os.getenv("TEST__LOCAL_DB") is not None:
4551
db_session_dependency.pool_class = NullPool
@@ -77,11 +83,15 @@ async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSes
7783
# v2 objects are created from the SQLModel metadata.
7884
assert rawdb.engine is not None
7985
async with rawdb.engine.begin() as aconn:
80-
await aconn.run_sync(SQLModel.metadata.drop_all)
81-
await aconn.run_sync(SQLModel.metadata.create_all)
86+
await aconn.run_sync(metadata.drop_all)
87+
await aconn.execute(CreateSchema(config.db.table_schema, if_not_exists=True))
88+
await aconn.run_sync(metadata.create_all)
89+
await aconn.commit()
8290
yield rawdb
8391
async with rawdb.engine.begin() as aconn:
84-
await aconn.run_sync(SQLModel.metadata.drop_all)
92+
await aconn.run_sync(metadata.drop_all)
93+
await aconn.execute(DropSchema(config.db.table_schema, if_exists=True))
94+
await aconn.commit()
8595

8696

8797
@pytest_asyncio.fixture(name="session", scope="module", loop_scope="module")

tests/v2/test_db.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Tests v2 database operations"""
22

3-
from uuid import uuid5
3+
from uuid import uuid4, uuid5
44

55
import pytest
66
from sqlmodel import select
77

8-
from lsst.cmservice.db.campaigns_v2 import Campaign, _default_campaign_namespace
8+
from lsst.cmservice.db.campaigns_v2 import Campaign, Machine, _default_campaign_namespace
99
from lsst.cmservice.db.session import DatabaseSessionDependency
1010

1111

1212
@pytest.mark.asyncio
1313
async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None:
14+
"""Tests the campaigns_v2 table by creating and updating a Campaign."""
15+
1416
assert testdb.sessionmaker is not None
1517

1618
campaign_name = "test_campaign"
@@ -50,3 +52,26 @@ async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None:
5052
assert campaign.configuration["crtime"] == 0
5153
assert campaign.metadata_["crtime"] == 1750107719
5254
assert campaign.metadata_["mtime"] == 0
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_create_machines_v2(testdb: DatabaseSessionDependency) -> None:
59+
"""Tests the machines_v2 table by storing + retrieving a pickled object."""
60+
61+
assert testdb.sessionmaker is not None
62+
63+
# the machines table is a PickleType so it doesn't really matter for this
64+
# test what kind of object is being pickled.
65+
o = {"a": [1, 2, 3, 4, {"aa": [[0, 1], [2, 3]]}]}
66+
67+
machine_id = uuid4()
68+
machine = Machine(id=machine_id, state=o)
69+
async with testdb.sessionmaker() as session:
70+
session.add(machine)
71+
await session.commit()
72+
73+
async with testdb.sessionmaker() as session:
74+
s = select(Machine).where(Machine.id == machine_id).limit(1)
75+
unpickled = (await session.exec(s)).one()
76+
77+
assert unpickled.state == o

tests/v2/test_manifest_routes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,44 @@ async def test_async_patch_manifest(aclient: AsyncClient) -> None:
217217
assert patched_manifest["metadata"]["owner"] == "bob_loblaw"
218218
assert "owner" not in patched_manifest["spec"]
219219

220+
# Using the "test" operator as a gating function, try but fail to update
221+
# the previously moved owner field
222+
x = await aclient.patch(
223+
f"/cm-service/v2/manifests/{manifest_name}",
224+
headers={"Content-Type": "application/json-patch+json"},
225+
json=[
226+
{
227+
"op": "test",
228+
"path": "/spec/owner",
229+
"value": "bob_loblaw",
230+
},
231+
{
232+
"op": "replace",
233+
"path": "/spec/owner",
234+
"value": "lob_boblaw",
235+
},
236+
{
237+
"op": "add",
238+
"path": "/metadata/scope",
239+
"value": "drp",
240+
},
241+
],
242+
)
243+
assert x.is_client_error
244+
220245
# Get the manifest with multiple versions
221246
# First, make sure when not indicated, the most recent version is returned
247+
# Note: the previous patch with a failing test op must not have created any
248+
# new version.
222249
x = await aclient.get(f"/cm-service/v2/manifests/{manifest_name}")
223250
assert x.is_success
224251
assert x.json()["version"] == 3
225252

253+
# RFC6902 prescribes an all-or-nothing patch operation, so the previous op
254+
# with a failing test assertion must not have otherwise completed, e.g.,
255+
# the addition of a "scope" key to the manifest's metadata
256+
assert "scope" not in x.json().get("metadata")
257+
226258
# Next, get a specific version of the manifest
227259
x = await aclient.get(f"/cm-service/v2/manifests/{manifest_name}?version=2")
228260
assert x.is_success

0 commit comments

Comments
 (0)