|
| 1 | +"""Plot 4D reconstruction in action space.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +import pathlib |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import ment |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import scipy.ndimage |
| 11 | + |
| 12 | +from tools.cov import normalization_matrix |
| 13 | +from tools.utils import list_paths |
| 14 | + |
| 15 | + |
| 16 | +plt.style.use("tools/style.mplstyle") |
| 17 | + |
| 18 | + |
| 19 | +# Arguments |
| 20 | +# -------------------------------------------------------------------------------------- |
| 21 | + |
| 22 | +parser = argparse.ArgumentParser() |
| 23 | +parser.add_argument("--nsamp", type=int, default=1_000_000) |
| 24 | +parser.add_argument("--nbins", type=int, default=64) |
| 25 | +parser.add_argument("--jmax", type=float, default=85.0) |
| 26 | +parser.add_argument("--contours", type=int, default=8) |
| 27 | +parser.add_argument("--show", action="store_true") |
| 28 | +args = parser.parse_args() |
| 29 | + |
| 30 | + |
| 31 | +# Setup |
| 32 | +# -------------------------------------------------------------------------------------- |
| 33 | + |
| 34 | +input_dir = "./outputs/train/" |
| 35 | + |
| 36 | +path = pathlib.Path(__file__) |
| 37 | +output_dir = os.path.join("outputs", path.stem) |
| 38 | +if not os.path.exists(output_dir): |
| 39 | + os.makedirs(output_dir) |
| 40 | + |
| 41 | + |
| 42 | +# Load model |
| 43 | +# -------------------------------------------------------------------------------------- |
| 44 | + |
| 45 | +checkpoints_folder = os.path.join(input_dir, "checkpoints") |
| 46 | +checkpoint_filenames = list_paths(os.path.join(input_dir, "checkpoints"), sort=True) |
| 47 | +checkpoint_filename = checkpoint_filenames[-1] |
| 48 | + |
| 49 | +model = ment.MENT( |
| 50 | + ndim=4, |
| 51 | + transforms=None, |
| 52 | + projections=None, |
| 53 | + prior=None, |
| 54 | + sampler=None, |
| 55 | +) |
| 56 | +model.load(checkpoint_filenames[-1]) |
| 57 | +model.sampler.noise = 1.0 |
| 58 | + |
| 59 | + |
| 60 | +# Sample particles from distribution |
| 61 | +# -------------------------------------------------------------------------------------- |
| 62 | + |
| 63 | +x = model.unnormalize(model.sample(args.nsamp)) |
| 64 | + |
| 65 | +cov_matrix = np.cov(x.T) |
| 66 | +norm_matrix = normalization_matrix(cov_matrix, scale=False) |
| 67 | + |
| 68 | +z = np.matmul(x, norm_matrix.T) |
| 69 | +cov_matrix_n = np.cov(z.T) |
| 70 | + |
| 71 | +print("covariance matrix:") |
| 72 | +print(np.round(cov_matrix, 5)) |
| 73 | + |
| 74 | +print("normalized covariance matrix:") |
| 75 | +print(np.round(cov_matrix_n, 5)) |
| 76 | + |
| 77 | +eps1 = np.sqrt(cov_matrix_n[0, 0] * cov_matrix_n[1, 1]) |
| 78 | +eps2 = np.sqrt(cov_matrix_n[2, 2] * cov_matrix_n[3, 3]) |
| 79 | +eps_avg = np.sqrt(eps1 * eps2) # maintains 4D emittance but equal mode amplitudes |
| 80 | + |
| 81 | +z_pred = np.copy(z) |
| 82 | + |
| 83 | + |
| 84 | +# Plot actions |
| 85 | +# -------------------------------------------------------------------------------------- |
| 86 | + |
| 87 | +def make_joint_grid(figwidth=5.0, panel_width=0.33) -> tuple: |
| 88 | + fig, axs = plt.subplots( |
| 89 | + ncols=2, |
| 90 | + nrows=2, |
| 91 | + sharex="col", |
| 92 | + sharey="row", |
| 93 | + figsize=(figwidth, figwidth), |
| 94 | + gridspec_kw=dict( |
| 95 | + width_ratios=[1.0, panel_width], |
| 96 | + height_ratios=[panel_width, 1.0], |
| 97 | + ), |
| 98 | + ) |
| 99 | + axs[0, 1].axis("off") |
| 100 | + return fig, axs |
| 101 | + |
| 102 | + |
| 103 | +def plot_hist(values: np.ndarray, edges: list[np.ndarray], contours: int = 0, **plot_kws) -> tuple: |
| 104 | + fig, axs = make_joint_grid() |
| 105 | + |
| 106 | + ax = axs[1, 0] |
| 107 | + ax.pcolormesh( |
| 108 | + edges[0], |
| 109 | + edges[1], |
| 110 | + values.T, |
| 111 | + linewidth=0.0, |
| 112 | + rasterized=True, |
| 113 | + shading="auto", |
| 114 | + ) |
| 115 | + |
| 116 | + if contours: |
| 117 | + coords = [0.5 * (e[:-1] + e[1:]) for e in edges] |
| 118 | + axs[1, 0].contour( |
| 119 | + coords[0], |
| 120 | + coords[1], |
| 121 | + scipy.ndimage.gaussian_filter(values, 1.0).T, |
| 122 | + levels=np.linspace(0.01, 1.0, contours), |
| 123 | + colors="white", |
| 124 | + linewidths=0.80, |
| 125 | + alpha=0.15, |
| 126 | + ) |
| 127 | + |
| 128 | + ax = axs[0, 0] |
| 129 | + proj_edges = edges[0] |
| 130 | + proj_values = np.sum(values, axis=1) |
| 131 | + proj_values = proj_values / np.max(proj_values) |
| 132 | + ax.stairs(proj_values, proj_edges, lw=1.5, color="black") |
| 133 | + ax.set_ylim(0.0, 1.25) |
| 134 | + |
| 135 | + ax = axs[1, 1] |
| 136 | + proj_edges = edges[1] |
| 137 | + proj_values = np.sum(values, axis=0) |
| 138 | + proj_values = proj_values / np.max(proj_values) |
| 139 | + ax.stairs(proj_values, proj_edges, orientation="horizontal", lw=1.5, color="black") |
| 140 | + ax.set_xlim(0.0, 1.25) |
| 141 | + |
| 142 | + return fig, axs |
| 143 | + |
| 144 | + |
| 145 | +z = z_pred |
| 146 | +j1 = np.sum(np.square(z[:, (0, 1)]), axis=1) |
| 147 | +j2 = np.sum(np.square(z[:, (2, 3)]), axis=1) |
| 148 | + |
| 149 | +edges = np.linspace(0.0, args.jmax, args.nbins + 1) |
| 150 | +edges = [edges, edges] |
| 151 | +values, _, _ = np.histogram2d(j1, j2, bins=edges) |
| 152 | +values = values / np.max(values) |
| 153 | + |
| 154 | +fig, axs = plot_hist(values, edges, contours=args.contours) |
| 155 | +axs[1, 0].set_xlabel(r"$J_1$") |
| 156 | +axs[1, 0].set_ylabel(r"$J_2$") |
| 157 | + |
| 158 | +filename = f"fig_action.pdf" |
| 159 | +filename = os.path.join(output_dir, filename) |
| 160 | +plt.savefig(filename) |
0 commit comments