Skip to content

Commit 93becfe

Browse files
committed
Create orbit.sim module
1 parent acc5349 commit 93becfe

File tree

3 files changed

+364
-3
lines changed

3 files changed

+364
-3
lines changed

py/orbit/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ subdir('space_charge')
2424
subdir('errors')
2525
subdir('matrix_lattice')
2626
subdir('teapot')
27+
subdir('sim')
2728

2829

2930
py_sources = files([

py/orbit/sim/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import linac
2-
import ring
1+
from . import linac
2+
from . import ring
33

44
__all__ = []
55
__all__.append("linac")
6-
__all__.append("ring")
6+
__all__.append("ring")

py/orbit/sim/linac.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import math
2+
import os
3+
import sys
4+
import time
5+
from typing import Callable
6+
from typing import Optional
7+
8+
import numpy as np
9+
10+
from orbit.core import orbit_mpi
11+
from orbit.core.bunch import Bunch
12+
from orbit.core.bunch import BunchTwissAnalysis
13+
from orbit.core.orbit_utils import BunchExtremaCalculator
14+
from orbit.lattice import AccActionsContainer
15+
from orbit.lattice import AccNode
16+
from orbit.lattice import AccLattice
17+
from orbit.utils.consts import speed_of_light
18+
19+
20+
def get_z_to_phase_coeff(bunch: Bunch, frequency: float) -> float:
21+
wavelength = speed_of_light / frequency
22+
return -360.0 / (bunch.getSyncParticle().beta() * wavelength)
23+
24+
25+
def reverse_bunch(bunch: Bunch) -> Bunch:
26+
size = bunch.getSize()
27+
for i in range(size):
28+
bunch.xp(i, -bunch.xp(i))
29+
bunch.yp(i, -bunch.yp(i))
30+
bunch.z(i, -bunch.z(i))
31+
return bunch
32+
33+
34+
def track_bunch(
35+
bunch: Bunch,
36+
lattice: AccLattice,
37+
index_start: int = None,
38+
index_stop: int = None,
39+
copy: bool = False,
40+
**kwargs
41+
) -> Bunch:
42+
"""Track bunch forward or backward through the lattice."""
43+
if index_start is None:
44+
index_start = 0
45+
46+
if index_stop is None:
47+
index_stop = len(lattice.getNodes()) - 1
48+
49+
reverse = index_start > index_stop
50+
node_start = lattice.getNodes()[index_start]
51+
node_stop = lattice.getNodes()[index_stop]
52+
53+
bunch_out = None
54+
if copy:
55+
bunch_out = Bunch()
56+
bunch.copyBunchTo(bunch_out)
57+
else:
58+
bunch_out = bunch
59+
60+
if reverse:
61+
bunch_out = reverse_bunch(bunch_out)
62+
lattice.reverseOrder()
63+
64+
lattice.trackBunch(
65+
bunch_out,
66+
index_start=lattice.getNodeIndex(node_start),
67+
index_stop=lattice.getNodeIndex(node_stop),
68+
**kwargs
69+
)
70+
71+
if reverse:
72+
bunch_out = reverse_bunch(bunch_out)
73+
lattice.reverseOrder()
74+
75+
return bunch_out
76+
77+
78+
79+
class BunchWriter:
80+
"""Writes bunch to file.
81+
82+
File name is "{output_dir}/bunch_{index}_{node_name}.dat".
83+
Example:
84+
- bunch_0001_QH05.dat
85+
- bunch_0002_QV06.dat
86+
"""
87+
88+
def __init__(self, output_dir: str = None, index: int = 0, verbose: int = 1) -> None:
89+
self.output_dir = output_dir
90+
self.index = index
91+
self.verbose = verbose
92+
self.position = 0.0
93+
94+
self.mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
95+
self.mpi_rank = orbit_mpi.MPI_Comm_rank(self.mpi_comm)
96+
97+
def __call__(
98+
self, bunch: Bunch, node_name: str = None, position: float = None, filename: str = None
99+
) -> None:
100+
if filename is None:
101+
filename = "bunch"
102+
if self.index is not None:
103+
filename = "{}_{:04.0f}".format(filename, self.index)
104+
if node_name is not None:
105+
node_name = node_name.replace(" ", "_")
106+
filename = "{}_{}".format(filename, node_name)
107+
filename = "{}.dat".format(filename)
108+
109+
filename = os.path.join(self.output_dir, filename)
110+
111+
if self.mpi_rank == 0 and self.verbose:
112+
print("Writing bunch to file {}".format(filename))
113+
114+
bunch.dumpBunch(filename)
115+
116+
if self.index is not None:
117+
self.index += 1
118+
119+
if position is not None:
120+
self.position = position
121+
122+
123+
class BunchMonitor:
124+
"""Monitors bunch within linac."""
125+
126+
def __init__(
127+
self,
128+
output_dir: str = None,
129+
stride: float = 0.1,
130+
stride_write: float = math.inf,
131+
position_offset: float = 0.0,
132+
rf_frequency: float = None,
133+
stop_node: Optional[str] = None,
134+
bunch_writer: BunchWriter = None,
135+
verbose: bool = True,
136+
) -> None:
137+
"""Constructor.
138+
139+
Args:
140+
output_dir: Path to output directory.
141+
stride: Distance between scalar bunch measurements.
142+
stride_write: Distance between saving bunch to file.
143+
position_offset: Starting position in lattice [m].
144+
rf_frequency: For converting longitudinal position to phase.
145+
stop_node: Stop at this node if provided.
146+
bunch_writer: Writes bunch to file.
147+
verbose: Whether to print updates.
148+
"""
149+
150+
# Save MPI rank
151+
self.mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
152+
self.mpi_rank = orbit_mpi.MPI_Comm_rank(self.mpi_comm)
153+
154+
# Settings
155+
self.output_dir = output_dir
156+
self.stride = stride
157+
self.stride_write = stride_write
158+
self.rf_frequency = rf_frequency
159+
self.verbose = verbose
160+
self.stop_node = stop_node
161+
162+
# State
163+
self.position = self.position_offset = position_offset
164+
self.index = 0
165+
self.start_time = None
166+
self.time_ellapsed = 0.0
167+
self.reached_stop_node = False
168+
169+
# Helpers
170+
self.bunch_writer = bunch_writer
171+
172+
# Store scalar history in `history` dictionary.
173+
if self.mpi_rank == 0:
174+
keys = [
175+
"position",
176+
"n_parts",
177+
"gamma",
178+
"beta",
179+
"energy",
180+
"x_rms",
181+
"y_rms",
182+
"z_rms",
183+
"z_rms_deg",
184+
"z_to_phase_coeff",
185+
"x_min",
186+
"x_max",
187+
"y_min",
188+
"y_max",
189+
"z_min",
190+
"z_max",
191+
"eps_x",
192+
"eps_y",
193+
"eps_z",
194+
"eps_xy",
195+
"eps_xyz",
196+
]
197+
for i in range(6):
198+
keys.append("mean_{}".format(i))
199+
for i in range(6):
200+
for j in range(i + 1):
201+
keys.append("cov_{}-{}".format(j, i))
202+
203+
self.history = {}
204+
for key in keys:
205+
self.history[key] = []
206+
207+
if self.rf_frequency is None:
208+
self.history.pop("z_rms_deg")
209+
self.history.pop("z_to_phase_coeff")
210+
211+
if self.output_dir is not None:
212+
filename = os.path.join(self.output_dir, "history.dat")
213+
self.history_file = open(filename, "w")
214+
215+
# Write header line
216+
header = "#"
217+
header = header + ",".join(keys)
218+
header = header[:-1] + "\n"
219+
self.history_file.write(header)
220+
221+
def __call__(self, params_dict: dict, force_update: bool = False) -> None:
222+
"""Measure the bunch.
223+
224+
Args:
225+
params_dict: Dictionary with the following keys:
226+
"bunch": Reference to tracked Bunch object.
227+
"path_length": Total tracking distance.
228+
"node": Reference to current AccNode object.
229+
force_update: Forces measurement update.
230+
"""
231+
# Update position; decide whether to proceed.
232+
position = params_dict["path_length"] + self.position_offset
233+
is_stop_node = (self.stop_node is not None) and (
234+
params_dict["node"].getName() == self.stop_node
235+
)
236+
237+
if force_update:
238+
pass
239+
elif is_stop_node:
240+
if self.reached_stop_node:
241+
return
242+
self.reached_stop_node = True
243+
elif self.index > 0:
244+
if (position - self.position) < self.stride:
245+
return
246+
self.position = position
247+
248+
# Update ellapsed time.
249+
if self.start_time is None:
250+
self.start_time = time.time()
251+
self.time_ellapsed = time.time() - self.start_time
252+
253+
# Collect bunch and node from parameter dictionary.
254+
bunch = params_dict["bunch"]
255+
node = params_dict["node"]
256+
257+
# Measure scalars.
258+
beta = bunch.getSyncParticle().beta()
259+
gamma = bunch.getSyncParticle().gamma()
260+
bunch_size_global = bunch.getSizeGlobal()
261+
if self.mpi_rank == 0:
262+
self.history["position"].append(position)
263+
self.history["n_parts"].append(bunch_size_global)
264+
self.history["gamma"].append(gamma)
265+
self.history["beta"].append(beta)
266+
self.history["energy"].append(bunch.getSyncParticle().kinEnergy())
267+
268+
# Measure bunch centroid and 6 x 6 covariance matrix.
269+
twiss_analysis = BunchTwissAnalysis()
270+
twiss_analysis.computeBunchMoments(bunch, 2, 0, 0)
271+
272+
centroid = np.zeros(6)
273+
for i in range(6):
274+
centroid[i] = twiss_analysis.getAverage(i)
275+
276+
cov_matrix = np.zeros((6, 6))
277+
for i in range(6):
278+
for j in range(i + 1):
279+
cov_matrix[i, j] = cov_matrix[j, i] = twiss_analysis.getCorrelation(j, i)
280+
281+
if self.mpi_rank == 0:
282+
for i in range(6):
283+
key = "mean_{}".format(i)
284+
value = centroid[i]
285+
self.history[key].append(value)
286+
287+
if self.mpi_rank == 0:
288+
for i in range(6):
289+
for j in range(i + 1):
290+
key = "cov_{}-{}".format(j, i)
291+
value = cov_matrix[j, i]
292+
self.history[key].append(value)
293+
294+
# Record other parameters derived from covariance matrix.
295+
if self.mpi_rank == 0:
296+
x_rms = math.sqrt(cov_matrix[0, 0])
297+
y_rms = math.sqrt(cov_matrix[2, 2])
298+
z_rms = math.sqrt(cov_matrix[4, 4])
299+
self.history["x_rms"].append(x_rms)
300+
self.history["y_rms"].append(y_rms)
301+
self.history["z_rms"].append(z_rms)
302+
303+
if self.rf_frequency is not None:
304+
z_to_phase_coeff = get_z_to_phase_coeff(bunch, self.rf_frequency)
305+
z_rms_deg = -z_to_phase_coeff * z_rms
306+
self.history["z_rms_deg"].append(z_rms_deg)
307+
self.history["z_to_phase_coeff"].append(z_to_phase_coeff)
308+
309+
eps_x = np.sqrt(np.linalg.det(cov_matrix[0:2, 0:2]))
310+
eps_y = np.sqrt(np.linalg.det(cov_matrix[2:4, 2:4]))
311+
eps_z = np.sqrt(np.linalg.det(cov_matrix[4:6, 4:6]))
312+
eps_xy = np.sqrt(np.linalg.det(cov_matrix[0:4, 0:4]))
313+
eps_xyz = np.sqrt(np.linalg.det(cov_matrix))
314+
self.history["eps_x"].append(eps_x)
315+
self.history["eps_y"].append(eps_y)
316+
self.history["eps_z"].append(eps_z)
317+
self.history["eps_xy"].append(eps_xy)
318+
self.history["eps_xyz"].append(eps_xyz)
319+
320+
# Measure min/max particle coordinates.
321+
extrema_calculator = BunchExtremaCalculator()
322+
(x_min, x_max, y_min, y_max, z_min, z_max) = extrema_calculator.extremaXYZ(bunch)
323+
if self.mpi_rank == 0:
324+
self.history["x_min"].append(x_min)
325+
self.history["x_max"].append(x_max)
326+
self.history["y_min"].append(y_min)
327+
self.history["y_max"].append(y_max)
328+
self.history["z_min"].append(z_min)
329+
self.history["z_max"].append(z_max)
330+
331+
# Print update
332+
if self.verbose and (self.mpi_rank == 0):
333+
message = ""
334+
message += " index={:05.0f}".format(self.index)
335+
message += " t={:0.2f}".format(self.time_ellapsed)
336+
message += " s={:0.3f}".format(self.position)
337+
message += " xrms={:0.2f}".format(x_rms * 1000.0)
338+
message += " yrms={:0.2f}".format(y_rms * 1000.0)
339+
message += " zrms={:0.2f}".format(z_rms * 1000.0)
340+
message += " size={}".format(bunch_size_global)
341+
message += " node={}".format(node.getName())
342+
print(message)
343+
sys.stdout.flush() # for MPI (bug?)
344+
345+
# Increase index
346+
self.index += 1
347+
348+
# Write phase space coordinates to file
349+
if self.bunch_writer is not None:
350+
if (position - self.bunch_writer.position) >= self.stride_write:
351+
self.bunch_writer(bunch, node_name=node.getName(), position=position)
352+
353+
# Write new line to history file
354+
if (self.mpi_rank == 0) and (self.output_dir is not None):
355+
data = [self.history[key][-1] for key in self.history]
356+
line = ""
357+
for x in data:
358+
line = line + "{},".format(x)
359+
line = line[:-1] + "\n"
360+
self.history_file.write(line)

0 commit comments

Comments
 (0)