@@ -64,9 +64,7 @@ def unnormalize_beta_z(mass: float, kin_energy: float, beta_z: float) -> float:
6464 return beta_z
6565
6666
67- def get_node_info (
68- name : str = None , position : float = None , lattice : AccLattice = None
69- ) -> dict :
67+ def get_node_info (name : str = None , position : float = None , lattice : AccLattice = None ) -> dict :
7068 """Return node, node index, start and stop position from node name or center position.
7169
7270 Returns dict:
@@ -228,9 +226,7 @@ def make_phase_aperture_node(
228226 return aperture_node
229227
230228
231- def make_energy_aperture_node (
232- energy_min : float , energy_max : float
233- ) -> LinacEnergyApertureNode :
229+ def make_energy_aperture_node (energy_min : float , energy_max : float ) -> LinacEnergyApertureNode :
234230 aperture_node = LinacEnergyApertureNode ()
235231 aperture_node .setMinMaxEnergy (energy_min , energy_max )
236232 return aperture_node
@@ -256,9 +252,7 @@ def check_sync_time(
256252 sync_time_design = 0.0
257253
258254 if start_node_info ["index" ] > 0 :
259- design_bunch = lattice .trackDesignBunch (
260- bunch , index_start = 0 , index_stop = start ["index" ]
261- )
255+ design_bunch = lattice .trackDesignBunch (bunch , index_start = 0 , index_stop = start ["index" ])
262256 sync_time_design = design_bunch .getSyncParticle ().time ()
263257
264258 if _mpi_rank == 0 and verbose :
@@ -271,11 +265,7 @@ def check_sync_time(
271265 print (" Setting to design value." )
272266 bunch .getSyncParticle ().time (sync_time_design )
273267 if _mpi_rank == 0 and verbose :
274- print (
275- "bunch.getSyncParticle().time() = {}" .format (
276- bunch .getSyncParticle ().time ()
277- )
278- )
268+ print ("bunch.getSyncParticle().time() = {}" .format (bunch .getSyncParticle ().time ()))
279269
280270
281271def estimate_transfer_matrix (
@@ -309,9 +299,7 @@ def estimate_transfer_matrix(
309299 (- 0.000 , 0.000 ),
310300 ]
311301 test_bunch_lb , test_bunch_ub = list (zip (* test_bunch_limits ))
312- test_bunch_coords = rng .uniform (
313- test_bunch_lb , test_bunch_ub , size = (test_bunch_size , 6 )
314- )
302+ test_bunch_coords = rng .uniform (test_bunch_lb , test_bunch_ub , size = (test_bunch_size , 6 ))
315303
316304 test_bunch = Bunch ()
317305 bunch .copyEmptyBunchTo (test_bunch )
@@ -327,21 +315,15 @@ def estimate_transfer_matrix(
327315 return matrix
328316
329317
330- def save_node_positions (
331- lattice : AccLattice , filename : str = "lattice_nodes.txt"
332- ) -> None :
318+ def save_node_positions (lattice : AccLattice , filename : str = "lattice_nodes.txt" ) -> None :
333319 file = open (filename , "w" )
334320 file .write ("node position length\n " )
335321 for node in lattice .getNodes ():
336- file .write (
337- "{} {} {}\n " .format (node .getName (), node .getPosition (), node .getLength ())
338- )
322+ file .write ("{} {} {}\n " .format (node .getName (), node .getPosition (), node .getLength ()))
339323 file .close ()
340324
341325
342- def save_lattice_structure (
343- lattice : AccLattice , filename : str = "lattice_structure.txt"
344- ) -> None :
326+ def save_lattice_structure (lattice : AccLattice , filename : str = "lattice_structure.txt" ) -> None :
345327 file = open (filename , "w" )
346328 file .write (lattice .structureToText ())
347329 file .close ()
@@ -583,9 +565,7 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
583565 # Measure mean and covariance.
584566 twiss_analysis = BunchTwissAnalysis ()
585567 order = 2
586- twiss_analysis .computeBunchMoments (
587- bunch , order , self .dispersion_flag , self .emit_norm_flag
588- )
568+ twiss_analysis .computeBunchMoments (bunch , order , self .dispersion_flag , self .emit_norm_flag )
589569
590570 mean = np .zeros (6 )
591571 for i in range (6 ):
@@ -650,9 +630,7 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
650630
651631 # Measure maximum phase space coordinates.
652632 extrema_calculator = BunchExtremaCalculator ()
653- (x_min , x_max , y_min , y_max , z_min , z_max ) = extrema_calculator .extremaXYZ (
654- bunch
655- )
633+ (x_min , x_max , y_min , y_max , z_min , z_max ) = extrema_calculator .extremaXYZ (bunch )
656634 if _mpi_rank == 0 :
657635 self .history ["x_min" ] = x_min
658636 self .history ["x_max" ] = x_max
@@ -691,17 +669,13 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
691669
692670 # Write phase space coordinates to file.
693671 if self .write is not None :
694- if force_update or (
695- self .stride_write <= (position - self .last_write_position )
696- ):
672+ if force_update or (self .stride_write <= (position - self .last_write_position )):
697673 self .write (bunch , tag = node .getName ())
698674 self .last_write_position = position
699675
700676 # Call plotting routines.
701677 if self .plot is not None and _mpi_rank == 0 :
702- if force_update or (
703- self .stride_plot <= (position - self .last_plot_position )
704- ):
678+ if force_update or (self .stride_plot <= (position - self .last_plot_position )):
705679 info = dict ()
706680 for key in self .history :
707681 if self .history [key ]:
@@ -749,9 +723,7 @@ def action(self, params_dict: dict) -> None:
749723 order = 2
750724 dispersion_flag = 0
751725 emitt_norm_flag = 0
752- twiss_analysis .computeBunchMoments (
753- bunch , order , dispersion_flag , emitt_norm_flag
754- )
726+ twiss_analysis .computeBunchMoments (bunch , order , dispersion_flag , emitt_norm_flag )
755727 x_rms = np .sqrt (twiss_analysis .getCorrelation (0 , 0 ))
756728 y_rms = np .sqrt (twiss_analysis .getCorrelation (2 , 2 ))
757729
@@ -817,9 +789,7 @@ def forward(self, x: np.ndarray) -> np.ndarray:
817789
818790 bunch = self .get_new_bunch ()
819791 bunch = set_bunch_coords (bunch , x_new , verbose = False )
820- self .lattice .trackBunch (
821- bunch , index_start = self .index_start , index_stop = self .index_stop
822- )
792+ self .lattice .trackBunch (bunch , index_start = self .index_start , index_stop = self .index_stop )
823793 U = bunch .get_bunch_coords (bunch )
824794 U = U [:, self .axis ]
825795 return U
@@ -835,9 +805,7 @@ def inverse(self, u: np.ndarray) -> np.ndarray:
835805 bunch = oset_bunch_coords (bunch , u_new , verbose = False )
836806 bunch = reverse_bunch (bunch )
837807 self .lattice .reverseOrder ()
838- self .lattice .trackBunch (
839- bunch , index_start = self .index_stop , index_stop = self .index_start
840- )
808+ self .lattice .trackBunch (bunch , index_start = self .index_stop , index_stop = self .index_start )
841809 self .lattice .reverseOrder ()
842810 bunch = reverse_bunch (bunch )
843811 x = get_bunch_coords (bunch )
0 commit comments