|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import pathlib |
| 4 | +import shutil |
| 5 | +import sys |
| 6 | +import matplotlib as mpl |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import ment |
| 9 | +import numpy as np |
| 10 | +from scipy.ndimage import gaussian_filter |
| 11 | + |
| 12 | +# local |
| 13 | +from utils import make_dist |
| 14 | +from utils import get_actions |
| 15 | + |
| 16 | + |
| 17 | +# Arguments |
| 18 | +# -------------------------------------------------------------------------------------- |
| 19 | + |
| 20 | +parser = argparse.ArgumentParser() |
| 21 | +parser.add_argument("--nsamp", type=int, default=500_000) |
| 22 | +parser.add_argument("--nbins", type=int, default=90) |
| 23 | +parser.add_argument("--blur", type=int, default=0.0) |
| 24 | +parser.add_argument("--cmap", type=str, default="plasma") |
| 25 | +args = parser.parse_args() |
| 26 | + |
| 27 | + |
| 28 | +# Setup |
| 29 | +# -------------------------------------------------------------------------------------- |
| 30 | + |
| 31 | +path = pathlib.Path(__file__) |
| 32 | +output_dir = os.path.join("outputs", path.stem) |
| 33 | + |
| 34 | +if os.path.exists(output_dir): |
| 35 | + shutil.rmtree(output_dir) |
| 36 | +os.makedirs(output_dir) |
| 37 | + |
| 38 | + |
| 39 | +# Load model |
| 40 | +# -------------------------------------------------------------------------------------- |
| 41 | + |
| 42 | +input_dir = "outputs/train/checkpoints" |
| 43 | +filenames = os.listdir(input_dir) |
| 44 | +filenames = sorted(filenames) |
| 45 | +filenames = [f for f in filenames if f.endswith(".pt")] |
| 46 | +filenames = [os.path.join(input_dir, f) for f in filenames] |
| 47 | + |
| 48 | +filename = filenames[-1] |
| 49 | + |
| 50 | +model = ment.MENT( |
| 51 | + ndim=4, |
| 52 | + transforms=None, |
| 53 | + projections=None, |
| 54 | + sampler=None, |
| 55 | + prior=None, |
| 56 | +) |
| 57 | +model.load(filename) |
| 58 | + |
| 59 | + |
| 60 | +# Sample particles and simulate data |
| 61 | +# -------------------------------------------------------------------------------------- |
| 62 | + |
| 63 | +nsamp = args.nsamp |
| 64 | +x_true = make_dist(nsamp) |
| 65 | +x_pred = model.unnormalize(model.sample(nsamp)) |
| 66 | +projections_pred = ment.unravel(ment.simulate(x_pred, model.transforms, model.diagnostics)) |
| 67 | +projections_true = ment.unravel(ment.simulate(x_true, model.transforms, model.diagnostics)) |
| 68 | + |
| 69 | + |
| 70 | +# Plot distribution of actions Jx-Jy |
| 71 | +# -------------------------------------------------------------------------------------- |
| 72 | + |
| 73 | +ncols = len(projections_pred) // 2 |
| 74 | +cmap = args.cmap |
| 75 | + |
| 76 | +for log in [False, True]: |
| 77 | + fig, axs = plt.subplots( |
| 78 | + nrows=2, |
| 79 | + ncols=ncols, |
| 80 | + figsize=(ncols * 2.0, 4.0), |
| 81 | + constrained_layout=True, |
| 82 | + sharex=True, |
| 83 | + sharey=True, |
| 84 | + ) |
| 85 | + for j, transform in enumerate(model.transforms): |
| 86 | + x_true_out = transform(x_true) |
| 87 | + x_pred_out = transform(x_pred) |
| 88 | + |
| 89 | + values_list = [] |
| 90 | + for i, x_out in enumerate([x_pred_out, x_true_out]): |
| 91 | + actions = get_actions(x_out) |
| 92 | + sqrt_actions = np.sqrt(actions) |
| 93 | + |
| 94 | + xmax = 4.0 |
| 95 | + limits = 2 * [(0.0, xmax)] |
| 96 | + values, edges = np.histogramdd(sqrt_actions, bins=args.nbins, range=limits) |
| 97 | + values_list.append(values) |
| 98 | + |
| 99 | + scale = np.max([np.max(values) for values in values_list]) |
| 100 | + for i, values in enumerate(values_list): |
| 101 | + if args.blur: |
| 102 | + values = gaussian_filter(values, args.blur) |
| 103 | + #values = values / scale |
| 104 | + values = values / np.max(values) |
| 105 | + if log: |
| 106 | + values = np.log10(values + 1.00e-15) |
| 107 | + |
| 108 | + vmax = 1.0 |
| 109 | + vmin = 0.0 |
| 110 | + if log: |
| 111 | + vmax = 0.0 |
| 112 | + vmin = -3.0 |
| 113 | + |
| 114 | + ax = axs[i, j] |
| 115 | + mesh = ax.pcolormesh( |
| 116 | + edges[0], |
| 117 | + edges[1], |
| 118 | + values.T, |
| 119 | + cmap=cmap, |
| 120 | + vmax=vmax, |
| 121 | + vmin=vmin, |
| 122 | + linewidth=0.0, |
| 123 | + rasterized=True, |
| 124 | + shading="auto", |
| 125 | + ) |
| 126 | + |
| 127 | + axs[1, 0].set_xlabel("Jx") |
| 128 | + axs[1, 0].set_ylabel("Jy") |
| 129 | + |
| 130 | + |
| 131 | + filename = "fig_action" |
| 132 | + if log: |
| 133 | + filename = filename + "_log" |
| 134 | + filename = filename + ".pdf" |
| 135 | + filename = os.path.join(output_dir, filename) |
| 136 | + plt.savefig(filename, dpi=300) |
| 137 | + plt.close() |
0 commit comments