Skip to content

Commit 6a598cb

Browse files
Merge pull request #4 from austin-hoover/longitudinal-example
Longitudinal example
2 parents 13fe19d + 72c0f60 commit 6a598cb

File tree

3 files changed

+281
-4
lines changed

3 files changed

+281
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ __pycache__
33
.DS_Store
44
.idea
55
*.egg-info
6+
outputs

examples/rec_long.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""Reconstruct longitudinal phase space distribution from turn-by-turn projections.
2+
3+
This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice
4+
model consisting of a harmonic RF cavity surrounded by two drifts. Things are a bit slow
5+
because we have to repeatedly convert between NumPy arrays and Bunch objects, but it works.
6+
7+
Note that one MENT iteration requires simulating all projectionos. If projectiono k
8+
is measured after k turns, then we must first track the bunch 1 turn, then resample
9+
and track 2 turns, then resample and track 3 turns, etc. In total, we must track
10+
n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better
11+
option.
12+
"""
13+
import os
14+
import pathlib
15+
16+
import numpy as np
17+
import matplotlib.pyplot as plt
18+
import ment
19+
20+
from orbit.core.bunch import Bunch
21+
from orbit.lattice import AccLattice
22+
from orbit.lattice import AccNode
23+
from orbit.rf_cavities import Harmonic_RFNode
24+
from orbit.teapot import DriftTEAPOT
25+
from orbit.teapot import TEAPOT_Ring
26+
27+
28+
# Setup
29+
# --------------------------------------------------------------------------------------
30+
31+
ndim = 2
32+
nmeas = 7
33+
seed = 0
34+
size = 100_000
35+
36+
path = pathlib.Path(__file__)
37+
output_dir = os.path.join("outputs", path.stem)
38+
os.makedirs(output_dir, exist_ok=True)
39+
40+
41+
# Distribution
42+
# --------------------------------------------------------------------------------------
43+
44+
rng = np.random.default_rng(seed)
45+
46+
x_true = np.zeros((size, 2))
47+
x_true[:, 0] = 0.60 * rng.uniform(-124.0, 124.0, x_true.shape[0])
48+
x_true[:, 1] = rng.normal(scale=0.0025, size=x_true.shape[0])
49+
50+
51+
# Forward model
52+
# --------------------------------------------------------------------------------------
53+
54+
def get_part_coords(bunch: Bunch, index: int) -> list[float]:
55+
x = bunch.x(index)
56+
y = bunch.y(index)
57+
z = bunch.z(index)
58+
xp = bunch.xp(index)
59+
yp = bunch.yp(index)
60+
de = bunch.dE(index)
61+
return [x, xp, y, yp, z, de]
62+
63+
64+
def set_part_coords(bunch: Bunch, index: int, coords: list[float]) -> Bunch:
65+
(x, xp, y, yp, z, de) = coords
66+
bunch.x(index, x)
67+
bunch.y(index, y)
68+
bunch.z(index, z)
69+
bunch.xp(index, xp)
70+
bunch.yp(index, yp)
71+
bunch.dE(index, de)
72+
return bunch
73+
74+
75+
def get_bunch_coords(bunch: Bunch, axis: tuple[int, ...] = None) -> np.ndarray:
76+
x = np.zeros((bunch.getSize(), 6))
77+
for i in range(bunch.getSize()):
78+
x[i, :] = get_part_coords(bunch, i)
79+
if axis is not None:
80+
x = x[:, axis]
81+
return x
82+
83+
84+
def set_bunch_coords(bunch: Bunch, x: np.ndarray, axis: tuple[int, ...] = None) -> Bunch:
85+
if axis is None:
86+
axis = tuple(range(6))
87+
88+
# Resize
89+
size = x.shape[0]
90+
size_error = size - bunch.getSize()
91+
if size_error > 0:
92+
for _ in range(size_error):
93+
bunch.addParticle(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
94+
else:
95+
for i in range(size, bunch.getSize()):
96+
bunch.deleteParticleFast(i)
97+
bunch.compress()
98+
99+
for i in range(bunch.getSize()):
100+
coords = get_part_coords(bunch, i)
101+
for j in range(len(axis)):
102+
coords[axis[j]] = x[i, j]
103+
bunch = set_part_coords(bunch, i, coords)
104+
return bunch
105+
106+
107+
class ORBITTransform:
108+
def __init__(
109+
self,
110+
lattice: AccLattice,
111+
bunch: Bunch,
112+
axis: tuple[int, ...],
113+
nturns: int = 1,
114+
) -> None:
115+
self.lattice = lattice
116+
self.bunch = bunch
117+
self.axis = axis
118+
self.nturns = nturns
119+
120+
def track_bunch(self) -> Bunch:
121+
bunch = Bunch()
122+
self.bunch.copyBunchTo(bunch)
123+
for _ in range(self.nturns):
124+
self.lattice.trackBunch(bunch)
125+
return bunch
126+
127+
def __call__(self, x: np.ndarray) -> np.ndarray:
128+
set_bunch_coords(self.bunch, x, axis=self.axis)
129+
bunch = self.track_bunch()
130+
x_out = get_bunch_coords(bunch, axis=self.axis)
131+
return x_out
132+
133+
134+
# Create accelerator lattice (drift, rf, drift)
135+
drift_node_1 = DriftTEAPOT()
136+
drift_node_2 = DriftTEAPOT()
137+
drift_node_1.setLength(124.0)
138+
drift_node_2.setLength(124.0)
139+
140+
141+
z_to_phi = 2.0 * np.pi / 248.0
142+
rf_hnum = 1.0
143+
rf_length = 0.0
144+
rf_synchronous_de = 0.0
145+
rf_voltage = 300.0e-06
146+
rf_phase = 0.0
147+
rf_node = Harmonic_RFNode(z_to_phi, rf_synchronous_de, rf_hnum, rf_voltage, rf_phase, rf_length)
148+
149+
lattice = TEAPOT_Ring()
150+
lattice.addNode(drift_node_1)
151+
lattice.addNode(rf_node)
152+
lattice.addNode(drift_node_2)
153+
lattice.initialize()
154+
155+
156+
# Create bunch
157+
bunch = Bunch()
158+
bunch.mass(0.938)
159+
bunch.getSyncParticle().kinEnergy(1.000)
160+
161+
for i in range(x_true.shape[0]):
162+
bunch.addParticle(0.0, 0.0, 0.0, 0.0, x_true[i, 0], x_true[i, 1])
163+
164+
165+
# Create transform functions
166+
turn_min = 0
167+
turn_max = 500
168+
turn_step = int((turn_max - turn_min) / nmeas)
169+
turns = list(range(turn_min, turn_max + turn_step, turn_step))
170+
171+
transforms = []
172+
for nturns in turns:
173+
transform = ORBITTransform(lattice, bunch, nturns=nturns, axis=(4, 5))
174+
transforms.append(transform)
175+
176+
limits = [
177+
(-0.5 * lattice.getLength(), +0.5 * lattice.getLength()),
178+
(-0.030, 0.030)
179+
]
180+
181+
# Create a list of histogram diagnostics for each transform.
182+
bin_edges = np.linspace(limits[0][0], limits[0][1], 100)
183+
diagnostics = []
184+
for transform in transforms:
185+
diagnostic = ment.diag.Histogram1D(axis=0, edges=bin_edges)
186+
diagnostics.append([diagnostic])
187+
188+
189+
# Training data
190+
# --------------------------------------------------------------------------------------
191+
192+
# Here we simulate the projections; in real life the projections would be measured.
193+
projections = ment.sim.simulate(x_true, transforms, diagnostics)
194+
195+
196+
# Reconstruction model
197+
# --------------------------------------------------------------------------------------
198+
199+
# Define prior distribution for relative entropy calculation
200+
prior = ment.prior.GaussianPrior(ndim=2, scale=[200.0, 0.020])
201+
202+
# Define particle sampler (if mode="sample")
203+
sampler = ment.samp.GridSampler(
204+
grid_limits=limits,
205+
grid_shape=(128, 128),
206+
)
207+
208+
# Set up MENT model
209+
model = ment.MENT(
210+
ndim=ndim,
211+
transforms=transforms,
212+
projections=projections,
213+
prior=prior,
214+
sampler=sampler,
215+
mode="sample",
216+
verbose=2,
217+
)
218+
219+
220+
# Training
221+
# --------------------------------------------------------------------------------------
222+
223+
224+
def plot_model(model):
225+
# Sample particles
226+
x_pred = model.sample(size)
227+
228+
# Plot sim vs. measured profiles
229+
projections_true = ment.sim.copy_histograms(model.projections)
230+
projections_true = ment.utils.unravel(projections_true)
231+
232+
projections_pred = ment.sim.copy_histograms(model.diagnostics)
233+
projections_pred = ment.sim.simulate(x_pred, transforms, projections_pred)
234+
projections_pred = ment.utils.unravel(projections_pred)
235+
236+
fig, axs = plt.subplots(
237+
ncols=nmeas, figsize=(11.0, 1.0), sharey=True, sharex=True, constrained_layout=True
238+
)
239+
for i, ax in enumerate(axs):
240+
values_pred = projections_pred[i].values
241+
values_true = projections_true[i].values
242+
ax.plot(values_pred / values_true.max(), color="lightgray")
243+
ax.plot(values_true / values_true.max(), color="black", lw=0.0, marker=".", ms=2.0)
244+
return fig
245+
246+
247+
for epoch in range(4):
248+
print("epoch =", epoch)
249+
250+
if epoch > 0:
251+
model.gauss_seidel_step(learning_rate=0.90)
252+
253+
fig = plot_model(model)
254+
fig.savefig(os.path.join(output_dir, f"fig_proj_{epoch:02.0f}.png"))
255+
plt.close()
256+
257+
# Plot final distribution
258+
x_pred = model.sample(x_true.shape[0])
259+
260+
fig, axs = plt.subplots(ncols=2, constrained_layout=True)
261+
for ax, x in zip(axs, [x_pred, x_true]):
262+
ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
263+
fig.savefig(os.path.join(output_dir, f"fig_dist_{epoch:02.0f}.png"))
264+
plt.close()
265+

examples/rec_simple.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import pathlib
3+
14
import numpy as np
25
import matplotlib.pyplot as plt
36

@@ -9,6 +12,10 @@
912
nmeas = 7
1013
seed = 0
1114

15+
path = pathlib.Path(__file__)
16+
output_dir = os.path.join("outputs", path.stem)
17+
os.makedirs(output_dir, exist_ok=True)
18+
1219

1320
# Ground truth distribution
1421
# --------------------------------------------------------------------------------------
@@ -113,15 +120,18 @@ def plot_model(model):
113120
values_true = projections_true[i].values
114121
ax.plot(values_pred / values_true.max(), color="lightgray")
115122
ax.plot(values_true / values_true.max(), color="black", lw=0.0, marker=".", ms=2.0)
116-
plt.show()
123+
return fig
117124

118125

119126
for epoch in range(4):
127+
print("epoch =", epoch)
128+
120129
if epoch > 0:
121130
model.gauss_seidel_step(learning_rate=0.90)
122131

123-
plot_model(model)
124-
132+
fig = plot_model(model)
133+
fig.savefig(os.path.join(output_dir, f"fig_proj_{epoch:02.0f}.png"))
134+
plt.close("all")
125135

126136
# Plot final distribution
127137
x_pred = model.sample(x_true.shape[0])
@@ -130,4 +140,5 @@ def plot_model(model):
130140
for ax, X in zip(axs, [x_pred, x_true]):
131141
ax.hist2d(X[:, 0], X[:, 1], bins=55, range=[(-4.0, 4.0), (-4.0, 4.0)])
132142
ax.set_aspect(1.0)
133-
plt.show()
143+
fig.savefig(os.path.join(output_dir, f"fig_dist_{epoch:02.0f}.png"))
144+
plt.close("all")

0 commit comments

Comments
 (0)