Skip to content

Commit 162d4df

Browse files
committed
More Parameter Validations
And more limits on ints and string literal values.
1 parent e5c7ac3 commit 162d4df

File tree

8 files changed

+55
-29
lines changed

8 files changed

+55
-29
lines changed
Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
from typing import List
2-
from pydantic import BaseModel
1+
from typing import Literal
2+
from pydantic import BaseModel, Field
33

44

55
class ApertureParameters(BaseModel):
66
"""Aperture parameters"""
77

8-
x_limits: List[float] = [float("nan"), float("nan")]
9-
y_limits: List[float] = [float("nan"), float("nan")]
10-
shape: str = "RECTANGULAR"
11-
location: str = "ENTRANCE_END"
8+
x_limits: list[float | None, float | None] = Field(
9+
default=[None, None],
10+
validate_args=lambda x: (x[0] is None or x[1] is None or x[0] < x[1]),
11+
)
12+
y_limits: list[float | None, float | None] = Field(
13+
default=[None, None],
14+
validate_args=lambda x: (x[0] is None or x[1] is None or x[0] < x[1]),
15+
)
16+
shape: Literal["RECTANGULAR", "ELLIPTICAL", "VERTICES", "CUSTOM_SHAPE"] = (
17+
"RECTANGULAR"
18+
)
19+
location: Literal[
20+
"ENTRANCE_END", "CENTER", "EXIT_END", "BOTH_ENDS", "NOWHERE", "EVERYWHERE"
21+
] = "ENTRANCE_END"
1222
material: str = ""
13-
thickness: float = 0.0
23+
thickness: float = Field(default=0.0, ge=0.0)
1424
aperture_shifts_with_body: bool = False
1525
aperture_active: bool = True

src/pals/parameters/ElectricMultipoleParameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
class ElectricMultipoleParameters(BaseModel):
55
"""Electric multipole parameters"""
66

7-
# Allow arbitrary fields
7+
# Allow arbitrary fields (TODO: remove this)
88
model_config = ConfigDict(extra="allow")
99

1010
# TODO: add ElectricMultipoleParameters in a follow-up RP
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Literal
12
from pydantic import BaseModel
23

34

@@ -6,5 +7,5 @@ class ForkParameters(BaseModel):
67

78
to_line: str = ""
89
to_ele: str = ""
9-
direction: str = "FORWARDS" # "FORWARDS" or "BACKWARDS"
10+
direction: Literal["FORWARDS", "BACKWARDS"] = "FORWARDS"
1011
propagate_reference: bool = True

src/pals/parameters/PatchParameters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Literal
12
from pydantic import BaseModel
23

34

@@ -11,5 +12,5 @@ class PatchParameters(BaseModel):
1112
y_rot: float = 0.0
1213
z_rot: float = 0.0
1314
flexible: bool = False
14-
ref_coords: str = "exit_end" # "entrance_end" or "exit_end"
15+
ref_coords: Literal["entrance_end", "exit_end"] = "exit_end"
1516
user_sets_length: bool = False

src/pals/parameters/RFParameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel
1+
from pydantic import BaseModel, Field
22

33

44
class RFParameters(BaseModel):
@@ -11,4 +11,4 @@ class RFParameters(BaseModel):
1111
phase: float = 0.0 # [unitless] RF phase in 0 to 2*pi
1212
multipass_phase: float = 0.0 # [unitless] RF Phase added to multipass elements
1313
cavity_type: str = "STANDING_WAVE" # [string] Cavity type
14-
n_cell: int = 1 # [unitless] Number of cavity cells
14+
n_cell: int = Field(default=1, gt=0) # [unitless] Number of cavity cells

src/pals/parameters/ReferenceParameters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Literal
12
from pydantic import BaseModel
23

34

@@ -9,3 +10,6 @@ class ReferenceParameters(BaseModel):
910
E_tot_ref: float = 0.0 # [eV] Reference total energy
1011
time_ref: float = 0.0 # [s] Reference time
1112
location: str = "" # Where reference parameters are evaluated
13+
location: Literal[
14+
"UPSTREAM_END", "DOWNSTREAM_END"
15+
] # TODO: undefined default in PALS?

tests/test_elements.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from pydantic import ValidationError
23

34
import pals
@@ -23,13 +24,8 @@ def test_ThickElement():
2324
# Try to assign negative length and
2425
# detect validation error without breaking pytest
2526
element_length = -1.0
26-
passed = True
27-
try:
27+
with pytest.raises(ValidationError):
2828
element.length = element_length
29-
except ValidationError as e:
30-
print(e)
31-
passed = False
32-
assert not passed
3329

3430

3531
def test_Drift():
@@ -45,13 +41,8 @@ def test_Drift():
4541
# Try to assign negative length and
4642
# detect validation error without breaking pytest
4743
element_length = -1.0
48-
passed = True
49-
try:
44+
with pytest.raises(ValidationError):
5045
element.length = element_length
51-
except ValidationError as e:
52-
print(e)
53-
passed = False
54-
assert not passed
5546

5647

5748
def test_Quadrupole():

tests/test_parameters.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
14
from pals import (
25
ApertureParameters,
36
BeamBeamParameters,
47
BendParameters,
58
BodyShiftParameters,
6-
ElectricMultipoleParameters,
7-
# FloorParameters, # not yet tested
89
FloorShiftParameters,
910
ForkParameters,
1011
MagneticMultipoleParameters,
@@ -24,6 +25,11 @@ def test_ParameterClasses():
2425
aperture = ApertureParameters(x_limits=[-0.1, 0.1], y_limits=[-0.05, 0.05])
2526
assert aperture.x_limits == [-0.1, 0.1]
2627

28+
with pytest.raises(ValidationError):
29+
_ = ApertureParameters(
30+
x_limits=[-0.1, 0.1], y_limits=[-0.05, 0.05, 0.1], shape="wrong"
31+
)
32+
2733
# Test BodyShiftParameters
2834
body_shift = BodyShiftParameters(x_offset=0.01, y_rot=0.02)
2935
assert body_shift.x_offset == 0.01
@@ -32,15 +38,23 @@ def test_ParameterClasses():
3238
meta = MetaParameters(alias="test", description="test element")
3339
assert meta.alias == "test"
3440

35-
# Test ElectricMultipoleParameters
36-
emp = ElectricMultipoleParameters(En1=1.0, Es1=0.5)
37-
assert emp.En1 == 1.0
41+
# Test ElectricMultipoleParameters (TODO)
42+
# emp = ElectricMultipoleParameters(En1=1.0, Es1=0.5)
43+
# assert emp.En1 == 1.0
3844

3945
# Test MagneticMultipoleParameters
4046
mmp = MagneticMultipoleParameters(Bn1=1.0, Bs1=0.5)
4147
assert mmp.Bn1 == 1.0
4248
assert mmp.Bs1 == 0.5
4349

50+
# catch typos
51+
with pytest.raises(ValidationError):
52+
_ = MagneticMultipoleParameters(Bm1=1.0, Bs1=0.5)
53+
with pytest.raises(ValidationError):
54+
_ = MagneticMultipoleParameters(Bn1=1.0, Bv1=0.5)
55+
with pytest.raises(ValidationError):
56+
_ = MagneticMultipoleParameters(Bn01=1.0, Bs01=0.5)
57+
4458
# Test SolenoidParameters
4559
sol = SolenoidParameters(Ksol=0.1, Bsol=0.2)
4660
assert sol.Ksol == 0.1
@@ -49,6 +63,11 @@ def test_ParameterClasses():
4963
rf = RFParameters(frequency=1e9, voltage=1e6)
5064
assert rf.frequency == 1e9
5165

66+
with pytest.raises(ValidationError):
67+
_ = RFParameters(frequency=1e9, voltage=1e6, n_cell=0)
68+
with pytest.raises(ValidationError):
69+
_ = RFParameters(frequency=1e9, voltage=1e6, n_cell=-1)
70+
5271
# Test BendParameters
5372
bend = BendParameters(rho_ref=1.0, bend_field_ref=2.0)
5473
assert bend.rho_ref == 1.0

0 commit comments

Comments
 (0)