Skip to content

Commit 76781d5

Browse files
committed
Add BunchHistogram diagnostic
1 parent b9c1bcb commit 76781d5

File tree

4 files changed

+213
-1
lines changed

4 files changed

+213
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ PyORBIT.egg-info
99
*.so
1010
.*.swp
1111
.eggs
12+
.ipynb_checkpoints

py/orbit/diagnostics/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from .TeapotDiagnosticsNode import TeapotStatLatsNode, TeapotStatLatsNodeSetMember
1414
from .TeapotDiagnosticsNode import TeapotMomentsNode, TeapotMomentsNodeSetMember
1515
from .TeapotDiagnosticsNode import TeapotTuneAnalysisNode
16+
from .histogram import BunchHistogram
17+
from .histogram import BunchHistogram1D
18+
from .histogram import BunchHistogram2D
19+
from .histogram import BunchHistogram3D
1620

1721

1822
__all__ = []
@@ -30,3 +34,4 @@
3034
__all__.append("addTeapotMomentsNodeSet")
3135
__all__.append("TeapotTuneAnalysisNode")
3236
__all__.append("profiles")
37+
__all__.append("histogram")

py/orbit/diagnostics/histogram.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import os
2+
import sys
3+
import time
4+
from typing import Any
5+
from typing import Callable
6+
from typing import Union
7+
8+
import numpy as np
9+
import xarray as xr
10+
11+
from orbit.core import orbit_mpi
12+
from orbit.core.bunch import Bunch
13+
from orbit.core.spacecharge import Grid1D
14+
from orbit.core.spacecharge import Grid2D
15+
from orbit.core.spacecharge import Grid3D
16+
from orbit.lattice import AccLattice
17+
from orbit.lattice import AccNode
18+
19+
20+
Grid = Union[Grid1D, Grid2D, Grid3D]
21+
22+
23+
def get_grid_points(grid_coords: list[np.ndarray]) -> np.ndarray:
24+
if len(grid_coords) == 1:
25+
return grid_coords[0]
26+
return np.vstack([c.ravel() for c in np.meshgrid(*grid_coords, indexing="ij")]).T
27+
28+
29+
def grid_edges_to_coords(grid_edges: np.ndarray) -> np.ndarray:
30+
return 0.5 * (grid_edges[:-1] + grid_edges[1:])
31+
32+
33+
def make_grid(shape: tuple[int, ...], limits: list[tuple[float, float]]) -> Grid:
34+
35+
ndim = len(shape)
36+
37+
grid = None
38+
if ndim == 1:
39+
grid = Grid1D(shape[0] + 1, limits[0][0], limits[0][1])
40+
elif ndim == 2:
41+
grid = Grid2D(
42+
shape[0] + 1,
43+
shape[1] + 1,
44+
limits[0][0],
45+
limits[0][1],
46+
limits[1][0],
47+
limits[1][1],
48+
)
49+
elif ndim == 3:
50+
grid = Grid3D(
51+
shape[0] + 1,
52+
shape[1] + 1,
53+
shape[2] + 1,
54+
)
55+
grid.setGridX(limits[0][0], limits[0][1])
56+
grid.setGridY(limits[1][0], limits[1][1])
57+
grid.setGridZ(limits[2][0], limits[2][1])
58+
else:
59+
raise ValueError
60+
61+
return grid
62+
63+
64+
class BunchHistogram:
65+
"""MPI-compatible bunch histogram."""
66+
def __init__(
67+
self,
68+
axis: tuple[int, ...],
69+
shape: tuple[int, ...],
70+
limits: list[tuple[float, float]],
71+
method: str = None,
72+
transform: Callable = None,
73+
normalize: bool = True,
74+
output_dir: str = None,
75+
verbose: int = 2,
76+
**kwargs
77+
) -> None:
78+
"""Constructor.
79+
80+
Args:
81+
axis: Axis on which to compute the histogram.
82+
shape: Number of bins along each axis.
83+
limits: Min/max coordinates along each axis.
84+
method: Smoothing method {"bilinear", "nine-point", None}.
85+
transform: Transforms bunch before histogram is calculated.
86+
Call signature is `bunch_new = transform(bunch)`.
87+
normalize: Whehter to normalize values to PDF.
88+
output_dir: Output directory for saved files.
89+
verbose: Whether to print update messages.
90+
"""
91+
self.mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
92+
self.mpi_rank = orbit_mpi.MPI_Comm_rank(self.mpi_comm)
93+
self.output_dir = output_dir
94+
self.verbose = verbose
95+
96+
self.axis = axis
97+
self.ndim = len(axis)
98+
self.method = method
99+
self.transform = transform
100+
self.normalize = normalize
101+
102+
self.index = 0 # number of calls to `track` method
103+
self.node = None
104+
105+
if self.ndim > 2:
106+
raise NotImplementedError(
107+
"BunchHistogram does not yet support 3D grids. See "
108+
"https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/46"
109+
" and "
110+
"https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/47"
111+
)
112+
113+
# Dimension names
114+
self.dims = ["x", "xp", "y", "yp", "z", "dE"]
115+
self.dims = [self.dims[i] for i in self.axis]
116+
117+
# Create grid
118+
self.grid_shape = shape
119+
self.grid_limits = limits
120+
self.grid_edges = [
121+
np.linspace(self.grid_limits[i][0], self.grid_limits[i][1], self.grid_shape[i] + 1)
122+
for i in range(self.ndim)
123+
]
124+
self.grid_coords = [grid_edges_to_coords(e) for e in self.grid_edges]
125+
self.grid_values = np.zeros(shape)
126+
self.grid_points = get_grid_points(self.grid_coords)
127+
self.grid = make_grid(self.grid_shape, self.grid_limits)
128+
129+
# Store cell volume for normalization
130+
self.cell_volume = np.prod([e[1] - e[0] for e in self.grid_edges])
131+
132+
def sync_mpi(self) -> None:
133+
self.grid.synchronizeMPI(self.mpi_comm)
134+
135+
def bin_bunch(self, bunch: Bunch) -> None:
136+
macrosize = bunch.macroSize()
137+
if macrosize == 0:
138+
bunch.macroSize(1.0)
139+
140+
if self.method == "bilinear":
141+
self.grid.binBunchBilinear(bunch, *self.axis)
142+
else:
143+
self.grid.binBunch(bunch, *self.axis)
144+
145+
bunch.macroSize(macrosize)
146+
147+
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
148+
self.bin_bunch(bunch)
149+
self.sync_mpi()
150+
151+
values = np.zeros(self.grid_points.shape[0])
152+
if self.method == "bilinear":
153+
for i, point in enumerate(self.grid_points):
154+
values[i] = self.grid.getValueBilinear(*point)
155+
elif self.method == "nine-point":
156+
for i, point in enumerate(self.grid_points):
157+
values[i] = self.grid.getValue(*point)
158+
else:
159+
for i, indices in enumerate(np.ndindex(*self.grid_shape)):
160+
values[i] = self.grid.getValueOnGrid(*indices)
161+
values = np.reshape(values, self.grid_shape)
162+
163+
if self.normalize:
164+
values_sum = np.sum(values)
165+
if values_sum > 0.0:
166+
values /= values_sum
167+
values /= self.cell_volume
168+
return values
169+
170+
def track(self, bunch: Bunch) -> None:
171+
bunch_copy = Bunch()
172+
bunch.copyBunchTo(bunch_copy)
173+
if self.transform is not None:
174+
bunch_copy = self.transform(bunch_copy)
175+
176+
self.grid.setZero()
177+
self.grid_values = self.compute_histogram(bunch_copy)
178+
179+
if self.output_dir is not None:
180+
array = xr.DataArray(self.grid_values, coords=self.grid_coords, dims=self.dims)
181+
array.to_netcdf(path=self.get_filename())
182+
183+
self.index += 1
184+
185+
def get_filename(self) -> str:
186+
filename = "hist_" + "-".join([str(i) for i in self.axis])
187+
filename = "{}_{:04.0f}".format(filename, self.index)
188+
filename = "{}.nc".format(filename)
189+
filename = os.path.join(self.output_dir, filename)
190+
return filename
191+
192+
193+
class BunchHistogram1D(BunchHistogram):
194+
def __init__(self, **kwargs) -> None:
195+
super().__init__(**kwargs)
196+
197+
198+
class BunchHistogram2D(BunchHistogram):
199+
def __init__(self, **kwargs) -> None:
200+
super().__init__(**kwargs)
201+
202+
203+
class BunchHistogram3D(BunchHistogram):
204+
def __init__(self, **kwargs) -> None:
205+
super().__init__(**kwargs)

py/orbit/diagnostics/meson.build

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ py_sources = files([
66
'TeapotDiagnosticsNode.py',
77
'diagnosticsLatticeModifications.py',
88
'__init__.py',
9-
'profiles.py'
9+
'profiles.py',
10+
'histogram.py',
1011
])
1112

1213
python.install_sources(

0 commit comments

Comments
 (0)