Skip to content

Commit 9564ca1

Browse files
Merge pull request #9 from austin-hoover/examples
Add 4D SNS example
2 parents 18ee81e + 02651f2 commit 9564ca1

File tree

11 files changed

+1064
-0
lines changed

11 files changed

+1064
-0
lines changed

examples/sns_ring_4d_1d/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# 4D phase space tomography from 1D measurements in the SNS accumulator ring
2+
3+
<img src="saved/fig_sns_diagram.png">
4+
5+
<img src="saved/fig_diagram.png">
6+
7+
https://doi.org/10.1103/PhysRevAccelBeams.27.122802
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Plot 4D reconstruction in action space."""
2+
3+
import argparse
4+
import os
5+
import pathlib
6+
7+
import numpy as np
8+
import ment
9+
import matplotlib.pyplot as plt
10+
import scipy.ndimage
11+
12+
from tools.cov import normalization_matrix
13+
from tools.utils import list_paths
14+
15+
16+
plt.style.use("tools/style.mplstyle")
17+
18+
19+
# Arguments
20+
# --------------------------------------------------------------------------------------
21+
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--nsamp", type=int, default=1_000_000)
24+
parser.add_argument("--nbins", type=int, default=64)
25+
parser.add_argument("--jmax", type=float, default=85.0)
26+
parser.add_argument("--contours", type=int, default=8)
27+
parser.add_argument("--show", action="store_true")
28+
args = parser.parse_args()
29+
30+
31+
# Setup
32+
# --------------------------------------------------------------------------------------
33+
34+
input_dir = "./outputs/train/"
35+
36+
path = pathlib.Path(__file__)
37+
output_dir = os.path.join("outputs", path.stem)
38+
if not os.path.exists(output_dir):
39+
os.makedirs(output_dir)
40+
41+
42+
# Load model
43+
# --------------------------------------------------------------------------------------
44+
45+
checkpoints_folder = os.path.join(input_dir, "checkpoints")
46+
checkpoint_filenames = list_paths(os.path.join(input_dir, "checkpoints"), sort=True)
47+
checkpoint_filename = checkpoint_filenames[-1]
48+
49+
model = ment.MENT(
50+
ndim=4,
51+
transforms=None,
52+
projections=None,
53+
prior=None,
54+
sampler=None,
55+
)
56+
model.load(checkpoint_filenames[-1])
57+
model.sampler.noise = 1.0
58+
59+
60+
# Sample particles from distribution
61+
# --------------------------------------------------------------------------------------
62+
63+
x = model.unnormalize(model.sample(args.nsamp))
64+
65+
cov_matrix = np.cov(x.T)
66+
norm_matrix = normalization_matrix(cov_matrix, scale=False)
67+
68+
z = np.matmul(x, norm_matrix.T)
69+
cov_matrix_n = np.cov(z.T)
70+
71+
print("covariance matrix:")
72+
print(np.round(cov_matrix, 5))
73+
74+
print("normalized covariance matrix:")
75+
print(np.round(cov_matrix_n, 5))
76+
77+
eps1 = np.sqrt(cov_matrix_n[0, 0] * cov_matrix_n[1, 1])
78+
eps2 = np.sqrt(cov_matrix_n[2, 2] * cov_matrix_n[3, 3])
79+
eps_avg = np.sqrt(eps1 * eps2) # maintains 4D emittance but equal mode amplitudes
80+
81+
z_pred = np.copy(z)
82+
83+
84+
# Plot actions
85+
# --------------------------------------------------------------------------------------
86+
87+
def make_joint_grid(figwidth=5.0, panel_width=0.33) -> tuple:
88+
fig, axs = plt.subplots(
89+
ncols=2,
90+
nrows=2,
91+
sharex="col",
92+
sharey="row",
93+
figsize=(figwidth, figwidth),
94+
gridspec_kw=dict(
95+
width_ratios=[1.0, panel_width],
96+
height_ratios=[panel_width, 1.0],
97+
),
98+
)
99+
axs[0, 1].axis("off")
100+
return fig, axs
101+
102+
103+
def plot_hist(values: np.ndarray, edges: list[np.ndarray], contours: int = 0, **plot_kws) -> tuple:
104+
fig, axs = make_joint_grid()
105+
106+
ax = axs[1, 0]
107+
ax.pcolormesh(
108+
edges[0],
109+
edges[1],
110+
values.T,
111+
linewidth=0.0,
112+
rasterized=True,
113+
shading="auto",
114+
)
115+
116+
if contours:
117+
coords = [0.5 * (e[:-1] + e[1:]) for e in edges]
118+
axs[1, 0].contour(
119+
coords[0],
120+
coords[1],
121+
scipy.ndimage.gaussian_filter(values, 1.0).T,
122+
levels=np.linspace(0.01, 1.0, contours),
123+
colors="white",
124+
linewidths=0.80,
125+
alpha=0.15,
126+
)
127+
128+
ax = axs[0, 0]
129+
proj_edges = edges[0]
130+
proj_values = np.sum(values, axis=1)
131+
proj_values = proj_values / np.max(proj_values)
132+
ax.stairs(proj_values, proj_edges, lw=1.5, color="black")
133+
ax.set_ylim(0.0, 1.25)
134+
135+
ax = axs[1, 1]
136+
proj_edges = edges[1]
137+
proj_values = np.sum(values, axis=0)
138+
proj_values = proj_values / np.max(proj_values)
139+
ax.stairs(proj_values, proj_edges, orientation="horizontal", lw=1.5, color="black")
140+
ax.set_xlim(0.0, 1.25)
141+
142+
return fig, axs
143+
144+
145+
z = z_pred
146+
j1 = np.sum(np.square(z[:, (0, 1)]), axis=1)
147+
j2 = np.sum(np.square(z[:, (2, 3)]), axis=1)
148+
149+
edges = np.linspace(0.0, args.jmax, args.nbins + 1)
150+
edges = [edges, edges]
151+
values, _, _ = np.histogram2d(j1, j2, bins=edges)
152+
values = values / np.max(values)
153+
154+
fig, axs = plot_hist(values, edges, contours=args.contours)
155+
axs[1, 0].set_xlabel(r"$J_1$")
156+
axs[1, 0].set_ylabel(r"$J_2$")
157+
158+
filename = f"fig_action.pdf"
159+
filename = os.path.join(output_dir, filename)
160+
plt.savefig(filename)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Plot 4D reconstruction results."""
2+
import argparse
3+
import os
4+
import pathlib
5+
6+
import numpy as np
7+
import ment
8+
import matplotlib.pyplot as plt
9+
10+
# local
11+
from tools.utils import list_paths
12+
13+
plt.style.use("tools/style.mplstyle")
14+
15+
16+
# Arguments
17+
# --------------------------------------------------------------------------------------
18+
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--nsamp", type=int, default=1_000_000)
21+
args = parser.parse_args()
22+
23+
24+
# Setup
25+
# --------------------------------------------------------------------------------------
26+
27+
input_dir = "./outputs/train/"
28+
29+
path = pathlib.Path(__file__)
30+
output_dir = os.path.join("outputs", path.stem)
31+
if not os.path.exists(output_dir):
32+
os.makedirs(output_dir)
33+
34+
35+
# Load model
36+
# --------------------------------------------------------------------------------------
37+
38+
checkpoints_folder = os.path.join(input_dir, "checkpoints")
39+
checkpoint_filenames = list_paths(os.path.join(input_dir, "checkpoints"), sort=True)
40+
checkpoint_filename = checkpoint_filenames[-1]
41+
42+
model = ment.MENT(
43+
ndim=4,
44+
transforms=None,
45+
projections=None,
46+
prior=None,
47+
sampler=None,
48+
)
49+
model.load(checkpoint_filenames[-1])
50+
model.sampler.noise = 1.0
51+
52+
53+
# Sample particles from distribution
54+
# --------------------------------------------------------------------------------------
55+
56+
x_pred = model.unnormalize(model.sample(args.nsamp))
57+
58+
59+
# Simulate data
60+
# --------------------------------------------------------------------------------------
61+
62+
projections_meas = model.projections
63+
projections_pred = ment.simulate(x_pred, model.transforms, model.diagnostics)
64+
65+
66+
# Plot measured vs. simulated projections.
67+
# --------------------------------------------------------------------------------------
68+
69+
fig, axs = plt.subplots(
70+
nrows=6,
71+
ncols=6,
72+
figsize=(6.0, 4.5),
73+
sharex=True,
74+
sharey=True,
75+
constrained_layout=True,
76+
)
77+
78+
index = 0
79+
for i in range(len(projections_meas)):
80+
for j in range(len(projections_meas[i])):
81+
proj_meas = projections_meas[i][j].copy()
82+
proj_pred = projections_pred[i][j].copy()
83+
scale = np.max(proj_meas.values)
84+
85+
ax = axs.flat[index]
86+
ax.plot(
87+
proj_pred.coords,
88+
proj_pred.values / scale,
89+
label="pred",
90+
color="red",
91+
alpha=0.3,
92+
)
93+
ax.plot(
94+
proj_meas.coords,
95+
proj_meas.values / scale,
96+
label="meas",
97+
color="black",
98+
lw=0,
99+
marker=".",
100+
ms=1.0,
101+
)
102+
ax.annotate(
103+
"{:02.0f}".format(index // 3),
104+
xy=(0.03, 0.96),
105+
xycoords="axes fraction",
106+
horizontalalignment="left",
107+
verticalalignment="top",
108+
)
109+
index += 1
110+
111+
for ax in axs.flat:
112+
ax.set_ylim(-0.05, 1.15)
113+
ax.set_yticks([])
114+
ax.set_xticks([-35.0, 0.0, 35.0])
115+
116+
axs[-1, 0].set_xlabel(r"$x$ (mm)")
117+
axs[-1, 1].set_xlabel(r"$y$ (mm)")
118+
axs[-1, 2].set_xlabel(r"$u$ (mm)")
119+
axs[-1, 3].set_xlabel(r"$x$ (mm)")
120+
axs[-1, 4].set_xlabel(r"$y$ (mm)")
121+
axs[-1, 5].set_xlabel(r"$u$ (mm)")
122+
123+
filename = "fig_profiles.pdf"
124+
filename = os.path.join(output_dir, filename)
125+
plt.savefig(filename, dpi=250)
126+
plt.close()
127+
128+
129+
# Plot 2D projections of 4D distribution
130+
# --------------------------------------------------------------------------------------
131+
132+
axes_proj = [(0, 1), (2, 3), (0, 2), (0, 3), (2, 1), (1, 3)]
133+
dims = [r"$x$", "$x'$", "$y$", "$y'$"]
134+
units = ["mm", "mrad", "mm", "mrad"]
135+
136+
xmax = 3.5 * np.std(x_pred, axis=0)
137+
limits = list(zip(-xmax, xmax))
138+
139+
fig, axs = plt.subplots(
140+
ncols=3,
141+
nrows=2,
142+
figsize=(7.0, 4.0),
143+
sharex=False,
144+
sharey=False,
145+
constrained_layout=True,
146+
)
147+
for j, axis in enumerate(axes_proj):
148+
values, edges = np.histogramdd(
149+
x_pred[:, axis],
150+
bins=64,
151+
range=[limits[k] for k in axis],
152+
density=True,
153+
)
154+
155+
ax = axs.flat[j]
156+
ax.pcolormesh(
157+
edges[0],
158+
edges[1],
159+
values.T,
160+
rasterized=True,
161+
edgecolor="None",
162+
linewidth=0.0,
163+
)
164+
165+
for j, axis in enumerate(axes_proj):
166+
ax = axs.flat[j]
167+
ax.set_xlabel(f"{dims[axis[0]]} ({units[axis[0]]})")
168+
ax.set_ylabel(f"{dims[axis[1]]} ({units[axis[1]]})")
169+
170+
filename = f"fig_proj2d.pdf"
171+
filename = os.path.join(output_dir, filename)
172+
plt.savefig(filename, dpi=250)
173+
plt.close("all")
871 KB
Loading
588 KB
Loading

examples/sns_ring_4d_1d/tools/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)