Skip to content

Commit c028de5

Browse files
committed
Fix line validation
1 parent c391e43 commit c028de5

File tree

3 files changed

+67
-44
lines changed

3 files changed

+67
-44
lines changed

examples/fodo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def main():
6464
with open(yaml_file, "r") as file:
6565
yaml_data = yaml.safe_load(file)
6666
# Parse YAML data
67-
loaded_line = Line(**yaml_data)
67+
loaded_line = Line(**yaml_data[0])
6868
# Validate loaded data
6969
assert line == loaded_line
7070
# Serialize to JSON
@@ -79,7 +79,7 @@ def main():
7979
with open(json_file, "r") as file:
8080
json_data = json.loads(file.read())
8181
# Parse JSON data
82-
loaded_line = Line(**json_data)
82+
loaded_line = Line(**json_data[0])
8383
# Validate loaded data
8484
assert line == loaded_line
8585

schema/Line.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel, ConfigDict, Field, field_validator
1+
from pydantic import ConfigDict, Field, model_validator
22
from typing import Annotated, List, Literal, Union
33

44
from schema.BaseElement import BaseElement
@@ -29,47 +29,70 @@ class Line(BaseElement):
2929
]
3030
]
3131

32-
@field_validator("line", mode="before")
32+
# @field_validator("line", mode="before")
33+
# @classmethod
34+
# def parse_list_of_dicts(cls, value):
35+
# """This method inserts the key of the one-key dictionary into
36+
# the name attribute of the elements"""
37+
# if not isinstance(value, list):
38+
# raise TypeError("line must be a list")
39+
40+
# if value and isinstance(value[0], BaseModel):
41+
# # Already a list of models; nothing to do
42+
# return value
43+
44+
# # we expect a list of dicts or strings
45+
# elements = []
46+
# for item_dict in value:
47+
# # an element is either a reference string to another element or a dict
48+
# if isinstance(item_dict, str):
49+
# raise RuntimeError("Reference/alias elements not yet implemented")
50+
51+
# elif isinstance(item_dict, dict):
52+
# if not (isinstance(item_dict, dict) and len(item_dict) == 1):
53+
# raise ValueError(
54+
# f"Each line element must be a dict with exactly one key, the name of the element, but we got: {item_dict!r}"
55+
# )
56+
# [(name, fields)] = item_dict.items()
57+
58+
# if not isinstance(fields, dict):
59+
# raise ValueError(
60+
# f"Value for element key '{name}' must be a dict (got {fields!r})"
61+
# )
62+
63+
# # Insert the name into the fields dict
64+
# fields["name"] = name
65+
# elements.append(fields)
66+
# return elements
67+
68+
@model_validator(mode="before")
3369
@classmethod
34-
def parse_list_of_dicts(cls, value):
35-
"""This method inserts the key of the one-key dictionary into
36-
the name attribute of the elements"""
37-
if not isinstance(value, list):
38-
raise TypeError("line must be a list")
39-
40-
if value and isinstance(value[0], BaseModel):
41-
# Already a list of models; nothing to do
42-
return value
43-
44-
# we expect a list of dicts or strings
45-
elements = []
46-
for item_dict in value:
47-
# an element is either a reference string to another element or a dict
48-
if isinstance(item_dict, str):
49-
raise RuntimeError("Reference/alias elements not yet implemented")
50-
51-
elif isinstance(item_dict, dict):
52-
if not (isinstance(item_dict, dict) and len(item_dict) == 1):
53-
raise ValueError(
54-
f"Each line element must be a dict with exactly one key, the name of the element, but we got: {item_dict!r}"
55-
)
56-
[(name, fields)] = item_dict.items()
57-
58-
if not isinstance(fields, dict):
59-
raise ValueError(
60-
f"Value for element key '{name}' must be a dict (got {fields!r})"
61-
)
62-
63-
# Insert the name into the fields dict
64-
fields["name"] = name
65-
elements.append(fields)
66-
return elements
70+
def unpack_yaml_structure(cls, data):
71+
# Handle the top-level dict
72+
if isinstance(data, dict) and len(data) == 1:
73+
name, value = list(data.items())[0]
74+
value = dict(value)
75+
value["name"] = name
76+
data = value
77+
# Now handle the 'line' field if present: unpack each element's name
78+
if "line" in data and isinstance(data["line"], list):
79+
new_line = []
80+
for item in data["line"]:
81+
if isinstance(item, dict) and len(item) == 1:
82+
elem_name, elem_fields = list(item.items())[0]
83+
elem_fields = dict(elem_fields)
84+
elem_fields["name"] = elem_name
85+
new_line.append(elem_fields)
86+
else:
87+
new_line.append(item)
88+
data["line"] = new_line
89+
return data
6790

6891
def model_dump(self, *args, **kwargs):
6992
"""This makes sure the element name property is moved out and up to a one-key dictionary"""
7093
# Use base element dump first and return a one-element list of the form
7194
# [{key: value}], where 'key' is the name of the line and 'value' is a
72-
# dictionary with all other properties
95+
# dict with all other properties
7396
data = super().model_dump(*args, **kwargs)
7497
# Reformat 'line' field as list of single-key dicts
7598
new_line = []

tests/test_schema.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ def test_QuadrupoleElement():
104104
def test_Line():
105105
# Create first line with one base element
106106
element1 = BaseElement(name="element1")
107-
line1 = Line(line=[element1])
107+
line1 = Line(name="line1", line=[element1])
108108
assert line1.line == [element1]
109109
# Extend first line with one thick element
110110
element2 = ThickElement(name="element2", length=2.0)
111111
line1.line.extend([element2])
112112
assert line1.line == [element1, element2]
113113
# Create second line with one drift element
114114
element3 = DriftElement(name="element3", length=3.0)
115-
line2 = Line(line=[element3])
115+
line2 = Line(name="line2", line=[element3])
116116
# Extend first line with second line
117117
line1.line.extend(line2.line)
118118
assert line1.line == [element1, element2, element3]
@@ -124,7 +124,7 @@ def test_yaml():
124124
# Create one thick element
125125
element2 = ThickElement(name="element2", length=2.0)
126126
# Create line with both elements
127-
line = Line(line=[element1, element2])
127+
line = Line(name="line", line=[element1, element2])
128128
# Serialize the Line object to YAML
129129
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
130130
print(f"\n{yaml_data}")
@@ -136,7 +136,7 @@ def test_yaml():
136136
with open(test_file, "r") as file:
137137
yaml_data = yaml.safe_load(file)
138138
# Parse the YAML data back into a Line object
139-
loaded_line = Line(**yaml_data)
139+
loaded_line = Line(**yaml_data[0])
140140
# Remove the test file
141141
os.remove(test_file)
142142
# Validate loaded Line object
@@ -149,7 +149,7 @@ def test_json():
149149
# Create one thick element
150150
element2 = ThickElement(name="element2", length=2.0)
151151
# Create line with both elements
152-
line = Line(line=[element1, element2])
152+
line = Line(name="line", line=[element1, element2])
153153
# Serialize the Line object to JSON
154154
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
155155
print(f"\n{json_data}")
@@ -161,7 +161,7 @@ def test_json():
161161
with open(test_file, "r") as file:
162162
json_data = json.loads(file.read())
163163
# Parse the JSON data back into a Line object
164-
loaded_line = Line(**json_data)
164+
loaded_line = Line(**json_data[0])
165165
# Remove the test file
166166
os.remove(test_file)
167167
# Validate loaded Line object

0 commit comments

Comments
 (0)