Skip to content

Commit 24f173b

Browse files
committed
full tests
1 parent 58439d3 commit 24f173b

File tree

7 files changed

+242
-56
lines changed

7 files changed

+242
-56
lines changed

examples/commissioning_2022.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing import TYPE_CHECKING
2727

2828
from ir_amplitude_detuning.detuning.calculations import Method
29-
from ir_amplitude_detuning.detuning.measurements import scaled_detuningmeasurement
29+
from ir_amplitude_detuning.detuning.measurements import DetuningMeasurement
3030
from ir_amplitude_detuning.detuning.targets import (
3131
Target,
3232
TargetData,
@@ -77,12 +77,12 @@ class MeasuredDetuning(Container):
7777
Note: Keys are beam numbers, 2 and 4 can be used interchangeably (but consistently) here.
7878
"""
7979
flat: DetuningPerBeam = BeamDict({
80-
1: scaled_detuningmeasurement(X10=(-15.4, 0.9), X01=(33.7, 1), Y01=(-8.4, 0.5)),
81-
2: scaled_detuningmeasurement(X10=(-8.7, 0.7), X01=(13, 2), Y01=(10, 0.9)),
80+
1: DetuningMeasurement(X10=(-15.4, 0.9), X01=(33.7, 1), Y01=(-8.4, 0.5), scale=1e3),
81+
2: DetuningMeasurement(X10=(-8.7, 0.7), X01=(13, 2), Y01=(10, 0.9), scale=1e3),
8282
})
8383
full: DetuningPerBeam = BeamDict({
84-
1: scaled_detuningmeasurement(X10=(20, 4), X01=(43, 4), Y01=(-10, 3)),
85-
2: scaled_detuningmeasurement(X10=(26, 0.8), X01=(-27, 4), Y01=(18, 7)),
84+
1: DetuningMeasurement(X10=(20, 4), X01=(43, 4), Y01=(-10, 3), scale=1e3),
85+
2: DetuningMeasurement(X10=(26, 0.8), X01=(-27, 4), Y01=(18, 7), scale=1e3),
8686
})
8787

8888
# Steps of calculations --------------------------------------------------------

examples/md3311.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from ir_amplitude_detuning.detuning.calculations import Method
2121
from ir_amplitude_detuning.detuning.measurements import (
2222
Constraints,
23+
DetuningMeasurement,
2324
FirstOrderTerm,
24-
scaled_detuningmeasurement,
2525
)
2626
from ir_amplitude_detuning.detuning.targets import (
2727
Target,
@@ -69,16 +69,16 @@ class MeasuredDetuning(Container):
6969
Note: Keys are beam numbers, 2 and 4 can be used interchangeably (but consistently) here.
7070
"""
7171
flat: DetuningPerBeam = BeamDict({
72-
1: scaled_detuningmeasurement(X10=(0.8, 0.5), Y01=(-3, 1)),
73-
2: scaled_detuningmeasurement(X10=(-7.5, 0.5), Y01=(6, 1)),
72+
1: DetuningMeasurement(X10=(0.8, 0.5), Y01=(-3, 1), scale=1e3),
73+
2: DetuningMeasurement(X10=(-7.5, 0.5), Y01=(6, 1), scale=1e3),
7474
})
7575
full: DetuningPerBeam = BeamDict({
76-
1: scaled_detuningmeasurement(X10=(34, 1), Y01=(-38, 1)),
77-
2: scaled_detuningmeasurement(X10=(-3, 1), Y01=(13, 3)),
76+
1: DetuningMeasurement(X10=(34, 1), Y01=(-38, 1), scale=1e3),
77+
2: DetuningMeasurement(X10=(-3, 1), Y01=(13, 3), scale=1e3),
7878
})
7979
ip5: DetuningPerBeam = BeamDict({
80-
1: scaled_detuningmeasurement(X10=(56, 6), Y01=(3, 2)),
81-
2: scaled_detuningmeasurement(X10=(1.5, 0.5), Y01=(12, 1)),
80+
1: DetuningMeasurement(X10=(56, 6), Y01=(3, 2), scale=1e3),
81+
2: DetuningMeasurement(X10=(1.5, 0.5), Y01=(12, 1), scale=1e3),
8282
})
8383
ip1: DetuningPerBeam = None # IP1 was not measured, calculated below
8484

ir_amplitude_detuning/detuning/measurements.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,3 @@ def get_leq(self, item: str) -> tuple[int, float]:
430430
value *= self.scale
431431

432432
return sign, sign*value
433-
434-
435-
# Default scaling is 1E3 as measurements are usually given in 1E3 m^-1
436-
scaled_detuning = partial(Detuning, scale=1e3)
437-
scaled_contraints = partial(Constraints, scale=1e3)
438-
scaled_detuningmeasurement = partial(DetuningMeasurement, scale=1e3)

ir_amplitude_detuning/plotting/utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,34 +63,46 @@ class OtherColors:
6363
flat = '#17becf' # blue-teal
6464

6565

66-
def get_full_target_labels(targets: Sequence[Target], suffixes: Sequence[str] | None = None, scale_exponent: float = 3) -> dict[str, str]:
67-
"""Get a label that includes all detuning terms so that they can be easily compared.
68-
To save space only the first target_data is used.
66+
def get_full_target_labels(
67+
targets: Sequence[Target],
68+
suffixes: Sequence[str] | None = None,
69+
rescale: float = 3
70+
) -> dict[str, str]:
71+
"""Get a latex label that includes values of all detuning terms, so that they can be easily compared.
72+
This is useful to plot the results of multiple targets on the same figure, without having to invent confusing
73+
labels. Instead you can just use the target detuning values that went into the correction.
74+
It ignores constraints and only the first target_data is used - otherwise the labels would be too long.
75+
Extra information can be added via the suffixes.
6976
7077
Args:
7178
targets (Sequence[Target]): List of Target objects to get labels for.
7279
suffixes (Sequence[str] | None): List of suffixes to add to the labels.
73-
scale (float): Scaling factor for the detuning values.
80+
rescale (float): Exponent of the scaling factor.
81+
(e.g. 3 to give data in units of 10^3, which multiplies the data by 10^-3)
82+
Default: 3.
7483
7584
Returns:
7685
dict[str, str]: Dictionary of labels for each target identified by its name.
7786
"""
7887
if suffixes is not None and len(suffixes) != len(targets):
7988
raise ValueError("Number of suffixes must match number of targets.")
8089

81-
scaling = 10**-scale_exponent
90+
scaling = 10**-rescale
8291

8392
names = [target.name for target in targets]
8493
labels = [None for _ in targets]
8594
for idx_target, target in enumerate(targets):
8695
target_data: TargetData = target.data[0]
8796
scaled_values = {
88-
term: (target_data.detuning[1][term] * scaling, target_data.detuning[2][term] * scaling)
89-
for term in target_data.detuning[1].terms()
97+
term: tuple("--".center(6) if val is None else f"{getattr(val, 'value', val) * scaling: 5.1f}"
98+
for beam in [1, 2]
99+
for val in [getattr(target_data.detuning[beam], term)])
100+
for term in set(target_data.detuning[1].terms()) | set(target_data.detuning[2].terms())
90101
}
102+
91103
label = "\n".join(
92104
[
93-
f"${latex.term2dqdj(term)}$ = {f'{values[0].value: 5.1f} | {values[1].value: 5.1f}'.center(15)}"
105+
f"${latex.term2dqdj(term)}$ = {f'{values[0]} | {values[1]}'.center(15)}"
94106
for term, values in scaled_values.items()
95107
]
96108
)

ir_amplitude_detuning/setup_template.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import TYPE_CHECKING
2424

2525
from ir_amplitude_detuning.detuning.calculations import Method
26-
from ir_amplitude_detuning.detuning.measurements import Constraints, scaled_detuningmeasurement
26+
from ir_amplitude_detuning.detuning.measurements import Constraints, DetuningMeasurement
2727
from ir_amplitude_detuning.detuning.targets import (
2828
Target,
2929
TargetData,
@@ -79,12 +79,12 @@ class MeasuredDetuning(Container):
7979
TODO: Fill in what you measured!
8080
"""
8181
flat: DetuningPerBeam = BeamDict({
82-
1: scaled_detuningmeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0)),
83-
2: scaled_detuningmeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0)),
82+
1: DetuningMeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0), scale=1e3),
83+
2: DetuningMeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0), scale=1e3),
8484
})
8585
full: DetuningPerBeam = BeamDict({
86-
1: scaled_detuningmeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0)),
87-
2: scaled_detuningmeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0)),
86+
1: DetuningMeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0), scale=1e3),
87+
2: DetuningMeasurement(X10=(0, 0), X01=(0, 0), Y01=(0, 0), scale=1e3),
8888
})
8989

9090
class CorrectionConstraints(Container):

tests/unit/test_detuning_measurements.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
Detuning,
99
DetuningMeasurement,
1010
MeasureValue,
11-
scaled_contraints,
12-
scaled_detuning,
13-
scaled_detuningmeasurement,
1411
)
1512
from ir_amplitude_detuning.detuning.terms import FirstOrderTerm
1613

@@ -699,28 +696,6 @@ def test_get_leq_unset_constraint_raises(self):
699696
const.get_leq("X01")
700697

701698

702-
class TestScaledPartials:
703-
"""Tests for the scaled partial functions."""
704-
705-
def test_scaled_detuning(self):
706-
"""Test scaled_detuning creates Detuning with 1e3 scale."""
707-
det = scaled_detuning(X10=1.0, Y01=2.0)
708-
assert det["X10"] == pytest.approx(1e3)
709-
assert det["Y01"] == pytest.approx(2e3)
710-
711-
def test_scaled_contraints(self):
712-
"""Test scaled_contraints creates Constraints with 1e3 scale."""
713-
const = scaled_contraints(X10="<=1.0")
714-
sign, value = const.get_leq("X10")
715-
assert value == pytest.approx(1e3)
716-
717-
def test_scaled_detuningmeasurement(self):
718-
"""Test scaled_detuningmeasurement creates DetuningMeasurement with 1e3 scale."""
719-
meas = scaled_detuningmeasurement(X10=MeasureValue(1.0, 0.1))
720-
assert meas["X10"].value == pytest.approx(1e3)
721-
assert meas["X10"].error == pytest.approx(100)
722-
723-
724699
class TestIntegration:
725700
"""Integration tests combining multiple classes."""
726701

tests/unit/test_plotting_utils.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""Tests for the plotting.utils module."""
2+
from __future__ import annotations
3+
4+
import re
5+
from unittest.mock import Mock
6+
7+
import pytest
8+
from matplotlib.colors import to_rgb
9+
10+
from ir_amplitude_detuning.detuning.measurements import Constraints, Detuning, DetuningMeasurement, MeasureValue
11+
from ir_amplitude_detuning.detuning.targets import Target, TargetData
12+
from ir_amplitude_detuning.plotting.utils import (
13+
OtherColors,
14+
get_color_for_field,
15+
get_color_for_ip,
16+
get_default_scaling,
17+
get_full_target_labels,
18+
)
19+
from ir_amplitude_detuning.utilities.correctors import Corrector, FieldComponent
20+
21+
22+
# ============================================================================
23+
# Tests for get_default_scaling
24+
# ============================================================================
25+
26+
class TestGetDefaultScaling:
27+
"""Test cases for the get_default_scaling function."""
28+
29+
@pytest.mark.parametrize(
30+
"term,expected_exponent,expected_scaling",
31+
[
32+
("X02", 12, 1e-12),
33+
("Y01", 3, 1e-3),
34+
("Y11", 12, 1e-12),
35+
("X10", 3, 1e-3),
36+
],
37+
)
38+
def test_get_default_scaling(self, term: str, expected_exponent: int, expected_scaling: float):
39+
"""Test default scaling factors for various detuning terms."""
40+
exponent, scaling = get_default_scaling(term)
41+
assert exponent == expected_exponent
42+
assert scaling == pytest.approx(expected_scaling)
43+
44+
def test_get_default_scaling_invalid_sum(self):
45+
"""Test that invalid term sums raise KeyError."""
46+
with pytest.raises(KeyError):
47+
get_default_scaling("X00") # sum is 0, not in dict
48+
49+
with pytest.raises(KeyError):
50+
get_default_scaling("Y31") # sum is 4, not in dict
51+
52+
53+
# ============================================================================
54+
# Tests for get_color_for_field
55+
# ============================================================================
56+
57+
class TestGetColorForField:
58+
"""Test cases for the get_color_for_field function."""
59+
60+
@pytest.mark.parametrize("field", list(FieldComponent))
61+
def test_get_color_for_field_valid(self, field: FieldComponent):
62+
"""Test that valid fields return colors."""
63+
result = get_color_for_field(field) # asserts all fields are valid
64+
_, _, _ = to_rgb(result) # asserts that it is convertable to RGB
65+
66+
def test_all_colors_different(self):
67+
"""Test that all colors are different."""
68+
colors = [get_color_for_field(field) for field in list(FieldComponent)]
69+
assert len(set(colors)) == len(colors)
70+
71+
def test_get_color_for_field_invalid(self):
72+
"""Test that invalid fields raise NotImplementedError."""
73+
# Create a mock field that doesn't match any case
74+
mock_field = Mock(spec=FieldComponent)
75+
mock_field.__str__ = Mock(return_value="invalid_field")
76+
77+
with pytest.raises(NotImplementedError, match="Field must be one of"):
78+
get_color_for_field(mock_field)
79+
80+
81+
# ============================================================================
82+
# Tests for get_color_for_ip
83+
# ============================================================================
84+
85+
class TestGetColorForIp:
86+
"""Test cases for the get_color_for_ip function."""
87+
88+
@pytest.mark.parametrize("ip", ["15", "1", "5"])
89+
def test_get_color_for_ip_valid(self, ip: str):
90+
"""Test that valid IPs return colors."""
91+
result = get_color_for_ip(ip) # asserts all IPs are valid
92+
_, _, _ = to_rgb(result) # asserts that it is convertable to RGB
93+
94+
def test_all_colors_different(self):
95+
"""Test that all colors are different."""
96+
colors = [get_color_for_ip(ip) for ip in ["15", "1", "5"]]
97+
assert len(set(colors)) == len(colors)
98+
99+
@pytest.mark.parametrize("invalid_ip", ["2", "8", "invalid", "", "1 ", "15a", "0"])
100+
def test_get_color_for_ip_invalid(self, invalid_ip: str):
101+
"""Test that invalid IPs raise NotImplementedError."""
102+
with pytest.raises(
103+
NotImplementedError,
104+
match=f"IP must be one of \\['15', '1', '5'\\], got {invalid_ip}\\."
105+
):
106+
get_color_for_ip(invalid_ip)
107+
108+
109+
# ============================================================================
110+
# Tests for OtherColors
111+
# ============================================================================
112+
113+
class TestOtherColors:
114+
"""Test cases for the OtherColors class."""
115+
116+
def test_other_colors_estimated(self):
117+
"""Test that OtherColors.estimated has correct value."""
118+
_, _, _ = to_rgb(OtherColors.estimated)
119+
assert OtherColors.flat != OtherColors.estimated
120+
121+
def test_other_colors_flat(self):
122+
"""Test that OtherColors.flat has correct value."""
123+
_, _, _ = to_rgb(OtherColors.estimated)
124+
assert OtherColors.flat != OtherColors.estimated
125+
126+
# ============================================================================
127+
# Tests for get_full_target_labels
128+
# ============================================================================
129+
130+
class TestGetFullTargetLabels:
131+
"""Test cases for the get_full_target_labels function."""
132+
133+
def test_get_full_target_labels_single_target_no_suffixes(self, target_data):
134+
"""Test with single target and no suffixes."""
135+
result = get_full_target_labels([Target(name="target_name", data=[target_data])])
136+
137+
assert isinstance(result, dict)
138+
assert "target_name" in result
139+
target_label = result["target_name"]
140+
assert isinstance(target_label, str)
141+
assert re.search(r"\$Q_\{x,yy\}\$\s+=\s+1\.5\s+\|\s+3\.5", target_label) # values from the fixture
142+
assert re.search(r"\$Q_\{y,xy\}\$\s+=\s+--\s+\|\s+2\.5", target_label)
143+
144+
145+
def test_get_full_target_labels_multiple_targets_with_suffixes(self, target_data):
146+
"""Test with multiple targets and suffixes."""
147+
result = get_full_target_labels(
148+
[Target(name="target1", data=[target_data]),
149+
Target(name="target2", data=[target_data])],
150+
suffixes=["suffix_1", "suffix_2"]
151+
)
152+
153+
assert len(result) == 2
154+
assert "target1" in result
155+
assert "target2" in result
156+
assert "suffix_1" in result["target1"]
157+
assert "suffix_2" in result["target2"]
158+
assert result["target1"].replace("suffix_1","") == result["target2"].replace("suffix_2","")
159+
160+
def test_get_full_target_labels_empty_targets(self):
161+
"""Test with empty targets list."""
162+
result = get_full_target_labels([])
163+
164+
assert isinstance(result, dict)
165+
assert len(result) == 0
166+
167+
def test_get_full_target_labels_mismatched_suffixes_error(self, target_data):
168+
"""Test that mismatched suffixes count raises ValueError."""
169+
with pytest.raises(ValueError, match="Number of suffixes must match number of targets"):
170+
get_full_target_labels(
171+
[
172+
Target(name="target1", data=[target_data]),
173+
Target(name="target2", data=[target_data]),
174+
],
175+
suffixes=["only_one_suffix"],
176+
)
177+
178+
def test_get_full_target_labels_custom_scale_exponent(self, target_data):
179+
"""Test with custom scale exponent."""
180+
result = get_full_target_labels([Target(name="target_name", data=[target_data])], rescale=4)
181+
182+
assert isinstance(result, dict)
183+
assert "target_name" in result # replace "target_data.name" with the actual name of the target
184+
target_label = result["target_name"]
185+
assert isinstance(target_label, str)
186+
assert re.search(r"\$Q_\{x,yy\}\$\s+=\s+0\.2\s+\|\s+0\.4", target_label) # values /10 from the fixture
187+
assert re.search(r"\$Q_\{y,xy\}\$\s+=\s+--\s+\|\s+0\.3", target_label)
188+
189+
190+
@pytest.fixture
191+
def target_data():
192+
"""Fixture for TargetData. Used in the TestGetFullTargetLabels class."""
193+
correctors = [
194+
Corrector(field=FieldComponent.b4, circuit="k4", magnet="K4", length=0.5),
195+
Corrector(field=FieldComponent.b5, circuit="k5", magnet="K5", length=0.5),
196+
]
197+
optics = {1: Mock(), 2: Mock()}
198+
detuning = {
199+
1: DetuningMeasurement(X02=(1.51, 2.5), scale=1e3),
200+
2: DetuningMeasurement(Y11=(2.51, 2.5), X02=(3.52, 4.5), scale=1e3)
201+
}
202+
constraints = {1: Constraints(Y02="<=10"), 2: Constraints(Y02=">=11")}
203+
return TargetData(
204+
correctors=correctors, optics=optics, detuning=detuning, constraints=constraints
205+
)

0 commit comments

Comments
 (0)