|
| 1 | +"""Reconstruct longitudinal phase space distribution from turn-by-turn projections. |
| 2 | +
|
| 3 | +This script uses a PyORBIT [https://github.com/PyORBIT-Collaboration/PyORBIT3] lattice |
| 4 | +model consisting of a harmonic RF cavity surrounded by two drifts. Things are a bit slow |
| 5 | +because we have to repeatedly convert between NumPy arrays and Bunch objects, but it works. |
| 6 | +
|
| 7 | +Note that one MENT iteration requires simulating all projectionos. If projectiono k |
| 8 | +is measured after k turns, then we must first track the bunch 1 turn, then resample |
| 9 | +and track 2 turns, then resample and track 3 turns, etc. In total, we must track |
| 10 | +n * (n + 1) / 2 turns. For a significant number of turns, ART may be the better |
| 11 | +option. |
| 12 | +""" |
| 13 | +import os |
| 14 | +import pathlib |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import ment |
| 19 | + |
| 20 | +from orbit.core.bunch import Bunch |
| 21 | +from orbit.lattice import AccLattice |
| 22 | +from orbit.lattice import AccNode |
| 23 | +from orbit.rf_cavities import Harmonic_RFNode |
| 24 | +from orbit.teapot import DriftTEAPOT |
| 25 | +from orbit.teapot import TEAPOT_Ring |
| 26 | + |
| 27 | + |
| 28 | +# Setup |
| 29 | +# -------------------------------------------------------------------------------------- |
| 30 | + |
| 31 | +ndim = 2 |
| 32 | +nmeas = 7 |
| 33 | +seed = 0 |
| 34 | +size = 100_000 |
| 35 | + |
| 36 | +path = pathlib.Path(__file__) |
| 37 | +output_dir = os.path.join("outputs", path.stem) |
| 38 | +os.makedirs(output_dir, exist_ok=True) |
| 39 | + |
| 40 | + |
| 41 | +# Distribution |
| 42 | +# -------------------------------------------------------------------------------------- |
| 43 | + |
| 44 | +rng = np.random.default_rng(seed) |
| 45 | + |
| 46 | +x_true = np.zeros((size, 2)) |
| 47 | +x_true[:, 0] = 0.60 * rng.uniform(-124.0, 124.0, x_true.shape[0]) |
| 48 | +x_true[:, 1] = rng.normal(scale=0.0025, size=x_true.shape[0]) |
| 49 | + |
| 50 | + |
| 51 | +# Forward model |
| 52 | +# -------------------------------------------------------------------------------------- |
| 53 | + |
| 54 | +def get_part_coords(bunch: Bunch, index: int) -> list[float]: |
| 55 | + x = bunch.x(index) |
| 56 | + y = bunch.y(index) |
| 57 | + z = bunch.z(index) |
| 58 | + xp = bunch.xp(index) |
| 59 | + yp = bunch.yp(index) |
| 60 | + de = bunch.dE(index) |
| 61 | + return [x, xp, y, yp, z, de] |
| 62 | + |
| 63 | + |
| 64 | +def set_part_coords(bunch: Bunch, index: int, coords: list[float]) -> Bunch: |
| 65 | + (x, xp, y, yp, z, de) = coords |
| 66 | + bunch.x(index, x) |
| 67 | + bunch.y(index, y) |
| 68 | + bunch.z(index, z) |
| 69 | + bunch.xp(index, xp) |
| 70 | + bunch.yp(index, yp) |
| 71 | + bunch.dE(index, de) |
| 72 | + return bunch |
| 73 | + |
| 74 | + |
| 75 | +def get_bunch_coords(bunch: Bunch, axis: tuple[int, ...] = None) -> np.ndarray: |
| 76 | + x = np.zeros((bunch.getSize(), 6)) |
| 77 | + for i in range(bunch.getSize()): |
| 78 | + x[i, :] = get_part_coords(bunch, i) |
| 79 | + if axis is not None: |
| 80 | + x = x[:, axis] |
| 81 | + return x |
| 82 | + |
| 83 | + |
| 84 | +def set_bunch_coords(bunch: Bunch, x: np.ndarray, axis: tuple[int, ...] = None) -> Bunch: |
| 85 | + if axis is None: |
| 86 | + axis = tuple(range(6)) |
| 87 | + |
| 88 | + # Resize |
| 89 | + size = x.shape[0] |
| 90 | + size_error = size - bunch.getSize() |
| 91 | + if size_error > 0: |
| 92 | + for _ in range(size_error): |
| 93 | + bunch.addParticle(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) |
| 94 | + else: |
| 95 | + for i in range(size, bunch.getSize()): |
| 96 | + bunch.deleteParticleFast(i) |
| 97 | + bunch.compress() |
| 98 | + |
| 99 | + for i in range(bunch.getSize()): |
| 100 | + coords = get_part_coords(bunch, i) |
| 101 | + for j in range(len(axis)): |
| 102 | + coords[axis[j]] = x[i, j] |
| 103 | + bunch = set_part_coords(bunch, i, coords) |
| 104 | + return bunch |
| 105 | + |
| 106 | + |
| 107 | +class ORBITTransform: |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + lattice: AccLattice, |
| 111 | + bunch: Bunch, |
| 112 | + axis: tuple[int, ...], |
| 113 | + nturns: int = 1, |
| 114 | + ) -> None: |
| 115 | + self.lattice = lattice |
| 116 | + self.bunch = bunch |
| 117 | + self.axis = axis |
| 118 | + self.nturns = nturns |
| 119 | + |
| 120 | + def track_bunch(self) -> Bunch: |
| 121 | + bunch = Bunch() |
| 122 | + self.bunch.copyBunchTo(bunch) |
| 123 | + for _ in range(self.nturns): |
| 124 | + self.lattice.trackBunch(bunch) |
| 125 | + return bunch |
| 126 | + |
| 127 | + def __call__(self, x: np.ndarray) -> np.ndarray: |
| 128 | + set_bunch_coords(self.bunch, x, axis=self.axis) |
| 129 | + bunch = self.track_bunch() |
| 130 | + x_out = get_bunch_coords(bunch, axis=self.axis) |
| 131 | + return x_out |
| 132 | + |
| 133 | + |
| 134 | +# Create accelerator lattice (drift, rf, drift) |
| 135 | +drift_node_1 = DriftTEAPOT() |
| 136 | +drift_node_2 = DriftTEAPOT() |
| 137 | +drift_node_1.setLength(124.0) |
| 138 | +drift_node_2.setLength(124.0) |
| 139 | + |
| 140 | + |
| 141 | +z_to_phi = 2.0 * np.pi / 248.0 |
| 142 | +rf_hnum = 1.0 |
| 143 | +rf_length = 0.0 |
| 144 | +rf_synchronous_de = 0.0 |
| 145 | +rf_voltage = 300.0e-06 |
| 146 | +rf_phase = 0.0 |
| 147 | +rf_node = Harmonic_RFNode(z_to_phi, rf_synchronous_de, rf_hnum, rf_voltage, rf_phase, rf_length) |
| 148 | + |
| 149 | +lattice = TEAPOT_Ring() |
| 150 | +lattice.addNode(drift_node_1) |
| 151 | +lattice.addNode(rf_node) |
| 152 | +lattice.addNode(drift_node_2) |
| 153 | +lattice.initialize() |
| 154 | + |
| 155 | + |
| 156 | +# Create bunch |
| 157 | +bunch = Bunch() |
| 158 | +bunch.mass(0.938) |
| 159 | +bunch.getSyncParticle().kinEnergy(1.000) |
| 160 | + |
| 161 | +for i in range(x_true.shape[0]): |
| 162 | + bunch.addParticle(0.0, 0.0, 0.0, 0.0, x_true[i, 0], x_true[i, 1]) |
| 163 | + |
| 164 | + |
| 165 | +# Create transform functions |
| 166 | +turn_min = 0 |
| 167 | +turn_max = 500 |
| 168 | +turn_step = int((turn_max - turn_min) / nmeas) |
| 169 | +turns = list(range(turn_min, turn_max + turn_step, turn_step)) |
| 170 | + |
| 171 | +transforms = [] |
| 172 | +for nturns in turns: |
| 173 | + transform = ORBITTransform(lattice, bunch, nturns=nturns, axis=(4, 5)) |
| 174 | + transforms.append(transform) |
| 175 | + |
| 176 | +limits = [ |
| 177 | + (-0.5 * lattice.getLength(), +0.5 * lattice.getLength()), |
| 178 | + (-0.030, 0.030) |
| 179 | +] |
| 180 | + |
| 181 | +# Create a list of histogram diagnostics for each transform. |
| 182 | +bin_edges = np.linspace(limits[0][0], limits[0][1], 100) |
| 183 | +diagnostics = [] |
| 184 | +for transform in transforms: |
| 185 | + diagnostic = ment.diag.Histogram1D(axis=0, edges=bin_edges) |
| 186 | + diagnostics.append([diagnostic]) |
| 187 | + |
| 188 | + |
| 189 | +# Training data |
| 190 | +# -------------------------------------------------------------------------------------- |
| 191 | + |
| 192 | +# Here we simulate the projections; in real life the projections would be measured. |
| 193 | +projections = ment.sim.simulate(x_true, transforms, diagnostics) |
| 194 | + |
| 195 | + |
| 196 | +# Reconstruction model |
| 197 | +# -------------------------------------------------------------------------------------- |
| 198 | + |
| 199 | +# Define prior distribution for relative entropy calculation |
| 200 | +prior = ment.prior.GaussianPrior(ndim=2, scale=[200.0, 0.020]) |
| 201 | + |
| 202 | +# Define particle sampler (if mode="sample") |
| 203 | +sampler = ment.samp.GridSampler( |
| 204 | + grid_limits=limits, |
| 205 | + grid_shape=(128, 128), |
| 206 | +) |
| 207 | + |
| 208 | +# Set up MENT model |
| 209 | +model = ment.MENT( |
| 210 | + ndim=ndim, |
| 211 | + transforms=transforms, |
| 212 | + projections=projections, |
| 213 | + prior=prior, |
| 214 | + sampler=sampler, |
| 215 | + mode="sample", |
| 216 | + verbose=2, |
| 217 | +) |
| 218 | + |
| 219 | + |
| 220 | +# Training |
| 221 | +# -------------------------------------------------------------------------------------- |
| 222 | + |
| 223 | + |
| 224 | +def plot_model(model): |
| 225 | + # Sample particles |
| 226 | + x_pred = model.sample(size) |
| 227 | + |
| 228 | + # Plot sim vs. measured profiles |
| 229 | + projections_true = ment.sim.copy_histograms(model.projections) |
| 230 | + projections_true = ment.utils.unravel(projections_true) |
| 231 | + |
| 232 | + projections_pred = ment.sim.copy_histograms(model.diagnostics) |
| 233 | + projections_pred = ment.sim.simulate(x_pred, transforms, projections_pred) |
| 234 | + projections_pred = ment.utils.unravel(projections_pred) |
| 235 | + |
| 236 | + fig, axs = plt.subplots( |
| 237 | + ncols=nmeas, figsize=(11.0, 1.0), sharey=True, sharex=True, constrained_layout=True |
| 238 | + ) |
| 239 | + for i, ax in enumerate(axs): |
| 240 | + values_pred = projections_pred[i].values |
| 241 | + values_true = projections_true[i].values |
| 242 | + ax.plot(values_pred / values_true.max(), color="lightgray") |
| 243 | + ax.plot(values_true / values_true.max(), color="black", lw=0.0, marker=".", ms=2.0) |
| 244 | + return fig |
| 245 | + |
| 246 | + |
| 247 | +for epoch in range(4): |
| 248 | + print("epoch =", epoch) |
| 249 | + |
| 250 | + if epoch > 0: |
| 251 | + model.gauss_seidel_step(learning_rate=0.90) |
| 252 | + |
| 253 | + fig = plot_model(model) |
| 254 | + fig.savefig(os.path.join(output_dir, f"fig_proj_{epoch:02.0f}.png")) |
| 255 | + plt.close() |
| 256 | + |
| 257 | +# Plot final distribution |
| 258 | +x_pred = model.sample(x_true.shape[0]) |
| 259 | + |
| 260 | +fig, axs = plt.subplots(ncols=2, constrained_layout=True) |
| 261 | +for ax, x in zip(axs, [x_pred, x_true]): |
| 262 | + ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits) |
| 263 | +fig.savefig(os.path.join(output_dir, f"fig_dist_{epoch:02.0f}.png")) |
| 264 | +plt.close() |
| 265 | + |
0 commit comments