11import os
22import sys
33import time
4+ from typing import Any
5+ from typing import Callable
46
57import numpy as np
68
79from orbit .core import orbit_mpi
810from orbit .core .bunch import Bunch
11+ from orbit .core .spacecharge import Grid1D
12+ from orbit .core .spacecharge import Grid2D
13+ from orbit .core .spacecharge import Grid3D
914from orbit .lattice import AccLattice
1015from orbit .lattice import AccNode
1116
1217
18+ def get_grid_points (coords : list [np .ndarray ]) -> np .ndarray :
19+ return np .vstack ([C .ravel () for C in np .meshgrid (* coords , indexing = "ij" )]).T
20+
21+
1322class Diagnostic :
14- def __init__ (self , output_dir : str , verbose : bool = True ) -> None :
23+ def __init__ (self , output_dir : str = None , verbose : bool = True ) -> None :
1524 self ._mpi_comm = orbit_mpi .mpi_comm .MPI_COMM_WORLD
1625 self ._mpi_rank = orbit_mpi .MPI_Comm_rank (self ._mpi_comm )
1726 self .output_dir = output_dir
@@ -33,3 +42,169 @@ def __call__(self, params_dict: dict) -> None:
3342 if not self .should_skip ():
3443 self .track (params_dict )
3544 self .update ()
45+
46+
47+ class BunchHistogram (Diagnostic ):
48+ def __init__ (
49+ self ,
50+ axis : tuple [int , ...],
51+ shape : tuple [int , ...],
52+ limits : list [tuple [float , float ]],
53+ transform : Callable = None ,
54+ ** kwargs
55+ ) -> None :
56+ super ().__init__ (** kwargs )
57+
58+ self .axis = axis
59+ self .ndim = len (axis )
60+
61+ self .dims = ["x" , "xp" , "y" , "yp" , "z" , "dE" ]
62+ self .dims = [self .dims [i ] for i in self .axis ]
63+
64+ self .shape = shape
65+ self .limits = limits
66+ self .edges = [
67+ np .linspace (self .limits [i ][0 ], self .limits [i ][1 ], self .shape [i ] + 1 )
68+ for i in range (self .ndim )
69+ ]
70+ self .coords = [0.5 * (e [:- 1 ] + e [1 :]) for e in self .edges ]
71+ self .values = np .zeros (shape )
72+
73+ self .points = get_grid_points (self .coords )
74+ self .cell_volume = np .prod ([e [1 ] - e [0 ] for e in self .edges ])
75+
76+ self .transform = transform
77+
78+ def get_filename (self ) -> str :
79+ filename = "hist_" + "-" .join ([str (i ) for i in self .axis ])
80+ filename = "{}_{:04.0f}" .format (filename , self .index )
81+ filename = "{}_{}" .format (filename , self .node .getName ())
82+ filename = "{}.nc" .format (filename )
83+ filename = os .path .join (self .output_dir , filename )
84+ return filename
85+
86+ def compute_histogram (self , bunch : Bunch ) -> np .ndarray :
87+ raise NotImplementedError
88+
89+ def __call__ (self , params_dict : dict ) -> np .ndarray :
90+ bunch_copy = Bunch ()
91+
92+ bunch = params_dict ["bunch" ]
93+ bunch .copyBunchTo (bunch_copy )
94+
95+ if self .transform is not None :
96+ bunch_copy = self .transform (bunch_copy )
97+
98+ self .values = self .compute_histogram (bunch_copy )
99+ values_sum = np .sum (self .values )
100+ if values_sum > 0.0 :
101+ self .values = self .values / values_sum
102+ self .values = self .values / self .cell_volume
103+
104+ if self .output_dir is not None :
105+ array = xr .DataArray (values , coords = self .coords , dims = self .dims )
106+ array .to_netcdf (path = self .get_filename (params_dict ))
107+
108+ return self .values
109+
110+ def track (self , bunch : Bunch ) -> np .ndarray :
111+ params_dict = {"bunch" : bunch }
112+ return self .__call__ (params_dict )
113+
114+
115+ class BunchHistogram2D (BunchHistogram ):
116+ def __init__ (self , method : str = None , ** kwargs ) -> None :
117+ super ().__init__ (** kwargs )
118+
119+ self ._grid = Grid2D (
120+ self .shape [0 ] + 1 ,
121+ self .shape [1 ] + 1 ,
122+ self .limits [0 ][0 ],
123+ self .limits [0 ][1 ],
124+ self .limits [1 ][0 ],
125+ self .limits [1 ][1 ],
126+ )
127+ self .method = method
128+
129+ def reset (self ) -> None :
130+ self ._grid .setZero ()
131+
132+ def compute_histogram (self , bunch : Bunch ) -> np .ndarray :
133+ # Bin coordinates on grid
134+ if self .method == "bilinear" :
135+ self ._grid .binBunchBilinear (bunch , self .axis [0 ], self .axis [1 ])
136+ else :
137+ self ._grid .binBunch (bunch , self .axis [0 ], self .axis [1 ])
138+
139+ # Synchronize MPI
140+ comm = orbit_mpi .mpi_comm .MPI_COMM_WORLD
141+ self ._grid .synchronizeMPI (comm )
142+
143+ # Extract grid values as numpy array
144+ values = np .zeros (self .points .shape [0 ])
145+ if self .method == "bilinear" :
146+ for i , point in enumerate (self .points ):
147+ values [i ] = self ._grid .getValueBilinear (* point )
148+ elif self .method == "nine-point" :
149+ for i , point in enumerate (self .points ):
150+ values [i ] = self ._grid .getValue (* point )
151+ else :
152+ index = 0
153+ for i in range (self .shape [0 ]):
154+ for j in range (self .shape [1 ]):
155+ values [index ] = self ._grid .getValueOnGrid (i , j )
156+ index += 1
157+
158+ values = np .reshape (values , self .shape )
159+ return values
160+
161+
162+ class BunchHistogram3D (BunchHistogram ):
163+ def __init__ (self , method : str = None , ** kwargs ) -> None :
164+ super ().__init__ (** kwargs )
165+
166+ self ._grid = Grid3D (
167+ self .shape [0 ] + 1 ,
168+ self .shape [1 ] + 1 ,
169+ self .shape [2 ] + 1 ,
170+ self .limits [0 ][0 ],
171+ self .limits [0 ][1 ],
172+ self .limits [1 ][0 ],
173+ self .limits [1 ][1 ],
174+ self .limits [2 ][0 ],
175+ self .limits [2 ][1 ],
176+ )
177+ self .method = method
178+
179+ def reset (self ) -> None :
180+ self ._grid .setZero ()
181+
182+ def compute_histogram (self , bunch : Bunch ) -> np .ndarray :
183+ # Bin coordinates on grid
184+ if self .method == "bilinear" :
185+ self ._grid .binBunchBilinear (bunch , self .axis [0 ], self .axis [1 ], self .axis [2 ])
186+ else :
187+ self ._grid .binBunch (bunch , self .axis [0 ], self .axis [1 ], self .axis [2 ])
188+
189+ # Synchronize MPI
190+ comm = orbit_mpi .mpi_comm .MPI_COMM_WORLD
191+ self ._grid .synchronizeMPI (comm )
192+
193+ # Extract grid values as numpy array
194+ values = np .zeros (self .points .shape [0 ])
195+ if self .method == "bilinear" :
196+ for i , point in enumerate (self .points ):
197+ values [i ] = self ._grid .getValueBilinear (* point )
198+ elif self .method == "nine-point" :
199+ for i , point in enumerate (self .points ):
200+ values [i ] = self ._grid .getValue (* point )
201+ else :
202+ index = 0
203+ for i in range (self .shape [0 ]):
204+ for j in range (self .shape [1 ]):
205+ for k in range (self .shape [2 ]):
206+ values [index ] = self ._grid .getValueOnGrid (i , j , k )
207+ index += 1
208+
209+ values = np .reshape (values , self .shape )
210+ return values
0 commit comments