Skip to content

Commit 4cfe1b3

Browse files
committed
bugfix(serde): Refactor validation function for enum fields
1 parent 1f8cf0a commit 4cfe1b3

3 files changed

Lines changed: 56 additions & 12 deletions

File tree

src/lsst/cmservice/db/manifests_v2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
from typing import Annotated
1+
"""Module for models representing generic CM Service manifests."""
22

33
from pydantic import AliasChoices
44
from sqlmodel import Field, SQLModel
55

66
from ..common.enums import ManifestKind
7-
from ..models.serde import EnumSerializer, ManifestKindEnumValidator
7+
from ..common.types import KindField
88

99

10-
# this can probably be a BaseModel since this is not a db relation, but the
11-
# distinction probably doesn't matter
1210
class ManifestWrapper(SQLModel):
13-
"""a model for an object's Manifest wrapper, used by APIs where the `spec`
11+
"""A model for an object's Manifest wrapper, used by APIs where the `spec`
1412
should be the kind's table model, more or less.
1513
"""
1614

1715
apiversion: str = Field(default="io.lsst.cmservice/v1")
18-
kind: Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] = Field(
19-
default=ManifestKind.other,
20-
)
16+
kind: KindField = Field(default=ManifestKind.other)
2117
metadata_: dict = Field(
2218
default_factory=dict,
2319
schema_extra={

src/lsst/cmservice/models/serde.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,27 @@
22
other derivative models.
33
"""
44

5+
from enum import EnumType
6+
from functools import partial
7+
58
from pydantic import PlainSerializer, PlainValidator
69

710
from ..common.enums import ManifestKind, StatusEnum
811

12+
13+
def EnumValidator[T: EnumType](value: str | int, enum_: T) -> T:
14+
"""Create an enum from the input value. The input can be either the
15+
enum name or its value.
16+
17+
Used as a Validator for a pydantic field.
18+
"""
19+
try:
20+
new_enum: T = enum_[value] if value in enum_.__members__ else enum_(value)
21+
except (KeyError, ValueError):
22+
raise ValueError(f"Value must be a member of {enum_.__qualname__}")
23+
return new_enum
24+
25+
926
EnumSerializer = PlainSerializer(
1027
lambda x: x.name,
1128
return_type="str",
@@ -14,15 +31,13 @@
1431
"""A serializer for enums that produces its name, not the value."""
1532

1633

17-
StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x))
34+
StatusEnumValidator = PlainValidator(partial(EnumValidator, enum_=StatusEnum))
1835
"""A validator for the StatusEnum that can parse the enum from either a name
1936
or a value.
2037
"""
2138

2239

23-
ManifestKindEnumValidator = PlainValidator(
24-
lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x)
25-
)
40+
ManifestKindEnumValidator = PlainValidator(partial(EnumValidator, enum_=ManifestKind))
2641
"""A validator for the ManifestKindEnum that can parse the enum from a name
2742
or a value.
2843
"""

tests/models/test_serde.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
from pydantic import BaseModel, ValidationError
3+
4+
from lsst.cmservice.common.enums import ManifestKind, StatusEnum
5+
from lsst.cmservice.common.types import KindField, StatusField
6+
7+
8+
class TestModel(BaseModel):
9+
status: StatusField
10+
kind: KindField
11+
12+
13+
def test_validators() -> None:
14+
"""Test model field enum validators."""
15+
# test enum validation by name and value
16+
x = TestModel(status=0, kind="campaign")
17+
assert x.status is StatusEnum.waiting
18+
assert x.kind is ManifestKind.campaign
19+
20+
# test bad input (wrong name)
21+
with pytest.raises(ValidationError):
22+
x = TestModel(status="bad", kind="edge")
23+
24+
# test bad input (bad value)
25+
with pytest.raises(ValidationError):
26+
x = TestModel(status="waiting", kind=99)
27+
28+
29+
def test_serializers() -> None:
30+
x = TestModel(status="accepted", kind="node")
31+
y = x.model_dump()
32+
assert y["status"] == "accepted"
33+
assert y["kind"] == "node"

0 commit comments

Comments
 (0)