Skip to content

Commit 41d48e4

Browse files
Merge pull request #11 from austin-hoover/dev
Organize examples
2 parents 374987a + c3c5830 commit 41d48e4

24 files changed

+5224
-5083
lines changed

examples/fit_cov_2d.py renamed to examples/cov/fit_cov_2d.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Fit 2D covariance matrix to 1D measurements."""
22
import argparse
3+
import os
4+
import pathlib
35
from typing import Callable
46
from typing import Optional
7+
58
import matplotlib.pyplot as plt
69
import numpy as np
710

@@ -13,7 +16,7 @@
1316
from ment.utils import rotation_matrix
1417

1518

16-
# Setup
19+
# Arguments
1720
# --------------------------------------------------------------------------------------
1821

1922
parser = argparse.ArgumentParser()
@@ -25,6 +28,14 @@
2528
parser.add_argument("--method", type=str, default="differential_evolution")
2629
args = parser.parse_args()
2730

31+
32+
# Setup
33+
# --------------------------------------------------------------------------------------
34+
35+
path = pathlib.Path(__file__)
36+
output_dir = os.path.join("outputs", path.stem)
37+
os.makedirs(output_dir, exist_ok=True)
38+
2839
ndim = 2
2940

3041

@@ -84,7 +95,7 @@
8495
print(fit_results)
8596

8697
# Plot results
87-
x = fitter.sample(10_000)
98+
x = fitter.sample(100_000)
8899
projections_pred = unravel(simulate(x, fitter.transforms, fitter.diagnostics))
89100
projections_meas = unravel(fitter.projections)
90101

@@ -100,7 +111,9 @@
100111
values_meas = projections_meas[i].values
101112
ax.plot(values_pred / values_meas.max(), color="lightgray")
102113
ax.plot(values_meas / values_meas.max(), color="black", lw=0.0, marker=".", ms=2.0)
103-
plt.show()
114+
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
115+
plt.close()
116+
104117

105118

106119

examples/fit_cov_4d.py renamed to examples/cov/fit_cov_4d.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Fit 4D covariance matrix to 2D measurements."""
22
import argparse
3+
import os
4+
import pathlib
35
from typing import Callable
46
from typing import Optional
5-
import matplotlib.pyplot as plt
7+
68
import numpy as np
7-
import psdist as ps
8-
import scipy.optimize
9+
from matplotlib import pyplot as plt
910
from matplotlib.patches import Ellipse
1011

1112
import ment
@@ -16,7 +17,7 @@
1617
from ment.utils import rotation_matrix
1718

1819

19-
# Setup
20+
# Arguments
2021
# --------------------------------------------------------------------------------------
2122

2223
parser = argparse.ArgumentParser()
@@ -31,6 +32,14 @@
3132
args = parser.parse_args()
3233

3334

35+
# Setup
36+
# --------------------------------------------------------------------------------------
37+
38+
path = pathlib.Path(__file__)
39+
output_dir = os.path.join("outputs", path.stem)
40+
os.makedirs(output_dir, exist_ok=True)
41+
42+
3443
# Source distribution
3544
# --------------------------------------------------------------------------------------
3645

@@ -83,7 +92,7 @@
8392
bound=1.00e+02,
8493
verbose=True,
8594
)
86-
cov_matrix, fit_result = fitter.fit(method=args.method)
95+
cov_matrix, fit_result = fitter.fit(method=args.method, iters=args.iters)
8796

8897

8998
# Print results
@@ -92,6 +101,23 @@
92101

93102

94103
# Plot results
104+
def rms_ellipse_params(cov_matrix: np.ndarray) -> tuple[float, float, float]:
105+
sii = cov_matrix[0, 0]
106+
sjj = cov_matrix[1, 1]
107+
sij = cov_matrix[0, 1]
108+
109+
angle = -0.5 * np.arctan2(2 * sij, sii - sjj)
110+
111+
_sin = np.sin(angle)
112+
_cos = np.cos(angle)
113+
_sin2 = _sin**2
114+
_cos2 = _cos**2
115+
116+
c1 = np.sqrt(abs(sii * _cos2 + sjj * _sin2 - 2 * sij * _sin * _cos))
117+
c2 = np.sqrt(abs(sii * _sin2 + sjj * _cos2 + 2 * sij * _sin * _cos))
118+
return (c1, c2, angle)
119+
120+
95121
x = fitter.sample(100_000)
96122
projections_pred = unravel(simulate(x, fitter.transforms, fitter.diagnostics))
97123
projections_true = unravel(fitter.projections)
@@ -101,7 +127,7 @@
101127
fig, axs = plt.subplots(
102128
ncols=ncols,
103129
nrows=nrows,
104-
figsize=(1.5 * ncols, 1.1 * nrows),
130+
figsize=(1.1 * ncols, 1.1 * nrows),
105131
constrained_layout=True,
106132
sharex=True,
107133
sharey=True,
@@ -115,14 +141,13 @@
115141
color = ["white", "red"][i]
116142
ls = ["-", "-"][i]
117143

118-
cx, cy, angle = ps.cov.rms_ellipse_params(proj.cov(), axis=(0, 1))
144+
cx, cy, angle = rms_ellipse_params(proj.cov())
119145
angle = -np.degrees(angle)
120146
center = (0.0, 0.0)
121147
cx *= 4.0
122148
cy *= 4.0
123149
ax.add_patch(Ellipse(center, cx, cy, angle=angle, color=color, fill=False, ls=ls))
124-
plt.show()
125-
126-
150+
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
151+
plt.close()
127152

128153

examples/fit_cov_nd.py renamed to examples/cov/fit_cov_nd.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Fit ND covariance matrix to random 1D projections."""
22
import argparse
3+
import os
4+
import pathlib
35
from typing import Callable
46
from typing import Optional
5-
import matplotlib.pyplot as plt
7+
68
import numpy as np
7-
import psdist as ps
8-
import scipy.optimize
9+
from matplotlib import pyplot as plt
910
from matplotlib.patches import Ellipse
1011

1112
import ment
@@ -16,7 +17,7 @@
1617
from ment.utils import rotation_matrix
1718

1819

19-
# Setup
20+
# Arguments
2021
# --------------------------------------------------------------------------------------
2122

2223
parser = argparse.ArgumentParser()
@@ -34,6 +35,14 @@
3435
args = parser.parse_args()
3536

3637

38+
# Setup
39+
# --------------------------------------------------------------------------------------
40+
41+
path = pathlib.Path(__file__)
42+
output_dir = os.path.join("outputs", path.stem)
43+
os.makedirs(output_dir, exist_ok=True)
44+
45+
3746
# Source distribution
3847
# --------------------------------------------------------------------------------------
3948

@@ -111,7 +120,7 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
111120
fig, axs = plt.subplots(
112121
ncols=ncols,
113122
nrows=nrows,
114-
figsize=(ncols * 11.0, nrows * 1.0),
123+
figsize=(ncols * 1.1, nrows * 1.1),
115124
sharey=True,
116125
sharex=True,
117126
constrained_layout=True
@@ -121,7 +130,8 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
121130
values_meas = projections_meas[i].values
122131
ax.plot(values_pred / values_meas.max(), color="lightgray")
123132
ax.plot(values_meas / values_meas.max(), color="black", lw=0.0, marker=".", ms=2.0)
124-
plt.show()
133+
plt.savefig(os.path.join(output_dir, "fig_results.png"), dpi=300)
134+
plt.close()
125135

126136

127137

examples/ct/plot_image.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
import sys
1+
import os
22
import matplotlib.pyplot as plt
3-
import numpy as np
4-
import skimage as ski
53

64
from utils import gen_image
75

8-
name = None
6+
7+
os.makedirs("outputs/plot_image", exist_ok=True)
8+
9+
name = "tree"
910
if len(sys.argv) > 1:
1011
name = sys.argv[1]
1112

1213
im = gen_image(name)
1314

1415
fig, ax = plt.subplots()
1516
ax.pcolormesh(im.T)
16-
plt.show()
17+
plt.savefig(f"outputs/plot_image/fig_{name}.png", dpi=300)

examples/rec_long.py renamed to examples/longitudinal/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def set_bunch_coords(bunch: Bunch, x: np.ndarray, axis: tuple[int, ...] = None)
8585
if axis is None:
8686
axis = tuple(range(6))
8787

88-
# Resize
8988
size = x.shape[0]
9089
size_error = size - bunch.getSize()
9190
if size_error > 0:
@@ -137,7 +136,6 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
137136
drift_node_1.setLength(124.0)
138137
drift_node_2.setLength(124.0)
139138

140-
141139
z_to_phi = 2.0 * np.pi / 248.0
142140
rf_hnum = 1.0
143141
rf_length = 0.0
@@ -162,6 +160,13 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
162160
bunch.addParticle(0.0, 0.0, 0.0, 0.0, x_true[i, 0], x_true[i, 1])
163161

164162

163+
# Evolve forward a few turns; this will be our ground-truth distribution.
164+
for _ in range(250):
165+
lattice.trackBunch(bunch)
166+
167+
x_true = get_bunch_coords(bunch, axis=(4, 5))
168+
169+
165170
# Create transform functions
166171
turn_min = 0
167172
turn_max = 500
@@ -189,23 +194,19 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
189194
# Training data
190195
# --------------------------------------------------------------------------------------
191196

192-
# Here we simulate the projections; in real life the projections would be measured.
193197
projections = ment.sim.simulate(x_true, transforms, diagnostics)
194198

195199

196200
# Reconstruction model
197201
# --------------------------------------------------------------------------------------
198202

199-
# Define prior distribution for relative entropy calculation
200203
prior = ment.prior.GaussianPrior(ndim=2, scale=[200.0, 0.020])
201204

202-
# Define particle sampler (if mode="sample")
203205
sampler = ment.samp.GridSampler(
204206
grid_limits=limits,
205207
grid_shape=(128, 128),
206208
)
207209

208-
# Set up MENT model
209210
model = ment.MENT(
210211
ndim=ndim,
211212
transforms=transforms,
@@ -254,10 +255,11 @@ def plot_model(model):
254255
fig.savefig(os.path.join(output_dir, f"fig_proj_{epoch:02.0f}.png"))
255256
plt.close()
256257

258+
257259
# Plot final distribution
258260
x_pred = model.sample(x_true.shape[0])
259261

260-
fig, axs = plt.subplots(ncols=2, constrained_layout=True)
262+
fig, axs = plt.subplots(ncols=2, figsize=(6, 3), constrained_layout=True)
261263
for ax, x in zip(axs, [x_pred, x_true]):
262264
ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
263265
fig.savefig(os.path.join(output_dir, f"fig_dist_{epoch:02.0f}.png"))

0 commit comments

Comments
 (0)