Skip to content

Commit 2759552

Browse files
committed
bugfix(serde): Refactor validation function for enum fields
1 parent d964b81 commit 2759552

2 files changed

Lines changed: 52 additions & 4 deletions

File tree

src/lsst/cmservice/models/serde.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,27 @@
22
other derivative models.
33
"""
44

5+
from enum import EnumType
6+
from functools import partial
7+
58
from pydantic import PlainSerializer, PlainValidator
69

710
from ..common.enums import ManifestKind, StatusEnum
811

12+
13+
def EnumValidator[T: EnumType](value: str | int, enum_: T) -> T:
14+
"""Create an enum from the input value. The input can be either the
15+
enum name or its value.
16+
17+
Used as a Validator for a pydantic field.
18+
"""
19+
try:
20+
new_enum: T = enum_[value] if value in enum_.__members__ else enum_(value)
21+
except (KeyError, ValueError):
22+
raise ValueError(f"Value must be a member of {enum_.__qualname__}")
23+
return new_enum
24+
25+
926
EnumSerializer = PlainSerializer(
1027
lambda x: x.name,
1128
return_type="str",
@@ -14,15 +31,13 @@
1431
"""A serializer for enums that produces its name, not the value."""
1532

1633

17-
StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x))
34+
StatusEnumValidator = PlainValidator(partial(EnumValidator, enum_=StatusEnum))
1835
"""A validator for the StatusEnum that can parse the enum from either a name
1936
or a value.
2037
"""
2138

2239

23-
ManifestKindEnumValidator = PlainValidator(
24-
lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x)
25-
)
40+
ManifestKindEnumValidator = PlainValidator(partial(EnumValidator, enum_=ManifestKind))
2641
"""A validator for the ManifestKindEnum that can parse the enum from a name
2742
or a value.
2843
"""

tests/models/test_serde.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
from pydantic import BaseModel, ValidationError
3+
4+
from lsst.cmservice.common.enums import ManifestKind, StatusEnum
5+
from lsst.cmservice.common.types import KindField, StatusField
6+
7+
8+
class TestModel(BaseModel):
9+
status: StatusField
10+
kind: KindField
11+
12+
13+
def test_validators() -> None:
14+
"""Test model field enum validators."""
15+
# test enum validation by name and value
16+
x = TestModel(status=0, kind="campaign")
17+
assert x.status is StatusEnum.waiting
18+
assert x.kind is ManifestKind.campaign
19+
20+
# test bad input (wrong name)
21+
with pytest.raises(ValidationError):
22+
x = TestModel(status="bad", kind="edge")
23+
24+
# test bad input (bad value)
25+
with pytest.raises(ValidationError):
26+
x = TestModel(status="waiting", kind=99)
27+
28+
29+
def test_serializers() -> None:
30+
x = TestModel(status="accepted", kind="node")
31+
y = x.model_dump()
32+
assert y["status"] == "accepted"
33+
assert y["kind"] == "node"

0 commit comments

Comments
 (0)