Skip to content

Commit a060cf8

Browse files
Merge pull request #23 from austin-hoover/dev
Add pre-commit hooks
2 parents 0802bbf + e16467c commit a060cf8

File tree

33 files changed

+463
-308
lines changed

33 files changed

+463
-308
lines changed

.pre-commit-config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v2.3.0
4+
hooks:
5+
- id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/psf/black
9+
rev: 22.10.0
10+
hooks:
11+
- id: black

examples/cov/fit_cov_2d.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@
7878

7979
# Fit covariance matrix
8080
# --------------------------------------------------------------------------------------
81-
81+
8282
fitter = ment.CholeskyCovFitter(
83-
ndim=ndim,
83+
ndim=ndim,
8484
transforms=transforms,
8585
projections=projections,
8686
nsamp=args.nsamp,
87-
bound=1.00e+06,
87+
bound=1.00e06,
8888
verbose=True,
8989
)
9090
cov_matrix, fit_results = fitter.fit(iters=args.iters, method=args.method)
@@ -100,11 +100,11 @@
100100
projections_meas = unravel(fitter.projections)
101101

102102
fig, axs = plt.subplots(
103-
ncols=args.nmeas,
104-
figsize=(11.0, 1.0),
103+
ncols=args.nmeas,
104+
figsize=(11.0, 1.0),
105105
sharey=True,
106106
sharex=True,
107-
constrained_layout=True
107+
constrained_layout=True,
108108
)
109109
for i, ax in enumerate(axs):
110110
values_pred = projections_pred[i].values
@@ -113,8 +113,3 @@
113113
ax.plot(values_meas / values_meas.max(), color="black", lw=0.0, marker=".", ms=2.0)
114114
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
115115
plt.close()
116-
117-
118-
119-
120-

examples/cov/fit_cov_4d.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@
8585

8686
# Run optimizer
8787
fitter = ment.CholeskyCovFitter(
88-
ndim=ndim,
88+
ndim=ndim,
8989
transforms=transforms,
9090
projections=projections,
9191
nsamp=args.nsamp,
92-
bound=1.00e+02,
92+
bound=1.00e02,
9393
verbose=True,
9494
)
9595
cov_matrix, fit_result = fitter.fit(method=args.method, iters=args.iters)
@@ -105,14 +105,14 @@ def rms_ellipse_params(cov_matrix: np.ndarray) -> tuple[float, float, float]:
105105
sii = cov_matrix[0, 0]
106106
sjj = cov_matrix[1, 1]
107107
sij = cov_matrix[0, 1]
108-
108+
109109
angle = -0.5 * np.arctan2(2 * sij, sii - sjj)
110-
110+
111111
_sin = np.sin(angle)
112112
_cos = np.cos(angle)
113113
_sin2 = _sin**2
114114
_cos2 = _cos**2
115-
115+
116116
c1 = np.sqrt(abs(sii * _cos2 + sjj * _sin2 - 2 * sij * _sin * _cos))
117117
c2 = np.sqrt(abs(sii * _sin2 + sjj * _cos2 + 2 * sij * _sin * _cos))
118118
return (c1, c2, angle)
@@ -125,9 +125,9 @@ def rms_ellipse_params(cov_matrix: np.ndarray) -> tuple[float, float, float]:
125125
ncols = min(args.nmeas, 7)
126126
nrows = int(np.ceil(args.nmeas / ncols))
127127
fig, axs = plt.subplots(
128-
ncols=ncols,
129-
nrows=nrows,
130-
figsize=(1.1 * ncols, 1.1 * nrows),
128+
ncols=ncols,
129+
nrows=nrows,
130+
figsize=(1.1 * ncols, 1.1 * nrows),
131131
constrained_layout=True,
132132
sharex=True,
133133
sharey=True,
@@ -140,14 +140,14 @@ def rms_ellipse_params(cov_matrix: np.ndarray) -> tuple[float, float, float]:
140140
for i, proj in enumerate([proj_true, proj_pred]):
141141
color = ["white", "red"][i]
142142
ls = ["-", "-"][i]
143-
143+
144144
cx, cy, angle = rms_ellipse_params(proj.cov())
145145
angle = -np.degrees(angle)
146146
center = (0.0, 0.0)
147147
cx *= 4.0
148148
cy *= 4.0
149-
ax.add_patch(Ellipse(center, cx, cy, angle=angle, color=color, fill=False, ls=ls))
149+
ax.add_patch(
150+
Ellipse(center, cx, cy, angle=angle, color=color, fill=False, ls=ls)
151+
)
150152
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
151153
plt.close()
152-
153-

examples/cov/fit_cov_nd.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
# Forward model
6060
# --------------------------------------------------------------------------------------
6161

62+
6263
class ProjectionTransform:
6364
def __init__(self, direction: np.ndarray) -> None:
6465
self.direction = direction
@@ -94,11 +95,11 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
9495

9596
# Run optimizer
9697
fitter = ment.CholeskyCovFitter(
97-
ndim=ndim,
98+
ndim=ndim,
9899
transforms=transforms,
99100
projections=projections,
100101
nsamp=args.nsamp,
101-
bound=1.00e+02,
102+
bound=1.00e02,
102103
verbose=args.verbose,
103104
)
104105
cov_matrix, fit_result = fitter.fit(method=args.method, iters=args.iters)
@@ -120,10 +121,10 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
120121
fig, axs = plt.subplots(
121122
ncols=ncols,
122123
nrows=nrows,
123-
figsize=(ncols * 1.1, nrows * 1.1),
124+
figsize=(ncols * 1.1, nrows * 1.1),
124125
sharey=True,
125126
sharex=True,
126-
constrained_layout=True
127+
constrained_layout=True,
127128
)
128129
for i, ax in enumerate(axs.flat):
129130
values_pred = projections_pred[i].values
@@ -132,8 +133,3 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
132133
ax.plot(values_meas / values_meas.max(), color="black", lw=0.0, marker=".", ms=2.0)
133134
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
134135
plt.close()
135-
136-
137-
138-
139-

examples/ct/train.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030

3131
parser = argparse.ArgumentParser()
3232
parser.add_argument(
33-
"--im",
34-
type=str,
35-
default="tree",
36-
choices=["shepp", "leaf", "tree", "brain"]
33+
"--im", type=str, default="tree", choices=["shepp", "leaf", "tree", "brain"]
3734
)
3835
parser.add_argument("--im-blur", type=float, default=0.0)
3936
parser.add_argument("--im-pad", type=int, default=0)
@@ -126,7 +123,7 @@
126123
# import scipy.interpolate
127124

128125
# interp = scipy.interpolate.RegularGridInterpolator(
129-
# grid_coords,
126+
# grid_coords,
130127
# grid_values_true,
131128
# method="linear",
132129
# fill_value=0.0,
@@ -150,7 +147,6 @@
150147
# projections.append([projection])
151148

152149

153-
154150
# Plot sinogram
155151
fig, ax = plt.subplots()
156152
ax.pcolormesh(sinogram)
@@ -228,7 +224,7 @@ def evaluate_model(model: ment.MENT) -> dict:
228224

229225
discrepancy = np.mean(np.abs(sinogram_pred - sinogram_true))
230226

231-
# Absolute entropy
227+
# Absolute entropy
232228
p = values_pred
233229
q = np.ones(p.shape)
234230
q = q / np.sum(q) / cell_volume
@@ -397,7 +393,7 @@ def evaluate_model(model: ment.MENT) -> dict:
397393
for j, name in enumerate(results):
398394
for i, key in enumerate(["image", "sinogram"]):
399395
ax = axs[i, j]
400-
image = results[name][key]
396+
image = results[name][key]
401397
ax.pcolormesh(image.T, vmin=vmin, vmax=vmax)
402398
for j, name in enumerate(results):
403399
axs[0, j].set_title(name.upper())

examples/ct/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,20 @@ def gen_image(key: str, res: int = None, blur: float = 0.0, pad: int = 0) -> Non
3939
if key == "tree":
4040
pad = max(pad, 25)
4141

42-
if pad:
42+
if pad:
4343
shape = image.shape
4444
new_shape = tuple(np.add(shape, pad * 2))
4545
new_image = np.zeros(new_shape)
4646
new_image[pad:-pad, pad:-pad] = image.copy()
4747
image = new_image.copy()
48-
48+
4949
if res:
5050
shape = (res, res)
5151
image = skimage.transform.resize(image, shape, anti_aliasing=True)
5252

5353
if blur:
5454
image = skimage.filters.gaussian(image, blur)
55-
55+
5656
return image
5757

5858

@@ -64,7 +64,9 @@ def radon_transform(image: np.ndarray, angles: np.ndarray) -> np.ndarray:
6464
return sinogram
6565

6666

67-
def rec_sart(sinogram: np.ndarray, angles: np.ndarray, iterations: int = 1) -> np.ndarray:
67+
def rec_sart(
68+
sinogram: np.ndarray, angles: np.ndarray, iterations: int = 1
69+
) -> np.ndarray:
6870
theta = -np.copy(np.degrees(angles))
6971
image = skimage.transform.iradon_sart(sinogram, theta=theta)
7072
for _ in range(iterations - 1):
@@ -73,8 +75,10 @@ def rec_sart(sinogram: np.ndarray, angles: np.ndarray, iterations: int = 1) -> n
7375
return image
7476

7577

76-
def rec_fbp(sinogram: np.ndarray, angles: np.ndarray, iterations: int = 1) -> np.ndarray:
78+
def rec_fbp(
79+
sinogram: np.ndarray, angles: np.ndarray, iterations: int = 1
80+
) -> np.ndarray:
7781
theta = -np.copy(np.degrees(angles))
7882
image = skimage.transform.iradon(sinogram, theta=theta)
7983
image = image.T
80-
return image
84+
return image

examples/longitudinal/train.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Reconstruct longitudinal phase space distribution from turn-by-turn projections.
22
3-
This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice
3+
This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice
44
model consisting of a harmonic RF cavity surrounded by two drifts. Things are a bit slow
55
because we have to repeatedly convert between NumPy arrays and Bunch objects, but it works.
66
77
Note that one MENT iteration requires simulating all projectionos. If projectiono k
88
is measured after k turns, then we must first track the bunch 1 turn, then resample
99
and track 2 turns, then resample and track 3 turns, etc. In total, we must track
10-
n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better
10+
n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better
1111
option.
1212
"""
1313
import os
@@ -51,6 +51,7 @@
5151
# Forward model
5252
# --------------------------------------------------------------------------------------
5353

54+
5455
def get_part_coords(bunch: Bunch, index: int) -> list[float]:
5556
x = bunch.x(index)
5657
y = bunch.y(index)
@@ -81,7 +82,9 @@ def get_bunch_coords(bunch: Bunch, axis: tuple[int, ...] = None) -> np.ndarray:
8182
return x
8283

8384

84-
def set_bunch_coords(bunch: Bunch, x: np.ndarray, axis: tuple[int, ...] = None) -> Bunch:
85+
def set_bunch_coords(
86+
bunch: Bunch, x: np.ndarray, axis: tuple[int, ...] = None
87+
) -> Bunch:
8588
if axis is None:
8689
axis = tuple(range(6))
8790

@@ -128,7 +131,7 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
128131
bunch = self.track_bunch()
129132
x_out = get_bunch_coords(bunch, axis=self.axis)
130133
return x_out
131-
134+
132135

133136
# Create accelerator lattice (drift, rf, drift)
134137
drift_node_1 = DriftTEAPOT()
@@ -142,7 +145,9 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
142145
rf_synchronous_de = 0.0
143146
rf_voltage = 300.0e-06
144147
rf_phase = 0.0
145-
rf_node = Harmonic_RFNode(z_to_phi, rf_synchronous_de, rf_hnum, rf_voltage, rf_phase, rf_length)
148+
rf_node = Harmonic_RFNode(
149+
z_to_phi, rf_synchronous_de, rf_hnum, rf_voltage, rf_phase, rf_length
150+
)
146151

147152
lattice = TEAPOT_Ring()
148153
lattice.addNode(drift_node_1)
@@ -178,11 +183,8 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
178183
transform = ORBITTransform(lattice, bunch, nturns=nturns, axis=(4, 5))
179184
transforms.append(transform)
180185

181-
limits = [
182-
(-0.5 * lattice.getLength(), +0.5 * lattice.getLength()),
183-
(-0.030, 0.030)
184-
]
185-
186+
limits = [(-0.5 * lattice.getLength(), +0.5 * lattice.getLength()), (-0.030, 0.030)]
187+
186188
# Create a list of histogram diagnostics for each transform.
187189
bin_edges = np.linspace(limits[0][0], limits[0][1], 100)
188190
diagnostics = []
@@ -235,19 +237,25 @@ def plot_model(model):
235237
projections_pred = ment.utils.unravel(projections_pred)
236238

237239
fig, axs = plt.subplots(
238-
ncols=nmeas, figsize=(11.0, 1.0), sharey=True, sharex=True, constrained_layout=True
240+
ncols=nmeas,
241+
figsize=(11.0, 1.0),
242+
sharey=True,
243+
sharex=True,
244+
constrained_layout=True,
239245
)
240246
for i, ax in enumerate(axs):
241247
values_pred = projections_pred[i].values
242248
values_true = projections_true[i].values
243249
ax.plot(values_pred / values_true.max(), color="lightgray")
244-
ax.plot(values_true / values_true.max(), color="black", lw=0.0, marker=".", ms=2.0)
250+
ax.plot(
251+
values_true / values_true.max(), color="black", lw=0.0, marker=".", ms=2.0
252+
)
245253
return fig
246254

247255

248256
for epoch in range(4):
249257
print("epoch =", epoch)
250-
258+
251259
if epoch > 0:
252260
model.gauss_seidel_step(learning_rate=0.90)
253261

@@ -264,4 +272,3 @@ def plot_model(model):
264272
ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
265273
fig.savefig(os.path.join(output_dir, f"fig_dist_{epoch:02.0f}.png"))
266274
plt.close()
267-

examples/nonlinear_ring_4d/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@
55
In this example we measure the $x-p_x$ and $y-p_y$ distribution after every 20 turns in a linear periodic lattice + nonlinear rotationally symmetric kick [1]. MENT finds an initial 4D phase space distribution consistent with these projections.
66

77
[1] https://arxiv.org/abs/2405.05657
8-
9-

0 commit comments

Comments
 (0)