Skip to content

Commit d6b721a

Browse files
committed
Implement UnionEle
Deduplicate logic with `BeamLine` for allowed types and (de)serialization.
1 parent 3e76c0a commit d6b721a

File tree

6 files changed

+159
-147
lines changed

6 files changed

+159
-147
lines changed

src/pals/kinds/BeamLine.py

Lines changed: 12 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,27 @@
1-
from pydantic import ConfigDict, Field, model_validator
2-
from typing import Annotated, List, Literal, Union
1+
from pydantic import model_validator
2+
from typing import List, Literal
33

4-
from .BaseElement import BaseElement
5-
from .ThickElement import ThickElement
6-
7-
from .ACKicker import ACKicker
8-
from .BeamBeam import BeamBeam
9-
from .BeginningEle import BeginningEle
10-
from .Converter import Converter
11-
from .CrabCavity import CrabCavity
12-
from .Drift import Drift
13-
from .EGun import EGun
14-
from .Feedback import Feedback
15-
from .Fiducial import Fiducial
16-
from .FloorShift import FloorShift
17-
from .Foil import Foil
18-
from .Fork import Fork
19-
from .Girder import Girder
20-
from .Instrument import Instrument
21-
from .Kicker import Kicker
22-
from .Marker import Marker
23-
from .Mask import Mask
24-
from .Match import Match
25-
from .Multipole import Multipole
26-
from .NullEle import NullEle
27-
from .Octupole import Octupole
28-
from .Patch import Patch
29-
from .Quadrupole import Quadrupole
30-
from .RBend import RBend
31-
from .RFCavity import RFCavity
32-
from .SBend import SBend
33-
from .Sextupole import Sextupole
34-
from .Solenoid import Solenoid
35-
from .Taylor import Taylor
36-
from .UnionEle import UnionEle
37-
from .Wiggler import Wiggler
4+
from .all_elements import get_all_elements_as_annotation
5+
from .mixin import BaseElement
386

397

408
class BeamLine(BaseElement):
419
"""A line of elements and/or other lines"""
4210

43-
# Validate every time a new value is assigned to an attribute,
44-
# not only when an instance of BeamLine is created
45-
model_config = ConfigDict(validate_assignment=True)
46-
4711
kind: Literal["BeamLine"] = "BeamLine"
4812

49-
line: List[
50-
Annotated[
51-
Union[
52-
# Base classes (for testing compatibility)
53-
BaseElement,
54-
ThickElement,
55-
# User-Facing element kinds
56-
"BeamLine",
57-
ACKicker,
58-
BeamBeam,
59-
BeginningEle,
60-
Converter,
61-
CrabCavity,
62-
Drift,
63-
EGun,
64-
Feedback,
65-
Fiducial,
66-
FloorShift,
67-
Foil,
68-
Fork,
69-
Girder,
70-
Instrument,
71-
Kicker,
72-
Marker,
73-
Mask,
74-
Match,
75-
Multipole,
76-
NullEle,
77-
Octupole,
78-
Patch,
79-
Quadrupole,
80-
RBend,
81-
RFCavity,
82-
SBend,
83-
Sextupole,
84-
Solenoid,
85-
Taylor,
86-
UnionEle,
87-
Wiggler,
88-
],
89-
Field(discriminator="kind"),
90-
]
91-
]
13+
line: List[get_all_elements_as_annotation()]
9214

9315
@model_validator(mode="before")
9416
@classmethod
9517
def unpack_yaml_structure(cls, data):
96-
# Handle the top-level one-key dict: unpack the line's name
97-
if isinstance(data, dict) and len(data) == 1:
98-
name, value = list(data.items())[0]
99-
if not isinstance(value, dict):
100-
raise TypeError(
101-
f"Value for line key {name!r} must be a dict, but we got {value!r}"
102-
)
103-
value["name"] = name
104-
data = value
105-
# Handle the 'line' field: unpack each element's name
106-
if "line" not in data:
107-
raise ValueError("'line' field is missing")
108-
if not isinstance(data["line"], list):
109-
raise TypeError("'line' must be a list")
110-
new_line = []
111-
# Loop over all elements in the line
112-
for item in data["line"]:
113-
# An element can be a string that refers to another element
114-
if isinstance(item, str):
115-
raise RuntimeError("Reference/alias elements not yet implemented")
116-
# An element can be a dict
117-
elif isinstance(item, dict):
118-
if not (len(item) == 1):
119-
raise ValueError(
120-
f"Each element must be a dict with exactly one key (the element's name), but we got {item!r}"
121-
)
122-
name, fields = list(item.items())[0]
123-
if not isinstance(fields, dict):
124-
raise TypeError(
125-
f"Value for element key {name!r} must be a dict (the element's properties), but we got {fields!r}"
126-
)
127-
fields["name"] = name
128-
new_line.append(fields)
129-
# An element can be an instance of an existing model
130-
elif isinstance(item, BaseElement):
131-
# Nothing to do, keep the element as is
132-
new_line.append(item)
133-
else:
134-
raise TypeError(
135-
f"Value for element key {name!r} must be a reference string or a dict, but we got {item!r}"
136-
)
137-
data["line"] = new_line
138-
return data
18+
"""Unpack YAML/JSON/...-like structure for BeamLine elements"""
19+
from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure
13920

140-
def model_dump(self, *args, **kwargs):
141-
"""This makes sure the element name property is moved out and up to a one-key dictionary"""
142-
# Use base element dump first and return a dict {key: value}, where 'key'
143-
# is the name of the line and 'value' is a dict with all other properties
144-
data = super().model_dump(*args, **kwargs)
145-
# Reformat 'line' field as list of element dicts
146-
new_line = []
147-
for elem in self.line:
148-
# Use custom dump for each line element, which now returns a dict
149-
elem_dict = elem.model_dump(**kwargs)
150-
new_line.append(elem_dict)
151-
data[self.name]["line"] = new_line
152-
return data
21+
return unpack_element_list_structure(data, "line", "line")
15322

23+
def model_dump(self, *args, **kwargs):
24+
"""Custom model dump for BeamLine to handle element list formatting"""
25+
from pals.kinds.mixin.all_element_mixin import dump_element_list
15426

155-
# Avoid circular import issues
156-
BeamLine.model_rebuild()
27+
return dump_element_list(self, "line", *args, **kwargs)

src/pals/kinds/UnionEle.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import model_validator
1+
from pydantic import model_validator # noqa
22
from typing import List, Literal
33

44
from .all_elements import get_all_elements_as_annotation
@@ -11,6 +11,19 @@ class UnionEle(BaseElement):
1111
# Discriminator field
1212
kind: Literal["UnionEle"] = "UnionEle"
1313

14-
# Elements in the union
15-
# Note: https://github.com/campa-consortium/pals/issues/89
16-
elements: List[BaseElement] = []
14+
# Elements in the union - uses the same union type as BeamLine
15+
elements: List[get_all_elements_as_annotation()] = []
16+
17+
@model_validator(mode="before")
18+
@classmethod
19+
def unpack_yaml_structure(cls, data):
20+
"""Unpack YAML/JSON/...-like structure for UnionEle elements"""
21+
from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure
22+
23+
return unpack_element_list_structure(data, "elements", "union")
24+
25+
def model_dump(self, *args, **kwargs):
26+
"""Custom model dump for UnionEle to handle element list formatting"""
27+
from pals.kinds.mixin.all_element_mixin import dump_element_list
28+
29+
return dump_element_list(self, "elements", *args, **kwargs)

src/pals/kinds/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@
3434
from .Taylor import Taylor # noqa: F401
3535
from .UnionEle import UnionEle # noqa: F401
3636
from .Wiggler import Wiggler # noqa: F401
37+
38+
39+
# Rebuild pydantic models that depend on other classes
40+
BeamLine.model_rebuild()
41+
UnionEle.model_rebuild()

src/pals/kinds/all_elements.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Helper module to define the union of all allowed element types.
2+
3+
This module creates a helper function that returns the element union type,
4+
avoiding duplication between BeamLine.line and UnionEle.elements.
5+
"""
6+
7+
from typing import Annotated, Union
8+
9+
from pydantic import Field
10+
11+
from .ACKicker import ACKicker
12+
from .BeamBeam import BeamBeam
13+
from .BeginningEle import BeginningEle
14+
from .Converter import Converter
15+
from .CrabCavity import CrabCavity
16+
from .Drift import Drift
17+
from .EGun import EGun
18+
from .Feedback import Feedback
19+
from .Fiducial import Fiducial
20+
from .FloorShift import FloorShift
21+
from .Foil import Foil
22+
from .Fork import Fork
23+
from .Girder import Girder
24+
from .Instrument import Instrument
25+
from .Kicker import Kicker
26+
from .Marker import Marker
27+
from .Mask import Mask
28+
from .Match import Match
29+
from .Multipole import Multipole
30+
from .NullEle import NullEle
31+
from .Octupole import Octupole
32+
from .Patch import Patch
33+
from .Quadrupole import Quadrupole
34+
from .RBend import RBend
35+
from .RFCavity import RFCavity
36+
from .SBend import SBend
37+
from .Sextupole import Sextupole
38+
from .Solenoid import Solenoid
39+
from .Taylor import Taylor
40+
from .Wiggler import Wiggler
41+
42+
43+
def get_all_element_types(extra_types: tuple = None):
44+
"""Return a tuple of all element types that can be used in BeamLine or UnionEle."""
45+
element_types = (
46+
"BeamLine", # Forward reference to handle circular import
47+
"UnionEle", # Forward reference to handle circular import
48+
ACKicker,
49+
BeamBeam,
50+
BeginningEle,
51+
Converter,
52+
CrabCavity,
53+
Drift,
54+
EGun,
55+
Feedback,
56+
Fiducial,
57+
FloorShift,
58+
Foil,
59+
Fork,
60+
Girder,
61+
Instrument,
62+
Kicker,
63+
Marker,
64+
Mask,
65+
Match,
66+
Multipole,
67+
NullEle,
68+
Octupole,
69+
Patch,
70+
Quadrupole,
71+
RBend,
72+
RFCavity,
73+
SBend,
74+
Sextupole,
75+
Solenoid,
76+
Taylor,
77+
Wiggler,
78+
)
79+
if extra_types is not None:
80+
element_types += extra_types
81+
return element_types
82+
83+
84+
def get_all_elements_as_annotation(extra_types: tuple = None):
85+
"""Return the Union type of all allowed elements with their name as the discriminator field."""
86+
types = get_all_element_types(extra_types)
87+
return Annotated[Union[types], Field(discriminator="kind")]

tests/test_elements.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,17 @@ def test_Foil():
466466

467467
def test_UnionEle():
468468
"""Test UnionEle element"""
469+
# Test empty union
469470
element = pals.UnionEle(name="union1", elements=[])
470471
assert element.name == "union1"
471472
assert element.kind == "UnionEle"
472473
assert element.elements == []
474+
475+
# Test union with elements
476+
marker = pals.Marker(name="m1")
477+
drift = pals.Drift(name="d1", length=1.0)
478+
element_with_children = pals.UnionEle(name="union2", elements=[marker, drift])
479+
assert element_with_children.name == "union2"
480+
assert len(element_with_children.elements) == 2
481+
assert element_with_children.elements[0].name == "m1"
482+
assert element_with_children.elements[1].name == "d1"

tests/test_serialization.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ def test_comprehensive_lattice():
178178
# Taylor
179179
taylor = pals.Taylor(name="taylor1")
180180

181-
# UnionEle
182-
unionele = pals.UnionEle(name="unionele1", elements=[])
181+
# UnionEle - with nested elements
182+
union_marker = pals.Marker(name="union_marker")
183+
union_drift = pals.Drift(name="union_drift", length=0.1)
184+
unionele = pals.UnionEle(name="unionele1", elements=[union_marker, union_drift])
183185

184186
# Wiggler
185187
wiggler = pals.Wiggler(name="wiggler1", length=2.0)
@@ -246,6 +248,7 @@ def test_comprehensive_lattice():
246248
octupole_loaded = None
247249
rbend_loaded = None
248250
rfcavity_loaded = None
251+
unionele_loaded = None
249252

250253
for elem in loaded_lattice.line:
251254
if elem.name == "sextupole1":
@@ -256,6 +259,8 @@ def test_comprehensive_lattice():
256259
rbend_loaded = elem
257260
elif elem.name == "rfcavity1":
258261
rfcavity_loaded = elem
262+
elif elem.name == "unionele1":
263+
unionele_loaded = elem
259264

260265
# Test that parameter groups are correctly deserialized
261266
assert sextupole_loaded.MagneticMultipoleP.Bn2 == 1.0
@@ -270,6 +275,15 @@ def test_comprehensive_lattice():
270275
assert rfcavity_loaded.RFP.frequency == 1e9
271276
assert rfcavity_loaded.SolenoidP.Ksol == 0.05
272277

278+
# Test that UnionEle elements are correctly deserialized
279+
assert unionele_loaded is not None
280+
assert len(unionele_loaded.elements) == 2
281+
assert unionele_loaded.elements[0].name == "union_marker"
282+
assert unionele_loaded.elements[0].kind == "Marker"
283+
assert unionele_loaded.elements[1].name == "union_drift"
284+
assert unionele_loaded.elements[1].kind == "Drift"
285+
assert unionele_loaded.elements[1].length == 0.1
286+
273287
# Test serialization to JSON
274288
json_data = json.dumps(lattice.model_dump(), sort_keys=True, indent=2)
275289
print(f"\nComprehensive lattice JSON:\n{json_data}")
@@ -294,6 +308,7 @@ def test_comprehensive_lattice():
294308
octupole_loaded_json = None
295309
rbend_loaded_json = None
296310
rfcavity_loaded_json = None
311+
unionele_loaded_json = None
297312

298313
for elem in loaded_lattice_json.line:
299314
if elem.name == "sextupole1":
@@ -304,6 +319,8 @@ def test_comprehensive_lattice():
304319
rbend_loaded_json = elem
305320
elif elem.name == "rfcavity1":
306321
rfcavity_loaded_json = elem
322+
elif elem.name == "unionele1":
323+
unionele_loaded_json = elem
307324

308325
# Test that parameter groups are correctly deserialized
309326
assert sextupole_loaded_json.MagneticMultipoleP.Bn2 == 1.0
@@ -318,6 +335,15 @@ def test_comprehensive_lattice():
318335
assert rfcavity_loaded_json.RFP.frequency == 1e9
319336
assert rfcavity_loaded_json.SolenoidP.Ksol == 0.05
320337

338+
# Test that UnionEle elements are correctly deserialized from JSON
339+
assert unionele_loaded_json is not None
340+
assert len(unionele_loaded_json.elements) == 2
341+
assert unionele_loaded_json.elements[0].name == "union_marker"
342+
assert unionele_loaded_json.elements[0].kind == "Marker"
343+
assert unionele_loaded_json.elements[1].name == "union_drift"
344+
assert unionele_loaded_json.elements[1].kind == "Drift"
345+
assert unionele_loaded_json.elements[1].length == 0.1
346+
321347
# Clean up temporary files
322348
os.remove(yaml_file)
323349
os.remove(json_file)

0 commit comments

Comments
 (0)