Skip to content

Commit 18ee81e

Browse files
Merge pull request #8 from austin-hoover/example-nonlinear-ring
Example 4D nonlinear ring example
2 parents fab8513 + fc4db78 commit 18ee81e

File tree

11 files changed

+426
-0
lines changed

11 files changed

+426
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# 4D reconstruction in nonlinear ring
2+
3+
<img src="saved/fig_action.png" width="600px">
4+
5+
In this example we measure the $x-p_x$ and $y-p_y$ distribution after every 20 turns in a linear periodic lattice + nonlinear rotationally symmetric kick [1]. MENT finds an initial 4D phase space distribution consistent with these projections.
6+
7+
[1] https://arxiv.org/abs/2405.05657
8+
9+

examples/nonlinear_ring_4d/eval.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
import shutil
5+
import sys
6+
import matplotlib as mpl
7+
import matplotlib.pyplot as plt
8+
import ment
9+
import numpy as np
10+
from scipy.ndimage import gaussian_filter
11+
12+
# local
13+
from utils import make_dist
14+
from utils import get_actions
15+
16+
17+
# Arguments
18+
# --------------------------------------------------------------------------------------
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--nsamp", type=int, default=500_000)
22+
parser.add_argument("--nbins", type=int, default=90)
23+
parser.add_argument("--blur", type=int, default=0.0)
24+
parser.add_argument("--cmap", type=str, default="plasma")
25+
args = parser.parse_args()
26+
27+
28+
# Setup
29+
# --------------------------------------------------------------------------------------
30+
31+
path = pathlib.Path(__file__)
32+
output_dir = os.path.join("outputs", path.stem)
33+
34+
if os.path.exists(output_dir):
35+
shutil.rmtree(output_dir)
36+
os.makedirs(output_dir)
37+
38+
39+
# Load model
40+
# --------------------------------------------------------------------------------------
41+
42+
input_dir = "outputs/train/checkpoints"
43+
filenames = os.listdir(input_dir)
44+
filenames = sorted(filenames)
45+
filenames = [f for f in filenames if f.endswith(".pt")]
46+
filenames = [os.path.join(input_dir, f) for f in filenames]
47+
48+
filename = filenames[-1]
49+
50+
model = ment.MENT(
51+
ndim=4,
52+
transforms=None,
53+
projections=None,
54+
sampler=None,
55+
prior=None,
56+
)
57+
model.load(filename)
58+
59+
60+
# Sample particles and simulate data
61+
# --------------------------------------------------------------------------------------
62+
63+
nsamp = args.nsamp
64+
x_true = make_dist(nsamp)
65+
x_pred = model.unnormalize(model.sample(nsamp))
66+
projections_pred = ment.unravel(ment.simulate(x_pred, model.transforms, model.diagnostics))
67+
projections_true = ment.unravel(ment.simulate(x_true, model.transforms, model.diagnostics))
68+
69+
70+
# Plot distribution of actions Jx-Jy
71+
# --------------------------------------------------------------------------------------
72+
73+
ncols = len(projections_pred) // 2
74+
cmap = args.cmap
75+
76+
for log in [False, True]:
77+
fig, axs = plt.subplots(
78+
nrows=2,
79+
ncols=ncols,
80+
figsize=(ncols * 2.0, 4.0),
81+
constrained_layout=True,
82+
sharex=True,
83+
sharey=True,
84+
)
85+
for j, transform in enumerate(model.transforms):
86+
x_true_out = transform(x_true)
87+
x_pred_out = transform(x_pred)
88+
89+
values_list = []
90+
for i, x_out in enumerate([x_pred_out, x_true_out]):
91+
actions = get_actions(x_out)
92+
sqrt_actions = np.sqrt(actions)
93+
94+
xmax = 4.0
95+
limits = 2 * [(0.0, xmax)]
96+
values, edges = np.histogramdd(sqrt_actions, bins=args.nbins, range=limits)
97+
values_list.append(values)
98+
99+
scale = np.max([np.max(values) for values in values_list])
100+
for i, values in enumerate(values_list):
101+
if args.blur:
102+
values = gaussian_filter(values, args.blur)
103+
#values = values / scale
104+
values = values / np.max(values)
105+
if log:
106+
values = np.log10(values + 1.00e-15)
107+
108+
vmax = 1.0
109+
vmin = 0.0
110+
if log:
111+
vmax = 0.0
112+
vmin = -3.0
113+
114+
ax = axs[i, j]
115+
mesh = ax.pcolormesh(
116+
edges[0],
117+
edges[1],
118+
values.T,
119+
cmap=cmap,
120+
vmax=vmax,
121+
vmin=vmin,
122+
linewidth=0.0,
123+
rasterized=True,
124+
shading="auto",
125+
)
126+
127+
axs[1, 0].set_xlabel("Jx")
128+
axs[1, 0].set_ylabel("Jy")
129+
130+
131+
filename = "fig_action"
132+
if log:
133+
filename = filename + "_log"
134+
filename = filename + ".pdf"
135+
filename = os.path.join(output_dir, filename)
136+
plt.savefig(filename, dpi=300)
137+
plt.close()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
from ment.sim import Transform
3+
4+
5+
class AxiallySymmetricNonlinearKick(Transform):
6+
def __init__(self, alpha: float, beta: float, phi: float, A: float, E: float, T: float) -> None:
7+
super().__init__()
8+
self.alpha = alpha
9+
self.beta = beta
10+
self.phi = phi
11+
self.A = A
12+
self.E = E
13+
self.T = T
14+
15+
def forward(self, x: np.ndarray) -> np.ndarray:
16+
r = np.sqrt(x[:, 0]**2 + x[:, 2]**2)
17+
t = np.arctan2(x[:, 2], x[:, 0])
18+
19+
alpha = self.alpha
20+
beta = self.beta
21+
phi = self.phi
22+
E = self.E
23+
A = self.A
24+
T = self.T
25+
26+
dr = -(1.0 / (beta * np.sin(phi))) * ((E * r) / (A * r**2 + T)) - ((2.0 * r) / (beta * np.tan(phi)))
27+
28+
x_out = np.copy(x)
29+
x_out[:, 1] += dr * np.cos(t)
30+
x_out[:, 3] += dr * np.sin(t)
31+
return x_out
32+
33+
def inverse(self, x: np.ndarray) -> np.ndarray:
34+
x[:, 1] *= -1.0
35+
X = self.forward(X)
36+
X[:, 1] *= -1.0
37+
return X
579 KB
Loading
725 KB
Loading
117 KB
Loading
118 KB
Loading
113 KB
Loading
112 KB
Loading

0 commit comments

Comments
 (0)