@@ -21,10 +21,11 @@ def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
2121
2222
2323def make_grid (
24- axis : tuple [int , ...], shape : tuple [int , ...], limits : list [tuple [float , float ]]
24+ shape : tuple [int , ...],
25+ limits : list [tuple [float , float ]]
2526) -> Union [Grid1D , Grid2D , Grid3D ]:
2627
27- ndim = len (axis )
28+ ndim = len (shape )
2829
2930 grid = None
3031 if ndim == 1 :
@@ -43,13 +44,10 @@ def make_grid(
4344 shape [0 ] + 1 ,
4445 shape [1 ] + 1 ,
4546 shape [2 ] + 1 ,
46- limits [0 ][0 ],
47- limits [0 ][1 ],
48- limits [1 ][0 ],
49- limits [1 ][1 ],
50- limits [2 ][0 ],
51- limits [2 ][1 ],
5247 )
48+ grid .setGridX (limits [0 ][0 ], limits [0 ][1 ])
49+ grid .setGridY (limits [1 ][0 ], limits [1 ][1 ])
50+ grid .setGridZ (limits [2 ][0 ], limits [2 ][1 ])
5351 else :
5452 raise ValueError
5553
@@ -89,13 +87,17 @@ def __init__(
8987 limits : list [tuple [float , float ]],
9088 method : str = None ,
9189 transform : Callable = None ,
90+ normalize : bool = True ,
9291 ** kwargs
9392 ) -> None :
9493 super ().__init__ (** kwargs )
9594
9695 self .axis = axis
9796 self .ndim = len (axis )
9897
98+ if self .ndim > 2 :
99+ raise NotImplementedError ("BunchHistogram does not yet support 3D grids. See https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/46 and https://github.com/PyORBIT-Collaboration/PyORBIT3/issues/47." )
100+
99101 self .dims = ["x" , "xp" , "y" , "yp" , "z" , "dE" ]
100102 self .dims = [self .dims [i ] for i in self .axis ]
101103
@@ -111,22 +113,29 @@ def __init__(
111113 self .points = get_grid_points (self .coords )
112114 self .cell_volume = np .prod ([e [1 ] - e [0 ] for e in self .edges ])
113115
114- self .grid = make_grid (axis = self .axis , shape = self . shape , limits = self .limits )
116+ self .grid = make_grid (self .shape , self .limits )
115117 self .method = method
116118 self .transform = transform
119+ self .normalize = normalize
117120
118121 def reset (self ) -> None :
119122 self .grid .setZero ()
120123
121124 def sync_mpi (self ) -> None :
122125 self .grid .synchronizeMPI (self .mpi_comm )
123126
124- def bin_bunch (self , bunch : Bunch ) -> None :
127+ def bin_bunch (self , bunch : Bunch ) -> None :
128+ macrosize = bunch .macroSize ()
129+ if macrosize == 0 :
130+ bunch .macroSize (1.0 )
131+
125132 if self .method == "bilinear" :
126133 self .grid .binBunchBilinear (bunch , * self .axis )
127134 else :
128135 self .grid .binBunch (bunch , * self .axis )
129136
137+ bunch .macroSize (macrosize )
138+
130139 def compute_histogram (self , bunch : Bunch ) -> np .ndarray :
131140 self .bin_bunch (bunch )
132141 self .sync_mpi ()
@@ -142,6 +151,12 @@ def compute_histogram(self, bunch: Bunch) -> np.ndarray:
142151 for i , indices in enumerate (np .ndindex (* self .shape )):
143152 values [i ] = self .grid .getValueOnGrid (* indices )
144153 values = np .reshape (values , self .shape )
154+
155+ if self .normalize :
156+ values_sum = np .sum (values )
157+ if values_sum > 0.0 :
158+ values /= values_sum
159+ values /= self .cell_volume
145160 return values
146161
147162 def track (self , params_dict : dict ) -> None :
@@ -152,13 +167,8 @@ def track(self, params_dict: dict) -> None:
152167 if self .transform is not None :
153168 bunch_copy = self .transform (bunch_copy )
154169
155- values = self .compute_histogram (bunch_copy )
156- values_sum = np .sum (values )
157- if values_sum > 0.0 :
158- values /= values_sum
159- values /= self .cell_volume
160-
161- self .values = values
170+ self .reset ()
171+ self .values = self .compute_histogram (bunch_copy )
162172
163173 if self .output_dir is not None :
164174 array = xr .DataArray (self .values , coords = self .coords , dims = self .dims )
0 commit comments