Skip to content

Commit 078d354

Browse files
committed
Refactor sim params into a dataclass
There will be a corresponding update to Virtac
1 parent 64954ac commit 078d354

File tree

5 files changed

+118
-157
lines changed

5 files changed

+118
-157
lines changed

src/atip/load_sim.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
def load_from_filepath(
1616
pytac_lattice,
1717
at_lattice_filepath,
18-
linopt_function="linopt6",
19-
disable_emittance=False,
20-
disable_chromaticity=False,
21-
disable_radiation=False,
18+
sim_params=None,
2219
callback=None,
2320
):
2421
"""Load simulator data sources onto the lattice and its elements.
@@ -27,13 +24,8 @@ def load_from_filepath(
2724
pytac_lattice (pytac.lattice.Lattice): An instance of a Pytac lattice.
2825
at_lattice_filepath (str): The path to a .mat file from which the
2926
Accelerator Toolbox lattice can be loaded.
30-
linopt_function (str): Which pyAT linear optics function to use: linopt2,
31-
linopt4, linopt6.
32-
disable_emittance (bool): Whether the emittance calculations should be
33-
disabled.
34-
disable_chromaticity (bool): Whether the chromaticity calculations should be
35-
disabled.
36-
disable_radiation (bool): Whether radiation calculations should be disabled.
27+
sim_params (SimParams | None): An optional dataclass containing the pyAT
28+
simulation parameters to use.
3729
callback (typing.Callable): To be called after completion of each round of
3830
physics calculations.
3931
@@ -49,35 +41,24 @@ def load_from_filepath(
4941
return load(
5042
pytac_lattice,
5143
at_lattice,
52-
linopt_function,
53-
disable_emittance,
54-
disable_chromaticity,
55-
disable_radiation,
44+
sim_params,
5645
callback,
5746
)
5847

5948

6049
def load(
6150
pytac_lattice,
6251
at_lattice,
63-
linopt_function="linopt6",
64-
disable_emittance=False,
65-
disable_chromaticity=False,
66-
disable_radiation=False,
52+
sim_params=None,
6753
callback=None,
6854
):
6955
"""Load simulator data sources onto the lattice and its elements.
7056
7157
Args:
7258
pytac_lattice (pytac.lattice.Lattice): An instance of a Pytac lattice.
7359
at_lattice (at.lattice_object.Lattice): An instance of an AT lattice object.
74-
linopt_function (str): Which pyAT linear optics function to use: linopt2,
75-
linopt4, linopt6.
76-
disable_emittance (bool): Whether the emittance calculations should be
77-
disabled.
78-
disable_chromaticity (bool): Whether the chromaticity calculations should be
79-
disabled.
80-
disable_radiation (bool): Whether radiation calculations should be disabled.
60+
sim_params (SimParams | None): An optional dataclass containing the pyAT
61+
simulation parameters to use.
8162
callback (typing.Callable): To be called after completion of each round of
8263
physics calculations.
8364
@@ -93,10 +74,7 @@ def load(
9374
# Initialise an instance of the ATSimulator Object.
9475
atsim = ATSimulator(
9576
at_lattice,
96-
linopt_function,
97-
disable_emittance,
98-
disable_chromaticity,
99-
disable_radiation,
77+
sim_params,
10078
callback,
10179
)
10280
# Set the simulator data source on the Pytac lattice.

src/atip/simulator.py

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
from dataclasses import dataclass
5+
from enum import StrEnum, auto
56
from warnings import warn
67

78
import at
@@ -12,6 +13,38 @@
1213
from scipy.constants import speed_of_light
1314

1415

16+
class LinoptType(StrEnum):
17+
LINOPT2 = auto()
18+
LINOPT4 = auto()
19+
LINOPT6 = auto()
20+
21+
22+
@dataclass
23+
class SimParams:
24+
linopt: LinoptType = LinoptType.LINOPT6
25+
emittance: bool = True
26+
chromaticity: bool = True
27+
radiation: bool = True
28+
29+
def __post_init__(self):
30+
"""Check that we have a valid combination of simulation parameters."""
31+
if self.radiation:
32+
if self.linopt == LinoptType.LINOPT2 or self.linopt == LinoptType.LINOPT4:
33+
raise ValueError(
34+
f"You must disable radiation to use linopt function: {self.linopt}",
35+
)
36+
else:
37+
if self.linopt == LinoptType.LINOPT6:
38+
raise ValueError(
39+
f"You cannot use linopt function: {self.linopt} with radiation "
40+
f"disabled.",
41+
)
42+
elif self.emittance:
43+
raise ValueError(
44+
"You cannot calculate emittance with radiation disabled",
45+
)
46+
47+
1548
@dataclass
1649
class LatticeData:
1750
twiss: ArrayLike
@@ -22,12 +55,7 @@ class LatticeData:
2255

2356

2457
def calculate_optics(
25-
at_lattice: at.lattice_object.Lattice,
26-
refpts: ArrayLike,
27-
linopt_function: str = "linopt6",
28-
disable_emittance: bool = False,
29-
disable_chromaticity: bool = False,
30-
disable_radiation: bool = False,
58+
at_lattice: at.lattice_object.Lattice, refpts: ArrayLike, sp: SimParams
3159
) -> LatticeData:
3260
"""Perform the physics calculations on the lattice.
3361
@@ -39,63 +67,55 @@ def calculate_optics(
3967
Args:
4068
at_lattice (at.lattice_object.Lattice): AT lattice definition.
4169
refpts (numpy.typing.NDArray): A boolean array specifying the points at which
42-
to calculate physics data.
43-
disable_emittance (bool): whether to calculate emittance.
70+
to calculate physics data.
71+
sp (SimParams): An optional dataclass containing the pyAT simulation
72+
parameters to use.
4473
4574
Returns:
4675
LatticeData: The calculated lattice data.
4776
"""
4877
logging.debug("Starting physics calculations.")
4978
logging.debug(
50-
f"Using simulation params: {linopt_function}, disable_emittance="
51-
f"{disable_emittance}, disable_chromaticity={disable_chromaticity}, "
52-
f"disable_radiation={disable_radiation}"
79+
f"Using simulation params: {sp.linopt}, emittance={sp.emittance}, chromaticity="
80+
f"{sp.chromaticity}, radiation={sp.radiation}"
5381
)
54-
if linopt_function == "linopt6":
55-
orbit0, _ = at_lattice.find_orbit6()
56-
logging.debug("Completed orbit calculation.")
57-
58-
_, beamdata, twiss = at_lattice.linopt6(
59-
refpts=refpts,
60-
get_chrom=not disable_chromaticity,
61-
orbit=orbit0,
62-
keep_lattice=True,
63-
)
64-
elif linopt_function == "linopt4":
65-
orbit0, _ = at_lattice.find_orbit4()
66-
logging.debug("Completed orbit calculation.")
67-
68-
_, beamdata, twiss = at_lattice.linopt6(
69-
refpts=refpts,
70-
get_chrom=not disable_chromaticity,
71-
orbit=orbit0,
72-
keep_lattice=True,
73-
)
74-
elif linopt_function == "linopt2":
75-
orbit0, _ = at_lattice.find_orbit()
76-
logging.debug("Completed orbit calculation.")
77-
78-
_, beamdata, twiss = at_lattice.linopt2(
79-
refpts=refpts,
80-
get_chrom=not disable_chromaticity,
81-
orbit=orbit0,
82-
keep_lattice=True,
83-
)
84-
else:
85-
raise ValueError(
86-
f"Error. Invalid linopt function selected: {linopt_function}. Simulation "
87-
"data not calculated."
88-
)
8982

83+
match sp.linopt:
84+
case LinoptType.LINOPT2:
85+
orbit_func = at_lattice.find_orbit
86+
linopt_func = at_lattice.linopt2
87+
case LinoptType.LINOPT4:
88+
orbit_func = at_lattice.find_orbit4
89+
linopt_func = at_lattice.linopt4
90+
case LinoptType.LINOPT6:
91+
orbit_func = at_lattice.find_orbit6
92+
linopt_func = at_lattice.linopt6
93+
case _:
94+
raise ValueError(
95+
f"Error. Invalid linopt function selected: {sp.linopt}. Simulation "
96+
"data not calculated."
97+
)
98+
99+
# Perform pyAT orbit calculation
100+
orbit0, _ = orbit_func()
101+
logging.debug("Completed orbit calculation.")
102+
103+
# Perform pyAT linear optics calculation
104+
_, beamdata, twiss = linopt_func(
105+
refpts=refpts,
106+
get_chrom=sp.chromaticity,
107+
orbit=orbit0,
108+
keep_lattice=True,
109+
)
90110
logging.debug("Completed linear optics calculation.")
91111

92-
if not disable_emittance:
112+
if sp.emittance:
93113
emitdata = at_lattice.ohmi_envelope(orbit=orbit0, keep_lattice=True)
94114
logging.debug("Completed emittance calculation")
95115
else:
96116
emitdata = ()
97117

98-
if not disable_radiation:
118+
if sp.radiation:
99119
radint = at_lattice.get_radiation_integrals(twiss=twiss)
100120
logging.debug("Completed radiation calculation")
101121
else:
@@ -126,8 +146,6 @@ class ATSimulator:
126146
physics data is calculated.
127147
_rp (numpy.typing.NDArray): A boolean array to be used as refpts for the
128148
physics calculations.
129-
_disable_emittance (bool): Whether or not to perform the beam
130-
envelope based emittance calculations.
131149
_lattice_data (LatticeData): calculated physics data
132150
function linopt (see at.lattice.linear.py).
133151
_queue (cothread.EventQueue): A queue of changes to be applied to
@@ -144,10 +162,7 @@ class ATSimulator:
144162
def __init__(
145163
self,
146164
at_lattice,
147-
linopt_function="linopt6",
148-
disable_emittance=False,
149-
disable_chromaticity=False,
150-
disable_radiation=False,
165+
sim_params=None,
151166
callback=None,
152167
):
153168
"""
@@ -159,13 +174,8 @@ def __init__(
159174
160175
Args:
161176
at_lattice (at.lattice_object.Lattice): An instance of an AT lattice object.
162-
linopt_function (str): Which pyAT linear optics function to use: linopt2,
163-
linopt4, linopt6.
164-
disable_emittance (bool): Whether the emittance calculations should be
165-
disabled.
166-
disable_chromaticity (bool): Whether the chromaticity calculations should be
167-
disabled.
168-
disable_radiation (bool): Whether radiation calculations should be disabled.
177+
sim_params (SimParams | None): An optional dataclass containing the pyAT
178+
simulation parameters to use.
169179
callback (typing.Callable): To be called after completion of each round of
170180
physics calculations.
171181
@@ -177,23 +187,16 @@ def __init__(
177187
)
178188
self._at_lat = at_lattice
179189
self._rp = numpy.ones(len(at_lattice) + 1, dtype=bool)
180-
self._linopt_function = linopt_function
181-
self._disable_emittance = disable_emittance
182-
self._disable_chromaticity = disable_chromaticity
183-
self._disable_radiation = disable_radiation
184190

185-
if not self._disable_radiation:
191+
if sim_params is None:
192+
sim_params = SimParams()
193+
self._sim_params = sim_params
194+
195+
if self._sim_params.radiation:
186196
self._at_lat.radiation_on()
187197

188198
# Initial phys data calculation.
189-
self._lattice_data = calculate_optics(
190-
self._at_lat,
191-
self._rp,
192-
self._linopt_function,
193-
self._disable_emittance,
194-
self._disable_chromaticity,
195-
self._disable_radiation,
196-
)
199+
self._lattice_data = calculate_optics(self._at_lat, self._rp, self._sim_params)
197200

198201
# Threading stuff initialisation.
199202
self._queue = cothread.EventQueue()
@@ -261,12 +264,7 @@ def _recalculate_phys_data(self, callback):
261264
if bool(self._paused) is False:
262265
try:
263266
self._lattice_data = calculate_optics(
264-
self._at_lat,
265-
self._rp,
266-
self._linopt_function,
267-
self._disable_emittance,
268-
self._disable_chromaticity,
269-
self._disable_radiation,
267+
self._at_lat, self._rp, self._sim_params
270268
)
271269
except Exception as e:
272270
warn(at.AtWarning(e), stacklevel=1)
@@ -563,7 +561,7 @@ def get_emittance(self, field=None):
563561
Raises:
564562
pytac.FieldException: if the specified field is not valid for emittance.
565563
"""
566-
if not self._disable_emittance:
564+
if self._sim_params.emittance:
567565
if field is None:
568566
return self._lattice_data.emittance[0]["emitXY"]
569567
elif field == "x":

src/atip/utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ def load_at_lattice(mode="I04", **kwargs):
3131

3232
def loader(
3333
mode="I04",
34-
linopt_function="linopt6",
35-
disable_emittance=False,
36-
disable_chromaticity=False,
37-
disable_radiation=False,
34+
sim_params=None,
3835
callback=None,
3936
):
4037
"""Load a unified lattice of the specifed mode.
@@ -45,9 +42,10 @@ def loader(
4542
4643
Args:
4744
mode (str): The lattice operation mode.
45+
sim_params (SimParams | None): An optional dataclass containing the pyAT
46+
simulation parameters to use.
4847
callback (typing.Callable): Callable to be called after completion of each
4948
round of physics calculations in ATSimulator.
50-
disable_emittance (bool): Whether the emittance should be calculated.
5149
5250
Returns:
5351
pytac.lattice.Lattice: A Pytac lattice object with the simulator data
@@ -62,10 +60,7 @@ def loader(
6260
lattice = atip.load_sim.load(
6361
pytac_lattice,
6462
at_lattice,
65-
linopt_function,
66-
disable_emittance,
67-
disable_chromaticity,
68-
disable_radiation,
63+
sim_params,
6964
callback,
7065
)
7166
return lattice

0 commit comments

Comments
 (0)