11"""Reconstruct longitudinal phase space distribution from turn-by-turn projections.
22
3- This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice
3+ This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice
44model consisting of a harmonic RF cavity surrounded by two drifts. Things are a bit slow
55because we have to repeatedly convert between NumPy arrays and Bunch objects, but it works.
66
77Note that one MENT iteration requires simulating all projectionos. If projectiono k
88is measured after k turns, then we must first track the bunch 1 turn, then resample
99and track 2 turns, then resample and track 3 turns, etc. In total, we must track
10- n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better
10+ n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better
1111option.
1212"""
1313import os
5151# Forward model
5252# --------------------------------------------------------------------------------------
5353
54+
5455def get_part_coords (bunch : Bunch , index : int ) -> list [float ]:
5556 x = bunch .x (index )
5657 y = bunch .y (index )
@@ -81,7 +82,9 @@ def get_bunch_coords(bunch: Bunch, axis: tuple[int, ...] = None) -> np.ndarray:
8182 return x
8283
8384
84- def set_bunch_coords (bunch : Bunch , x : np .ndarray , axis : tuple [int , ...] = None ) -> Bunch :
85+ def set_bunch_coords (
86+ bunch : Bunch , x : np .ndarray , axis : tuple [int , ...] = None
87+ ) -> Bunch :
8588 if axis is None :
8689 axis = tuple (range (6 ))
8790
@@ -128,7 +131,7 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
128131 bunch = self .track_bunch ()
129132 x_out = get_bunch_coords (bunch , axis = self .axis )
130133 return x_out
131-
134+
132135
133136# Create accelerator lattice (drift, rf, drift)
134137drift_node_1 = DriftTEAPOT ()
@@ -142,7 +145,9 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
142145rf_synchronous_de = 0.0
143146rf_voltage = 300.0e-06
144147rf_phase = 0.0
145- rf_node = Harmonic_RFNode (z_to_phi , rf_synchronous_de , rf_hnum , rf_voltage , rf_phase , rf_length )
148+ rf_node = Harmonic_RFNode (
149+ z_to_phi , rf_synchronous_de , rf_hnum , rf_voltage , rf_phase , rf_length
150+ )
146151
147152lattice = TEAPOT_Ring ()
148153lattice .addNode (drift_node_1 )
@@ -178,11 +183,8 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
178183 transform = ORBITTransform (lattice , bunch , nturns = nturns , axis = (4 , 5 ))
179184 transforms .append (transform )
180185
181- limits = [
182- (- 0.5 * lattice .getLength (), + 0.5 * lattice .getLength ()),
183- (- 0.030 , 0.030 )
184- ]
185-
186+ limits = [(- 0.5 * lattice .getLength (), + 0.5 * lattice .getLength ()), (- 0.030 , 0.030 )]
187+
186188# Create a list of histogram diagnostics for each transform.
187189bin_edges = np .linspace (limits [0 ][0 ], limits [0 ][1 ], 100 )
188190diagnostics = []
@@ -235,19 +237,25 @@ def plot_model(model):
235237 projections_pred = ment .utils .unravel (projections_pred )
236238
237239 fig , axs = plt .subplots (
238- ncols = nmeas , figsize = (11.0 , 1.0 ), sharey = True , sharex = True , constrained_layout = True
240+ ncols = nmeas ,
241+ figsize = (11.0 , 1.0 ),
242+ sharey = True ,
243+ sharex = True ,
244+ constrained_layout = True ,
239245 )
240246 for i , ax in enumerate (axs ):
241247 values_pred = projections_pred [i ].values
242248 values_true = projections_true [i ].values
243249 ax .plot (values_pred / values_true .max (), color = "lightgray" )
244- ax .plot (values_true / values_true .max (), color = "black" , lw = 0.0 , marker = "." , ms = 2.0 )
250+ ax .plot (
251+ values_true / values_true .max (), color = "black" , lw = 0.0 , marker = "." , ms = 2.0
252+ )
245253 return fig
246254
247255
248256for epoch in range (4 ):
249257 print ("epoch =" , epoch )
250-
258+
251259 if epoch > 0 :
252260 model .gauss_seidel_step (learning_rate = 0.90 )
253261
@@ -264,4 +272,3 @@ def plot_model(model):
264272 ax .hist2d (x [:, 0 ], x [:, 1 ], bins = 100 , range = limits )
265273fig .savefig (os .path .join (output_dir , f"fig_dist_{ epoch :02.0f} .png" ))
266274plt .close ()
267-
0 commit comments