Skip to content

Commit 5b5615a

Browse files
committed
Add histogram diagnostic examples
1 parent 76781d5 commit 5b5615a

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
axes.linewidth: 1.25
2+
figure.constrained_layout.use: True
3+
xtick.minor.visible: True
4+
ytick.minor.visible: True
5+
6+
savefig.format: "png"
7+
savefig.dpi: 300
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
from orbit.core import orbit_mpi
9+
from orbit.core.bunch import Bunch
10+
from orbit.diagnostics import BunchHistogram
11+
from orbit.diagnostics import BunchHistogram1D
12+
from orbit.diagnostics import BunchHistogram2D
13+
from orbit.diagnostics import BunchHistogram3D
14+
from orbit.bunch_utils import collect_bunch
15+
16+
plt.style.use("style.mplstyle")
17+
18+
19+
# Parse args
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--nsamp", type=int, default=10_000)
22+
parser.add_argument("--nbins", type=int, default=64)
23+
args = parser.parse_args()
24+
25+
# Make output directory
26+
path = pathlib.Path(__file__)
27+
output_dir = os.path.join("outputs", path.stem)
28+
os.makedirs(output_dir, exist_ok=True)
29+
30+
# Setup MPI
31+
_mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
32+
_mpi_rank = orbit_mpi.MPI_Comm_rank(_mpi_comm)
33+
34+
# Generate particle distribution
35+
rng = np.random.default_rng(123)
36+
X = rng.normal(size=(args.nsamp, 2))
37+
X = X / np.linalg.norm(X, axis=1)[:, None]
38+
X[:, 0] -= 0.75 * X[:, 1] ** 2
39+
X = X + rng.normal(size=X.shape, scale=0.25)
40+
X = X / np.std(X, axis=0)
41+
X = np.hstack([X, np.zeros((X.shape[0], 4))])
42+
43+
# Create bunch
44+
bunch = Bunch()
45+
for i in range(X.shape[0]):
46+
bunch.addParticle(*X[i, :])
47+
48+
# Create histogram diagnostic
49+
grid_limits = [(-4.0, 4.0)]
50+
grid_shape = (args.nbins,)
51+
52+
axis = (0,)
53+
diag = BunchHistogram(
54+
axis=axis,
55+
shape=grid_shape,
56+
limits=grid_limits,
57+
method=None,
58+
normalize=True,
59+
output_dir=output_dir,
60+
)
61+
62+
# Compute histogram
63+
diag.track(bunch)
64+
65+
# Plot histogram
66+
if _mpi_rank == 0:
67+
fig, ax = plt.subplots(figsize=(5, 2))
68+
ax.plot(diag.grid_coords[0], diag.grid_values, lw=1.5, color=None, label="PyORBIT")
69+
grid_values_np, _ = np.histogram(X[:, axis], bins=diag.grid_edges[0], density=True)
70+
ax.plot(diag.grid_coords[0], grid_values_np, lw=1.5, color=None, label="NumPy")
71+
ax.legend()
72+
73+
filename = os.path.join(output_dir, "fig_hist")
74+
plt.savefig(filename)
75+
plt.show()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
from orbit.core import orbit_mpi
9+
from orbit.core.bunch import Bunch
10+
from orbit.diagnostics import BunchHistogram
11+
from orbit.diagnostics import BunchHistogram1D
12+
from orbit.diagnostics import BunchHistogram2D
13+
from orbit.diagnostics import BunchHistogram3D
14+
from orbit.bunch_utils import collect_bunch
15+
16+
plt.style.use("style.mplstyle")
17+
18+
19+
# Parse args
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--nsamp", type=int, default=10_000)
22+
parser.add_argument("--nbins", type=int, default=64)
23+
args = parser.parse_args()
24+
25+
# Make output directory
26+
path = pathlib.Path(__file__)
27+
output_dir = os.path.join("outputs", path.stem)
28+
os.makedirs(output_dir, exist_ok=True)
29+
30+
# Setup MPI
31+
_mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
32+
_mpi_rank = orbit_mpi.MPI_Comm_rank(_mpi_comm)
33+
34+
# Generate particle distribution
35+
rng = np.random.default_rng(123)
36+
X = rng.normal(size=(args.nsamp, 2))
37+
X = X / np.linalg.norm(X, axis=1)[:, None]
38+
X[:, 0] -= 0.75 * X[:, 1] ** 2
39+
X = X + rng.normal(size=X.shape, scale=0.25)
40+
X = X / np.std(X, axis=0)
41+
X = np.hstack([X, np.zeros((X.shape[0], 4))])
42+
43+
# Create bunch
44+
bunch = Bunch()
45+
for i in range(X.shape[0]):
46+
bunch.addParticle(*X[i, :])
47+
48+
# Create histogram diagnostic
49+
grid_limits = 2 * [(-4.0, 4.0)]
50+
grid_shape = (args.nbins, args.nbins)
51+
52+
axis = (0, 1)
53+
diag = BunchHistogram(
54+
axis=axis,
55+
shape=grid_shape,
56+
limits=grid_limits,
57+
method=None,
58+
normalize=True,
59+
output_dir=output_dir,
60+
)
61+
62+
# Compute histogram
63+
diag.track(bunch)
64+
65+
# Plot histogram
66+
if _mpi_rank == 0:
67+
fig, axs = plt.subplots(ncols=2, figsize=(6.0, 3.0), sharex=True, sharey=True)
68+
axs[0].pcolormesh(
69+
diag.grid_coords[0],
70+
diag.grid_coords[1],
71+
diag.grid_values.T,
72+
)
73+
74+
grid_values_np, _ = np.histogramdd(X[:, axis], bins=diag.grid_edges, density=True)
75+
axs[1].pcolormesh(
76+
diag.grid_coords[0],
77+
diag.grid_coords[1],
78+
grid_values_np.T,
79+
)
80+
axs[0].set_title("BunchHistogram", fontsize="medium")
81+
axs[1].set_title("NumPy", fontsize="medium")
82+
83+
filename = os.path.join(output_dir, "fig_hist")
84+
plt.savefig(filename)
85+
plt.show()

0 commit comments

Comments
 (0)