Skip to content

Commit c00e6fb

Browse files
committed
tests
1 parent bef85a9 commit c00e6fb

File tree

3 files changed

+306
-5
lines changed

3 files changed

+306
-5
lines changed

ir_amplitude_detuning/detuning/calculations.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import cvxpy as cvx
1414
import numpy as np
1515
import pandas as pd
16+
from cvxpy.settings import ERROR, INF_OR_UNB
1617

1718
from ir_amplitude_detuning.detuning.equation_system import (
1819
build_detuning_correction_matrix,
@@ -95,9 +96,10 @@ def calculate_correction(
9596
pd.Series[float]: A Series of circuit names and their settings in KNL values.
9697
"""
9798
# Check input ---
98-
99-
if method not in Method:
100-
raise ValueError(f"Unknown method: {method}. Use one of: {list(Method)}")
99+
try:
100+
method = Method(method)
101+
except ValueError as e:
102+
raise ValueError(f"Unknown method: {method}. Use one of: {list(Method)}") from e
101103

102104
# Build equation system ---
103105

@@ -116,7 +118,7 @@ def calculate_correction(
116118

117119
prob = cvx.Problem(cvx.Minimize(cost), constraints)
118120
prob.solve()
119-
if prob.status in ["infeasible", "unbounded"]:
121+
if prob.status in INF_OR_UNB + ERROR:
120122
raise ValueError(f"Optimization failed! Reason: {prob.status}.")
121123

122124
x_cvxpy = pd.Series(x.value, index=eqsys.m.columns)

ir_amplitude_detuning/utilities/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
class StrEnum(str, Enum):
1818
"""Enum with string representation.
1919
20-
Note: Can be removed in Python 3.11 as it is implemented there as `enum.StrEnum`.
20+
Note:
21+
Can possibly be removed in Python 3.11 as it is implemented there as `enum.StrEnum`.
22+
But beware, that `"value" in StrEnum` raises `TypeError` until Python 3.12,
23+
workaround is `"value" in list(StrEnum)` or to try `StrEnum(value)`.
2124
"""
2225
def __repr__(self) -> str:
2326
return self.value
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import Mock, patch
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
from cvxpy.settings import INFEASIBLE, SOLVER_ERROR, UNBOUNDED
9+
from pandas.testing import assert_frame_equal
10+
11+
from ir_amplitude_detuning.detuning.calculations import (
12+
FIELDS,
13+
IP,
14+
Method,
15+
calc_effective_detuning,
16+
calculate_correction,
17+
)
18+
from ir_amplitude_detuning.detuning.equation_system import DetuningCorrectionEquationSystem
19+
from ir_amplitude_detuning.detuning.measurements import (
20+
FirstOrderTerm,
21+
MeasureValue,
22+
SecondOrderTerm,
23+
)
24+
from ir_amplitude_detuning.utilities.correctors import Corrector, FieldComponent
25+
26+
# ============================================================================
27+
# Tests for Method Enum
28+
# ============================================================================
29+
30+
class TestMethodEnum:
31+
"""Test cases for the Method enum."""
32+
33+
def test_method_enum_values(self):
34+
"""Test that all method values are correct."""
35+
assert Method.auto == "auto"
36+
assert Method.cvxpy == "cvxpy"
37+
assert Method.numpy == "numpy"
38+
39+
def test_method_in_enum(self):
40+
"""Test that all methods are recognized by the enum."""
41+
# check if all values are in the enum (`list` only necessary until py3.12)
42+
assert "auto" in list(Method)
43+
assert "cvxpy" in list(Method)
44+
assert "numpy" in list(Method)
45+
46+
# similarly, creating instances should also work:
47+
assert Method("auto") == Method.auto
48+
assert Method("cvxpy") == Method.cvxpy
49+
assert Method("numpy") == Method.numpy
50+
51+
def test_method_not_in_enum(self):
52+
"""Test that unknown methods raise ValueError."""
53+
with pytest.raises(ValueError):
54+
Method("invalid")
55+
56+
57+
# ============================================================================
58+
# Tests for calculate_correction
59+
# ============================================================================
60+
61+
class TestCalculateCorrection:
62+
"""Test cases for the calculate_correction function."""
63+
64+
def test_calculate_correction_invalid_method(self):
65+
"""Test that invalid method raises ValueError."""
66+
mock_target = Mock()
67+
68+
with pytest.raises(ValueError):
69+
calculate_correction(mock_target, method="invalid")
70+
71+
@pytest.mark.parametrize("method", [Method.auto, Method.numpy, Method.cvxpy])
72+
@patch('ir_amplitude_detuning.detuning.calculations.build_detuning_correction_matrix')
73+
def test_calculate_correction_exact_no_constraints(self, mock_build_matrix, method):
74+
"""Test calculate_correction with auto method and no constraints."""
75+
# Very simple eqation system:
76+
matrix = [[1, 1], [1, -1]] # inverse is matrix/2
77+
values = [MeasureValue(3, 0.2), MeasureValue(1, 0.1)]
78+
expected_values = {"a": 2, "b": 1}
79+
expected_error = np.sqrt(np.mean([v.error**2 for v in values])/2)
80+
81+
# Mock equation system building ---
82+
mock_eqsys = DetuningCorrectionEquationSystem(
83+
m = pd.DataFrame(matrix, columns=expected_values.keys()),
84+
v = pd.Series([v.value for v in values]),
85+
m_constr = pd.DataFrame(),
86+
v_constr = pd.Series(dtype=float),
87+
v_meas = pd.Series(values),
88+
)
89+
90+
mock_build_matrix.return_value = mock_eqsys
91+
mock_target = Mock()
92+
93+
# Run the calculation ---
94+
result = calculate_correction(mock_target, method=method)
95+
96+
# Check the results ---
97+
assert mock_build_matrix.called_with(mock_target)
98+
assert isinstance(result, pd.Series)
99+
assert len(result) == 2
100+
101+
if method in (Method.numpy, Method.auto):
102+
assert result["a"].value == pytest.approx(expected_values["a"])
103+
assert result["b"].value == pytest.approx(expected_values["b"])
104+
assert result["a"].error == pytest.approx(result["b"].error) == pytest.approx(expected_error)
105+
else:
106+
assert result["a"] == pytest.approx(expected_values["a"])
107+
assert result["b"] == pytest.approx(expected_values["b"])
108+
109+
with pytest.raises(AttributeError):
110+
result["a"].value
111+
112+
with pytest.raises(AttributeError):
113+
result["b"].value
114+
115+
def _get_equation_system_to_optimize(self) -> tuple[list[MeasureValue], list[list[int]], float]:
116+
"""Simple, not exact solvable Eqs for the next tests."""
117+
values = [MeasureValue(3, 0.2), MeasureValue(5, 0.1)]
118+
matrix = [[1, 1], [2, 2]]
119+
expected = 1.3 # optimal value without constraints
120+
return values, matrix, expected
121+
122+
@patch('ir_amplitude_detuning.detuning.calculations.build_detuning_correction_matrix')
123+
def test_calculate_correction_optimize_no_constraints(self, mock_build_matrix):
124+
"""Test calculate_correction with auto method and no constraints."""
125+
# Prepare ---
126+
values, matrix, expected = self._get_equation_system_to_optimize()
127+
mock_eqsys = DetuningCorrectionEquationSystem(
128+
m = pd.DataFrame(matrix, columns=["a", "b"]),
129+
v = pd.Series([v.value for v in values]),
130+
m_constr = pd.DataFrame(),
131+
v_constr = pd.Series(dtype=float),
132+
v_meas = pd.Series(values),
133+
)
134+
mock_build_matrix.return_value = mock_eqsys
135+
136+
# Run ---
137+
result_numpy = calculate_correction(Mock(), method=Method.auto)
138+
139+
# Check ---
140+
assert result_numpy["a"].value == pytest.approx(expected)
141+
assert result_numpy["b"].value == pytest.approx(expected)
142+
assert result_numpy["a"].error > 0
143+
assert result_numpy["b"].error > 0
144+
145+
@patch('ir_amplitude_detuning.detuning.calculations.build_detuning_correction_matrix')
146+
def test_calculate_correction_optimize_with_wide_constraints(self, mock_build_matrix):
147+
"""Test calculate_correction with auto method and constraints that don't really matter."""
148+
# Prepare ---
149+
values, matrix, expected = self._get_equation_system_to_optimize()
150+
mock_eqsys = DetuningCorrectionEquationSystem(
151+
m = pd.DataFrame(matrix, columns=["a", "b"]),
152+
v = pd.Series([v.value for v in values]),
153+
m_constr = pd.DataFrame([[1, 1]]), # sum of variables
154+
v_constr = pd.Series([3]), # to be smaller than 3
155+
v_meas = pd.Series(values),
156+
)
157+
mock_build_matrix.return_value = mock_eqsys
158+
159+
# Run ---
160+
result_cvxpy = calculate_correction(Mock(), method=Method.auto)
161+
162+
# Check ---
163+
assert result_cvxpy["a"] == pytest.approx(expected)
164+
assert result_cvxpy["b"] == pytest.approx(expected)
165+
166+
@patch('ir_amplitude_detuning.detuning.calculations.build_detuning_correction_matrix')
167+
def test_calculate_correction_optimize_with_constraints(self, mock_build_matrix):
168+
"""Test calculate_correction with auto method and constraints."""
169+
# Prepare ---
170+
values, matrix, expected = self._get_equation_system_to_optimize()
171+
mock_eqsys = DetuningCorrectionEquationSystem(
172+
m = pd.DataFrame(matrix, columns=["a", "b"]),
173+
v = pd.Series([v.value for v in values]),
174+
m_constr = pd.DataFrame([[-1, -1]]), # sum of variables
175+
v_constr = pd.Series([-3]), # to be larger than 3
176+
v_meas = pd.Series(values),
177+
)
178+
mock_build_matrix.return_value = mock_eqsys
179+
180+
# Run ---
181+
result_cvxpy = calculate_correction(Mock(), method=Method.auto)
182+
183+
# Check ---
184+
assert np.sum(result_cvxpy) == pytest.approx(3) # should optimize to 1.5, 1.5
185+
186+
187+
@patch('ir_amplitude_detuning.detuning.calculations.build_detuning_correction_matrix')
188+
@patch('ir_amplitude_detuning.detuning.calculations.cvx.Problem')
189+
def test_cvxpy_fails(self, mock_problem_class, mock_build_matrix):
190+
"""Test calculate_correction with cvxpy method."""
191+
# Setup mocks
192+
mock_eqsys = Mock()
193+
mock_eqsys.m = pd.DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
194+
mock_eqsys.v = pd.Series([5, 6])
195+
mock_eqsys.m_constr = pd.DataFrame()
196+
mock_eqsys.v_constr = pd.Series(dtype=float)
197+
mock_eqsys.v_meas = pd.Series([5, 6])
198+
199+
mock_build_matrix.return_value = mock_eqsys
200+
201+
# Mock cvxpy solver
202+
for error_status in (INFEASIBLE, UNBOUNDED, SOLVER_ERROR):
203+
mock_prob = Mock()
204+
mock_problem_class.return_value = mock_prob
205+
mock_prob.status = error_status
206+
mock_prob.solve.return_value = None
207+
208+
with pytest.raises(ValueError) as e:
209+
calculate_correction(Mock(), method=Method.cvxpy)
210+
211+
assert "failed" in str(e)
212+
assert error_status in str(e)
213+
214+
# Check that cvxpy solver was used
215+
mock_prob.solve.assert_called_once()
216+
217+
218+
# ============================================================================
219+
# Tests for calc_effective_detuning
220+
# ============================================================================
221+
222+
class TestCalcEffectiveDetuning:
223+
"""Test cases for the calc_effective_detuning function."""
224+
225+
def test_calc_effective_detuning_empty_optics(self):
226+
"""Test with empty optics dictionary."""
227+
result = calc_effective_detuning({}, pd.Series(dtype=float))
228+
229+
# Returns empty dict with no beams
230+
assert isinstance(result, dict)
231+
assert len(result) == 0
232+
233+
@patch("ir_amplitude_detuning.detuning.calculations.calculate_matrix_row")
234+
def test_calc_effective_detuning(self, mock_calculate_matrix_row):
235+
"""Test with single beam."""
236+
# Prepare fake data ---
237+
# Create correctors with different IPs and fields
238+
correctors = [
239+
Corrector(
240+
field=field,
241+
length=0.5,
242+
magnet=f"{type_}ip{ip or 0}{field}",
243+
circuit=f"k{type_}ip{ip or 0}{field}",
244+
ip=ip,
245+
)
246+
for type_, field, ip in (
247+
("c1", FieldComponent.b4, 1),
248+
("c2", FieldComponent.b4, 1),
249+
("c1", FieldComponent.b5, 1),
250+
("c1", FieldComponent.b6, 1),
251+
("c2", FieldComponent.b6, 1),
252+
("c1", FieldComponent.b6, 2),
253+
("c2", FieldComponent.b6, 2),
254+
("c2", FieldComponent.b4, None),
255+
)
256+
]
257+
all_terms = list(FirstOrderTerm) + list(SecondOrderTerm)
258+
values = pd.Series(np.arange(len(correctors)), index=correctors)
259+
260+
# Create mocks
261+
mock_optics = {1: Mock(), 2: Mock()}
262+
def mocked_calulation(beam, optics, correctors, term):
263+
assert optics == mock_optics[beam] # already some checks
264+
assert term in all_terms
265+
return np.ones([1, len(correctors)]) * (all_terms.index(term) + 1) * beam
266+
267+
mock_calculate_matrix_row.side_effect = mocked_calulation
268+
269+
270+
# Test the function ---
271+
result = calc_effective_detuning(mock_optics, values)
272+
273+
# Check the result ---
274+
assert isinstance(result, dict)
275+
276+
# n calls = n_terms * (n_ips + 1) * (n_fields + 1) * n_beams
277+
assert mock_calculate_matrix_row.call_count == len(all_terms) * 3 * 4 * 2
278+
279+
# Check that result 2 is the same as 1 but multiplied by 2 (beam in mock calculation)
280+
df_mul = result[1].copy()
281+
df_mul.loc[:, all_terms] = df_mul.loc[:, all_terms] * 2
282+
assert_frame_equal(df_mul, result[2])
283+
284+
# Test grouping by fields and ips
285+
def filter_correctors(field, ip):
286+
return list(filter(lambda c: (c.field in field) and (c.ip is None or str(c.ip) in ip), correctors))
287+
288+
for field in ("b4", "b5", "b6", "b4b5b6"):
289+
df_field = result[1].loc[result[1][FIELDS] == field, :]
290+
for ip in ("1", "2", "12"):
291+
df_field_ip = df_field.loc[df_field[IP] == ip, :]
292+
value = df_field_ip[all_terms[0]].iloc[0]
293+
assert len(df_field_ip) == 1
294+
assert all(df_field_ip[all_terms] == value * np.arange(1, len(all_terms) + 1)) # because of mock return
295+
contributing_correctors = filter_correctors(field=field, ip=ip)
296+
assert value == sum(values.loc[contributing_correctors]) # because mock return for first term is (1,1,1..)

0 commit comments

Comments
 (0)