Skip to content

Commit 105db9a

Browse files
committed
Fix histogram diagnostics and move them to diag module
1 parent c544828 commit 105db9a

File tree

2 files changed

+176
-161
lines changed

2 files changed

+176
-161
lines changed

orbit_tools/diag.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
import os
22
import sys
33
import time
4+
from typing import Any
5+
from typing import Callable
46

57
import numpy as np
68

79
from orbit.core import orbit_mpi
810
from orbit.core.bunch import Bunch
11+
from orbit.core.spacecharge import Grid1D
12+
from orbit.core.spacecharge import Grid2D
13+
from orbit.core.spacecharge import Grid3D
914
from orbit.lattice import AccLattice
1015
from orbit.lattice import AccNode
1116

1217

18+
def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
19+
return np.vstack([C.ravel() for C in np.meshgrid(*coords, indexing="ij")]).T
20+
21+
1322
class Diagnostic:
14-
def __init__(self, output_dir: str, verbose: bool = True) -> None:
23+
def __init__(self, output_dir: str = None, verbose: bool = True) -> None:
1524
self._mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
1625
self._mpi_rank = orbit_mpi.MPI_Comm_rank(self._mpi_comm)
1726
self.output_dir = output_dir
@@ -33,3 +42,169 @@ def __call__(self, params_dict: dict) -> None:
3342
if not self.should_skip():
3443
self.track(params_dict)
3544
self.update()
45+
46+
47+
class BunchHistogram(Diagnostic):
48+
def __init__(
49+
self,
50+
axis: tuple[int, ...],
51+
shape: tuple[int, ...],
52+
limits: list[tuple[float, float]],
53+
transform: Callable = None,
54+
**kwargs
55+
) -> None:
56+
super().__init__(**kwargs)
57+
58+
self.axis = axis
59+
self.ndim = len(axis)
60+
61+
self.dims = ["x", "xp", "y", "yp", "z", "dE"]
62+
self.dims = [self.dims[i] for i in self.axis]
63+
64+
self.shape = shape
65+
self.limits = limits
66+
self.edges = [
67+
np.linspace(self.limits[i][0], self.limits[i][1], self.shape[i] + 1)
68+
for i in range(self.ndim)
69+
]
70+
self.coords = [0.5 * (e[:-1] + e[1:]) for e in self.edges]
71+
self.values = np.zeros(shape)
72+
73+
self.points = get_grid_points(self.coords)
74+
self.cell_volume = np.prod([e[1] - e[0] for e in self.edges])
75+
76+
self.transform = transform
77+
78+
def get_filename(self) -> str:
79+
filename = "hist_" + "-".join([str(i) for i in self.axis])
80+
filename = "{}_{:04.0f}".format(filename, self.index)
81+
filename = "{}_{}".format(filename, self.node.getName())
82+
filename = "{}.nc".format(filename)
83+
filename = os.path.join(self.output_dir, filename)
84+
return filename
85+
86+
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
87+
raise NotImplementedError
88+
89+
def __call__(self, params_dict: dict) -> np.ndarray:
90+
bunch_copy = Bunch()
91+
92+
bunch = params_dict["bunch"]
93+
bunch.copyBunchTo(bunch_copy)
94+
95+
if self.transform is not None:
96+
bunch_copy = self.transform(bunch_copy)
97+
98+
self.values = self.compute_histogram(bunch_copy)
99+
values_sum = np.sum(self.values)
100+
if values_sum > 0.0:
101+
self.values = self.values / values_sum
102+
self.values = self.values / self.cell_volume
103+
104+
if self.output_dir is not None:
105+
array = xr.DataArray(values, coords=self.coords, dims=self.dims)
106+
array.to_netcdf(path=self.get_filename(params_dict))
107+
108+
return self.values
109+
110+
def track(self, bunch: Bunch) -> np.ndarray:
111+
params_dict = {"bunch": bunch}
112+
return self.__call__(params_dict)
113+
114+
115+
class BunchHistogram2D(BunchHistogram):
116+
def __init__(self, method: str = None, **kwargs) -> None:
117+
super().__init__(**kwargs)
118+
119+
self._grid = Grid2D(
120+
self.shape[0] + 1,
121+
self.shape[1] + 1,
122+
self.limits[0][0],
123+
self.limits[0][1],
124+
self.limits[1][0],
125+
self.limits[1][1],
126+
)
127+
self.method = method
128+
129+
def reset(self) -> None:
130+
self._grid.setZero()
131+
132+
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
133+
# Bin coordinates on grid
134+
if self.method == "bilinear":
135+
self._grid.binBunchBilinear(bunch, self.axis[0], self.axis[1])
136+
else:
137+
self._grid.binBunch(bunch, self.axis[0], self.axis[1])
138+
139+
# Synchronize MPI
140+
comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
141+
self._grid.synchronizeMPI(comm)
142+
143+
# Extract grid values as numpy array
144+
values = np.zeros(self.points.shape[0])
145+
if self.method == "bilinear":
146+
for i, point in enumerate(self.points):
147+
values[i] = self._grid.getValueBilinear(*point)
148+
elif self.method == "nine-point":
149+
for i, point in enumerate(self.points):
150+
values[i] = self._grid.getValue(*point)
151+
else:
152+
index = 0
153+
for i in range(self.shape[0]):
154+
for j in range(self.shape[1]):
155+
values[index] = self._grid.getValueOnGrid(i, j)
156+
index += 1
157+
158+
values = np.reshape(values, self.shape)
159+
return values
160+
161+
162+
class BunchHistogram3D(BunchHistogram):
163+
def __init__(self, method: str = None, **kwargs) -> None:
164+
super().__init__(**kwargs)
165+
166+
self._grid = Grid3D(
167+
self.shape[0] + 1,
168+
self.shape[1] + 1,
169+
self.shape[2] + 1,
170+
self.limits[0][0],
171+
self.limits[0][1],
172+
self.limits[1][0],
173+
self.limits[1][1],
174+
self.limits[2][0],
175+
self.limits[2][1],
176+
)
177+
self.method = method
178+
179+
def reset(self) -> None:
180+
self._grid.setZero()
181+
182+
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
183+
# Bin coordinates on grid
184+
if self.method == "bilinear":
185+
self._grid.binBunchBilinear(bunch, self.axis[0], self.axis[1], self.axis[2])
186+
else:
187+
self._grid.binBunch(bunch, self.axis[0], self.axis[1], self.axis[2])
188+
189+
# Synchronize MPI
190+
comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
191+
self._grid.synchronizeMPI(comm)
192+
193+
# Extract grid values as numpy array
194+
values = np.zeros(self.points.shape[0])
195+
if self.method == "bilinear":
196+
for i, point in enumerate(self.points):
197+
values[i] = self._grid.getValueBilinear(*point)
198+
elif self.method == "nine-point":
199+
for i, point in enumerate(self.points):
200+
values[i] = self._grid.getValue(*point)
201+
else:
202+
index = 0
203+
for i in range(self.shape[0]):
204+
for j in range(self.shape[1]):
205+
for k in range(self.shape[2]):
206+
values[index] = self._grid.getValueOnGrid(i, j, k)
207+
index += 1
208+
209+
values = np.reshape(values, self.shape)
210+
return values

orbit_tools/linac/diag.py

Lines changed: 0 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -82,166 +82,6 @@ def track(self, params_dict: dict) -> None:
8282
bunch.dumpBunch(filename)
8383

8484

85-
class BunchHistogram(Diagnostic):
86-
def __init__(
87-
self,
88-
axis: tuple[int, ...],
89-
shape: tuple[int, ...],
90-
limits: list[tuple[float, float]],
91-
transform: Callable = None,
92-
**kwargs
93-
) -> None:
94-
super().__init__(**kwargs)
95-
96-
self.axis = axis
97-
self.ndim = len(axis)
98-
99-
self.dims = ["x", "xp", "y", "yp", "z", "dE"]
100-
self.dims = [self.dims[i] for i in self.axis]
101-
102-
self.shape = shape
103-
self.limits = limits
104-
self.edges = [
105-
np.linspace(self.limits[i][0], self.limits[i][1], self.shape[i] + 1)
106-
for i in range(self.ndim)
107-
]
108-
self.coords = [0.5 * (e[:-1] + e[1:]) for e in self.edges]
109-
self.points = get_grid_points(self.coords)
110-
self.cell_volume = np.prod([e[1] - e[0] for e in self.edges])
111-
112-
self.transform = transform
113-
114-
def get_filename(self) -> str:
115-
filename = "hist_" + "-".join([str(i) for i in self.axis])
116-
filename = "{}_{:04.0f}".format(filename, self.index)
117-
filename = "{}_{}".format(filename, self.node.getName())
118-
filename = "{}.nc".format(filename)
119-
filename = os.path.join(self.output_dir, filename)
120-
return filename
121-
122-
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
123-
raise NotImplementedError
124-
125-
def track(self, params_dict: dict) -> None:
126-
bunch_copy = Bunch()
127-
128-
bunch = self.bunch
129-
bunch.copyBunchTo(bunch_copy)
130-
131-
if self.transform is not None:
132-
bunch_copy = self.transform(bunch_copy)
133-
134-
values = self.compute_histogram(bunch_copy)
135-
values = values / self.cell_volume
136-
values_sum = np.sum(values)
137-
if values_sum > 0.0:
138-
values = values / values_sum
139-
140-
if self.output_dir is not None:
141-
array = xr.DataArray(values, coords=self.coords, dims=self.dims)
142-
array.to_netcdf(path=self.get_filename(params_dict))
143-
144-
return values
145-
146-
147-
class BunchHistogram2D(BunchHistogram):
148-
def __init__(self, method: str, **kwargs) -> None:
149-
super().__init__(**kwargs)
150-
151-
self._grid = Grid2D(
152-
self.shape[0] + 1,
153-
self.shape[1] + 1,
154-
self.limits[0][0],
155-
self.limits[0][1],
156-
self.limits[1][0],
157-
self.limits[1][1],
158-
)
159-
self.method = method
160-
161-
def reset(self) -> None:
162-
self._grid.setZero()
163-
164-
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
165-
# Bin coordinates on grid
166-
if self.method == "bilinear":
167-
self._grid.binBunchBilinear(bunch, self.axis[0], self.axis[1])
168-
else:
169-
self._grid.binBunch(bunch, self.axis[0], self.axis[1])
170-
171-
# Synchronize MPI
172-
comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
173-
self._grid.synchronizeMPI(comm)
174-
175-
# Extract grid values as numpy array
176-
values = np.zeros(self.points.shape[0])
177-
if self.method == "bilinear":
178-
for i, point in enumerate(self.points):
179-
values[i] = self._grid.getValueBilinear(*point)
180-
elif self.method == "nine-point":
181-
for i, point in enumerate(self.points):
182-
values[i] = self._grid.getValue(*point)
183-
else:
184-
index = 0
185-
for i in range(self.shape[0]):
186-
for j in range(self.shape[1]):
187-
values[index] = self._grid.getValueOnGrid(i, j)
188-
index += 1
189-
190-
values = np.reshape(values, self.shape)
191-
return values
192-
193-
194-
class BunchHistogram3D(BunchHistogram):
195-
def __init__(self, method: str, **kwargs) -> None:
196-
super().__init__(**kwargs)
197-
198-
self._grid = Grid3D(
199-
self.shape[0] + 1,
200-
self.shape[1] + 1,
201-
self.shape[2] + 1,
202-
self.limits[0][0],
203-
self.limits[0][1],
204-
self.limits[1][0],
205-
self.limits[1][1],
206-
self.limits[2][0],
207-
self.limits[2][1],
208-
)
209-
self.method = method
210-
211-
def reset(self) -> None:
212-
self._grid.setZero()
213-
214-
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
215-
# Bin coordinates on grid
216-
if self.method == "bilinear":
217-
self._grid.binBunchBilinear(bunch, self.axis[0], self.axis[1], self.axis[2])
218-
else:
219-
self._grid.binBunch(bunch, self.axis[0], self.axis[1], self.axis[2])
220-
221-
# Synchronize MPI
222-
comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
223-
self._grid.synchronizeMPI(comm)
224-
225-
# Extract grid values as numpy array
226-
values = np.zeros(self.points.shape[0])
227-
if self.method == "bilinear":
228-
for i, point in enumerate(self.points):
229-
values[i] = self._grid.getValueBilinear(*point)
230-
elif self.method == "nine-point":
231-
for i, point in enumerate(self.points):
232-
values[i] = self._grid.getValue(*point)
233-
else:
234-
index = 0
235-
for i in range(self.shape[0]):
236-
for j in range(self.shape[1]):
237-
for k in range(self.shape[2]):
238-
values[index] = self._grid.getValueOnGrid(i, j, k)
239-
index += 1
240-
241-
values = np.reshape(values, self.shape)
242-
return values
243-
244-
24585
class ScalarBunchMonitor(LinacDiagnostic):
24686
def __init__(self, rf_frequency: float = 402.5e06, **kwargs) -> None:
24787
kwargs.setdefault("stride", 0.1)

0 commit comments

Comments
 (0)