Skip to content

Commit d1d57ee

Browse files
committed
Fix bug in histogram class
- 3D histograms not working; only accept 1D/2D for now; see PyORBIT-Collaboration/PyORBIT3#47 and PyORBIT-Collaboration/PyORBIT3#46. - Clear old histogram values before tracking bunch.
1 parent 60ff234 commit d1d57ee

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

orbit_tools/diag.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
2121

2222

2323
def make_grid(
24-
axis: tuple[int, ...], shape: tuple[int, ...], limits: list[tuple[float, float]]
24+
shape: tuple[int, ...],
25+
limits: list[tuple[float, float]]
2526
) -> Union[Grid1D, Grid2D, Grid3D]:
2627

27-
ndim = len(axis)
28+
ndim = len(shape)
2829

2930
grid = None
3031
if ndim == 1:
@@ -43,13 +44,10 @@ def make_grid(
4344
shape[0] + 1,
4445
shape[1] + 1,
4546
shape[2] + 1,
46-
limits[0][0],
47-
limits[0][1],
48-
limits[1][0],
49-
limits[1][1],
50-
limits[2][0],
51-
limits[2][1],
5247
)
48+
grid.setGridX(limits[0][0], limits[0][1])
49+
grid.setGridY(limits[1][0], limits[1][1])
50+
grid.setGridZ(limits[2][0], limits[2][1])
5351
else:
5452
raise ValueError
5553

@@ -89,13 +87,17 @@ def __init__(
8987
limits: list[tuple[float, float]],
9088
method: str = None,
9189
transform: Callable = None,
90+
normalize: bool = True,
9291
**kwargs
9392
) -> None:
9493
super().__init__(**kwargs)
9594

9695
self.axis = axis
9796
self.ndim = len(axis)
9897

98+
if self.ndim > 2:
99+
raise NotImplementedError("BunchHistogram does not yet support 3D grids. See https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/46 and https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/47.")
100+
99101
self.dims = ["x", "xp", "y", "yp", "z", "dE"]
100102
self.dims = [self.dims[i] for i in self.axis]
101103

@@ -111,22 +113,29 @@ def __init__(
111113
self.points = get_grid_points(self.coords)
112114
self.cell_volume = np.prod([e[1] - e[0] for e in self.edges])
113115

114-
self.grid = make_grid(axis=self.axis, shape=self.shape, limits=self.limits)
116+
self.grid = make_grid(self.shape, self.limits)
115117
self.method = method
116118
self.transform = transform
119+
self.normalize = normalize
117120

118121
def reset(self) -> None:
119122
self.grid.setZero()
120123

121124
def sync_mpi(self) -> None:
122125
self.grid.synchronizeMPI(self.mpi_comm)
123126

124-
def bin_bunch(self, bunch: Bunch) -> None:
127+
def bin_bunch(self, bunch: Bunch) -> None:
128+
macrosize = bunch.macroSize()
129+
if macrosize == 0:
130+
bunch.macroSize(1.0)
131+
125132
if self.method == "bilinear":
126133
self.grid.binBunchBilinear(bunch, *self.axis)
127134
else:
128135
self.grid.binBunch(bunch, *self.axis)
129136

137+
bunch.macroSize(macrosize)
138+
130139
def compute_histogram(self, bunch: Bunch) -> np.ndarray:
131140
self.bin_bunch(bunch)
132141
self.sync_mpi()
@@ -142,6 +151,12 @@ def compute_histogram(self, bunch: Bunch) -> np.ndarray:
142151
for i, indices in enumerate(np.ndindex(*self.shape)):
143152
values[i] = self.grid.getValueOnGrid(*indices)
144153
values = np.reshape(values, self.shape)
154+
155+
if self.normalize:
156+
values_sum = np.sum(values)
157+
if values_sum > 0.0:
158+
values /= values_sum
159+
values /= self.cell_volume
145160
return values
146161

147162
def track(self, params_dict: dict) -> None:
@@ -152,13 +167,8 @@ def track(self, params_dict: dict) -> None:
152167
if self.transform is not None:
153168
bunch_copy = self.transform(bunch_copy)
154169

155-
values = self.compute_histogram(bunch_copy)
156-
values_sum = np.sum(values)
157-
if values_sum > 0.0:
158-
values /= values_sum
159-
values /= self.cell_volume
160-
161-
self.values = values
170+
self.reset()
171+
self.values = self.compute_histogram(bunch_copy)
162172

163173
if self.output_dir is not None:
164174
array = xr.DataArray(self.values, coords=self.coords, dims=self.dims)

0 commit comments

Comments
 (0)