Skip to content

Commit a2464c9

Browse files
committed
black
1 parent abb9dcb commit a2464c9

File tree

11 files changed

+60
-138
lines changed

11 files changed

+60
-138
lines changed

orbit_tools/bunch.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def get_bunch_coords(bunch: Bunch, axis: tuple[int, ...] = None) -> np.ndarray:
5454
return X[:, axis]
5555

5656

57-
def set_bunch_coords(
58-
bunch: Bunch, X: np.ndarray, axis: tuple[int, ...] = None
59-
) -> Bunch:
57+
def set_bunch_coords(bunch: Bunch, X: np.ndarray, axis: tuple[int, ...] = None) -> Bunch:
6058
if axis is None:
6159
axis = tuple(range(6))
6260

@@ -115,9 +113,7 @@ def shift_bunch_centroid(bunch: Bunch, offset: np.ndarray) -> Bunch:
115113
return bunch
116114

117115

118-
def set_bunch_cov(
119-
bunch: Bunch, covariance_matrix: np.ndarray, block_diag: bool = True
120-
) -> Bunch:
116+
def set_bunch_cov(bunch: Bunch, covariance_matrix: np.ndarray, block_diag: bool = True) -> Bunch:
121117
X_old = get_bunch_coords(bunch)
122118
S_old = np.cov(X_old.T)
123119

@@ -136,9 +132,7 @@ def set_bunch_cov(
136132
return bunch
137133

138134

139-
def transform_bunch(
140-
bunch: Bunch, transform: Callable, axis: tuple[int, ...] = None
141-
) -> Bunch:
135+
def transform_bunch(bunch: Bunch, transform: Callable, axis: tuple[int, ...] = None) -> Bunch:
142136
if axis is None:
143137
axis = tuple(range(6))
144138

@@ -147,9 +141,7 @@ def transform_bunch(
147141
return set_bunch_coords(bunch, X)
148142

149143

150-
def transform_bunch_linear(
151-
bunch: Bunch, matrix: np.ndarray, axis: tuple[int, ...] = None
152-
) -> Bunch:
144+
def transform_bunch_linear(bunch: Bunch, matrix: np.ndarray, axis: tuple[int, ...] = None) -> Bunch:
153145
return transform_bunch(bunch, lambda x: np.matmul(x, matrix.T), axis=axis)
154146

155147

@@ -170,9 +162,7 @@ def set_bunch_current(bunch: Bunch, current: float, frequency: float) -> Bunch:
170162
171163
Assumes bunch charge is already set.
172164
"""
173-
intensity = current_to_intensity(
174-
current=current, frequency=frequency, charge=bunch.charge()
175-
)
165+
intensity = current_to_intensity(current=current, frequency=frequency, charge=bunch.charge())
176166
bunch_size_global = bunch.getSizeGlobal()
177167
if bunch_size_global > 0:
178168
macro_size = intensity / bunch_size_global
@@ -315,9 +305,7 @@ def get_bunch_cov(
315305
return S
316306

317307

318-
def generate_bunch(
319-
sample: Callable, size: int, bunch: Bunch = None, verbose: bool = True
320-
) -> Bunch:
308+
def generate_bunch(sample: Callable, size: int, bunch: Bunch = None, verbose: bool = True) -> Bunch:
321309
"""Generate bunch from particle sampler..
322310
323311
Parameters

orbit_tools/cov.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ def normalization_matrix_from_twiss_2d(
9999
return np.linalg.inv(V)
100100

101101

102-
def normalization_matrix_from_twiss(
103-
twiss_params: list[tuple[float, float, float]]
104-
) -> np.ndarray:
102+
def normalization_matrix_from_twiss(twiss_params: list[tuple[float, float, float]]) -> np.ndarray:
105103
"""2N x 2N block-diagonal normalization matrix from Twiss parameters.
106104
107105
Parameters
@@ -117,9 +115,7 @@ def normalization_matrix_from_twiss(
117115
ndim = len(twiss_params) // 2
118116
V = np.zeros((ndim, ndim))
119117
for i in range(0, ndim, 2):
120-
V[i : i + 2, i : i + 2] = normalization_matrix_from_twiss_2d(
121-
*twiss_params[i : i + 2]
122-
)
118+
V[i : i + 2, i : i + 2] = normalization_matrix_from_twiss_2d(*twiss_params[i : i + 2])
123119
return np.linalg.inv(V)
124120

125121

@@ -175,9 +171,7 @@ def cov_to_corr(S: np.ndarray) -> np.ndarray:
175171
return np.linalg.multi_dot([Dinv, S, Dinv])
176172

177173

178-
def rms_ellipse_params(
179-
S: np.ndarray, axis: tuple[int, ...] = None
180-
) -> tuple[float, ...]:
174+
def rms_ellipse_params(S: np.ndarray, axis: tuple[int, ...] = None) -> tuple[float, ...]:
181175
"""Return projected rms ellipse dimensions and orientation.
182176
183177
Parameters

orbit_tools/diag.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616
from orbit.lattice import AccNode
1717

1818

19-
def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
19+
def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
2020
return np.vstack([C.ravel() for C in np.meshgrid(*coords, indexing="ij")]).T
2121

2222

2323
def make_grid(
24-
shape: tuple[int, ...],
25-
limits: list[tuple[float, float]]
24+
shape: tuple[int, ...], limits: list[tuple[float, float]]
2625
) -> Union[Grid1D, Grid2D, Grid3D]:
27-
26+
2827
ndim = len(shape)
29-
28+
3029
grid = None
3130
if ndim == 1:
3231
grid = Grid1D(shape[0] + 1, limits[0][0], limits[0][1])
@@ -52,7 +51,7 @@ def make_grid(
5251
raise ValueError
5352

5453
return grid
55-
54+
5655

5756
class Diagnostic:
5857
def __init__(self, output_dir: str = None, verbose: bool = True) -> None:
@@ -96,7 +95,9 @@ def __init__(
9695
self.ndim = len(axis)
9796

9897
if self.ndim > 2:
99-
raise NotImplementedError("BunchHistogram does not yet support 3D grids. See https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/46 and https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/47.")
98+
raise NotImplementedError(
99+
"BunchHistogram does not yet support 3D grids. See https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/46 and https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/47."
100+
)
100101

101102
self.dims = ["x", "xp", "y", "yp", "z", "dE"]
102103
self.dims = [self.dims[i] for i in self.axis]
@@ -107,28 +108,28 @@ def __init__(
107108
np.linspace(self.limits[i][0], self.limits[i][1], self.shape[i] + 1)
108109
for i in range(self.ndim)
109110
]
110-
self.coords = [0.5 * (e[:-1] + e[1:]) for e in self.edges]
111+
self.coords = [0.5 * (e[:-1] + e[1:]) for e in self.edges]
111112
self.values = np.zeros(shape)
112-
113+
113114
self.points = get_grid_points(self.coords)
114115
self.cell_volume = np.prod([e[1] - e[0] for e in self.edges])
115116

116117
self.grid = make_grid(self.shape, self.limits)
117118
self.method = method
118119
self.transform = transform
119120
self.normalize = normalize
120-
121+
121122
def reset(self) -> None:
122123
self.grid.setZero()
123124

124125
def sync_mpi(self) -> None:
125126
self.grid.synchronizeMPI(self.mpi_comm)
126127

127-
def bin_bunch(self, bunch: Bunch) -> None:
128+
def bin_bunch(self, bunch: Bunch) -> None:
128129
macrosize = bunch.macroSize()
129130
if macrosize == 0:
130131
bunch.macroSize(1.0)
131-
132+
132133
if self.method == "bilinear":
133134
self.grid.binBunchBilinear(bunch, *self.axis)
134135
else:
@@ -139,7 +140,7 @@ def bin_bunch(self, bunch: Bunch) -> None:
139140
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
140141
self.bin_bunch(bunch)
141142
self.sync_mpi()
142-
143+
143144
values = np.zeros(self.points.shape[0])
144145
if self.method == "bilinear":
145146
for i, point in enumerate(self.points):
@@ -158,7 +159,7 @@ def compute_histogram(self, bunch: Bunch) -> np.ndarray:
158159
values /= values_sum
159160
values /= self.cell_volume
160161
return values
161-
162+
162163
def track(self, params_dict: dict) -> None:
163164
bunch_copy = Bunch()
164165
bunch = params_dict["bunch"]
@@ -186,7 +187,7 @@ def get_filename(self) -> str:
186187
class BunchHistogram1D(BunchHistogram):
187188
def __init__(self, **kwargs) -> None:
188189
super().__init__(**kwargs)
189-
190+
190191

191192
class BunchHistogram2D(BunchHistogram):
192193
def __init__(self, **kwargs) -> None:
@@ -195,4 +196,4 @@ def __init__(self, **kwargs) -> None:
195196

196197
class BunchHistogram3D(BunchHistogram):
197198
def __init__(self, **kwargs) -> None:
198-
super().__init__(**kwargs)
199+
super().__init__(**kwargs)

orbit_tools/hydra.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
1212
repo = git.Repo(search_parent_directories=True)
1313
sha = repo.head.object.hexsha
1414

15-
output_dir = os.path.join(
16-
config.hydra.runtime.output_dir, config.hydra.output_subdir
17-
)
15+
output_dir = os.path.join(config.hydra.runtime.output_dir, config.hydra.output_subdir)
1816
filename = os.path.join(output_dir, "git_sha.txt")
1917

2018
with open(filename, "w") as file:

orbit_tools/linac/core.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def unnormalize_beta_z(mass: float, kin_energy: float, beta_z: float) -> float:
6464
return beta_z
6565

6666

67-
def get_node_info(
68-
name: str = None, position: float = None, lattice: AccLattice = None
69-
) -> dict:
67+
def get_node_info(name: str = None, position: float = None, lattice: AccLattice = None) -> dict:
7068
"""Return node, node index, start and stop position from node name or center position.
7169
7270
Returns dict:
@@ -228,9 +226,7 @@ def make_phase_aperture_node(
228226
return aperture_node
229227

230228

231-
def make_energy_aperture_node(
232-
energy_min: float, energy_max: float
233-
) -> LinacEnergyApertureNode:
229+
def make_energy_aperture_node(energy_min: float, energy_max: float) -> LinacEnergyApertureNode:
234230
aperture_node = LinacEnergyApertureNode()
235231
aperture_node.setMinMaxEnergy(energy_min, energy_max)
236232
return aperture_node
@@ -256,9 +252,7 @@ def check_sync_time(
256252
sync_time_design = 0.0
257253

258254
if start_node_info["index"] > 0:
259-
design_bunch = lattice.trackDesignBunch(
260-
bunch, index_start=0, index_stop=start["index"]
261-
)
255+
design_bunch = lattice.trackDesignBunch(bunch, index_start=0, index_stop=start["index"])
262256
sync_time_design = design_bunch.getSyncParticle().time()
263257

264258
if _mpi_rank == 0 and verbose:
@@ -271,11 +265,7 @@ def check_sync_time(
271265
print(" Setting to design value.")
272266
bunch.getSyncParticle().time(sync_time_design)
273267
if _mpi_rank == 0 and verbose:
274-
print(
275-
"bunch.getSyncParticle().time() = {}".format(
276-
bunch.getSyncParticle().time()
277-
)
278-
)
268+
print("bunch.getSyncParticle().time() = {}".format(bunch.getSyncParticle().time()))
279269

280270

281271
def estimate_transfer_matrix(
@@ -309,9 +299,7 @@ def estimate_transfer_matrix(
309299
(-0.000, 0.000),
310300
]
311301
test_bunch_lb, test_bunch_ub = list(zip(*test_bunch_limits))
312-
test_bunch_coords = rng.uniform(
313-
test_bunch_lb, test_bunch_ub, size=(test_bunch_size, 6)
314-
)
302+
test_bunch_coords = rng.uniform(test_bunch_lb, test_bunch_ub, size=(test_bunch_size, 6))
315303

316304
test_bunch = Bunch()
317305
bunch.copyEmptyBunchTo(test_bunch)
@@ -327,21 +315,15 @@ def estimate_transfer_matrix(
327315
return matrix
328316

329317

330-
def save_node_positions(
331-
lattice: AccLattice, filename: str = "lattice_nodes.txt"
332-
) -> None:
318+
def save_node_positions(lattice: AccLattice, filename: str = "lattice_nodes.txt") -> None:
333319
file = open(filename, "w")
334320
file.write("node position length\n")
335321
for node in lattice.getNodes():
336-
file.write(
337-
"{} {} {}\n".format(node.getName(), node.getPosition(), node.getLength())
338-
)
322+
file.write("{} {} {}\n".format(node.getName(), node.getPosition(), node.getLength()))
339323
file.close()
340324

341325

342-
def save_lattice_structure(
343-
lattice: AccLattice, filename: str = "lattice_structure.txt"
344-
) -> None:
326+
def save_lattice_structure(lattice: AccLattice, filename: str = "lattice_structure.txt") -> None:
345327
file = open(filename, "w")
346328
file.write(lattice.structureToText())
347329
file.close()
@@ -583,9 +565,7 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
583565
# Measure mean and covariance.
584566
twiss_analysis = BunchTwissAnalysis()
585567
order = 2
586-
twiss_analysis.computeBunchMoments(
587-
bunch, order, self.dispersion_flag, self.emit_norm_flag
588-
)
568+
twiss_analysis.computeBunchMoments(bunch, order, self.dispersion_flag, self.emit_norm_flag)
589569

590570
mean = np.zeros(6)
591571
for i in range(6):
@@ -650,9 +630,7 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
650630

651631
# Measure maximum phase space coordinates.
652632
extrema_calculator = BunchExtremaCalculator()
653-
(x_min, x_max, y_min, y_max, z_min, z_max) = extrema_calculator.extremaXYZ(
654-
bunch
655-
)
633+
(x_min, x_max, y_min, y_max, z_min, z_max) = extrema_calculator.extremaXYZ(bunch)
656634
if _mpi_rank == 0:
657635
self.history["x_min"] = x_min
658636
self.history["x_max"] = x_max
@@ -691,17 +669,13 @@ def __call__(self, params_dict: dict, force_update: bool = False) -> None:
691669

692670
# Write phase space coordinates to file.
693671
if self.write is not None:
694-
if force_update or (
695-
self.stride_write <= (position - self.last_write_position)
696-
):
672+
if force_update or (self.stride_write <= (position - self.last_write_position)):
697673
self.write(bunch, tag=node.getName())
698674
self.last_write_position = position
699675

700676
# Call plotting routines.
701677
if self.plot is not None and _mpi_rank == 0:
702-
if force_update or (
703-
self.stride_plot <= (position - self.last_plot_position)
704-
):
678+
if force_update or (self.stride_plot <= (position - self.last_plot_position)):
705679
info = dict()
706680
for key in self.history:
707681
if self.history[key]:
@@ -749,9 +723,7 @@ def action(self, params_dict: dict) -> None:
749723
order = 2
750724
dispersion_flag = 0
751725
emitt_norm_flag = 0
752-
twiss_analysis.computeBunchMoments(
753-
bunch, order, dispersion_flag, emitt_norm_flag
754-
)
726+
twiss_analysis.computeBunchMoments(bunch, order, dispersion_flag, emitt_norm_flag)
755727
x_rms = np.sqrt(twiss_analysis.getCorrelation(0, 0))
756728
y_rms = np.sqrt(twiss_analysis.getCorrelation(2, 2))
757729

@@ -817,9 +789,7 @@ def forward(self, x: np.ndarray) -> np.ndarray:
817789

818790
bunch = self.get_new_bunch()
819791
bunch = set_bunch_coords(bunch, x_new, verbose=False)
820-
self.lattice.trackBunch(
821-
bunch, index_start=self.index_start, index_stop=self.index_stop
822-
)
792+
self.lattice.trackBunch(bunch, index_start=self.index_start, index_stop=self.index_stop)
823793
U = bunch.get_bunch_coords(bunch)
824794
U = U[:, self.axis]
825795
return U
@@ -835,9 +805,7 @@ def inverse(self, u: np.ndarray) -> np.ndarray:
835805
bunch = oset_bunch_coords(bunch, u_new, verbose=False)
836806
bunch = reverse_bunch(bunch)
837807
self.lattice.reverseOrder()
838-
self.lattice.trackBunch(
839-
bunch, index_start=self.index_stop, index_stop=self.index_start
840-
)
808+
self.lattice.trackBunch(bunch, index_start=self.index_stop, index_stop=self.index_start)
841809
self.lattice.reverseOrder()
842810
bunch = reverse_bunch(bunch)
843811
x = get_bunch_coords(bunch)

orbit_tools/linac/diag.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ def __init__(self, rf_frequency: float = 402.5e06, **kwargs) -> None:
122122
self.history[key] = None
123123

124124
if self.output_dir is not None:
125-
self.history_file = open(
126-
os.path.join(self.output_dir, "history.dat"), "w"
127-
)
125+
self.history_file = open(os.path.join(self.output_dir, "history.dat"), "w")
128126
line = ",".join(list(self.history))
129127
line = line[:-1] + "\n"
130128
self.history_file.write(line)
@@ -178,9 +176,7 @@ def measure_stats(self, params_dict: dict) -> None:
178176

179177
def measure_extrema(self) -> None:
180178
extrema_calculator = BunchExtremaCalculator()
181-
(x_min, x_max, y_min, y_max, z_min, z_max) = extrema_calculator.extremaXYZ(
182-
self.bunch
183-
)
179+
(x_min, x_max, y_min, y_max, z_min, z_max) = extrema_calculator.extremaXYZ(self.bunch)
184180
if self._mpi_rank == 0:
185181
self.history["x_min"] = x_min
186182
self.history["x_max"] = x_max

0 commit comments

Comments
 (0)