Skip to content

Commit b9c1bcb

Browse files
authored
Merge pull request PyORBIT-Collaboration#72 from woodtp/feature/py-bunch-collect
feat: function to collect bunch across MPI ranks into a single Python dictionary
2 parents 7671a02 + a433036 commit b9c1bcb

File tree

4 files changed

+405
-2
lines changed

4 files changed

+405
-2
lines changed

py/orbit/bunch_utils/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,12 @@
77

88
from orbit.bunch_utils.particleidnumber import ParticleIdNumber
99

10+
# This guards against missing numpy.
11+
# Should be imporved with some meaningful (and MPI friendly?) warning printed out.
12+
try:
13+
from orbit.bunch_utils.serialize import collect_bunch, save_bunch, load_bunch
14+
except:
15+
pass
16+
1017
__all__ = []
1118
__all__.append("addParticleIdNumbers")

py/orbit/bunch_utils/meson.build

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
py_sources = files([
55
'__init__.py',
6-
'particleidnumber.py'
6+
'particleidnumber.py',
7+
'serialize.py',
78
])
89

910
python.install_sources(
1011
py_sources,
1112
subdir: 'orbit/bunch_utils',
1213
# pure: true,
13-
)
14+
)

py/orbit/bunch_utils/serialize.py

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
import os
2+
import pathlib
3+
from typing import Any, Protocol, TypedDict
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from orbit.core import orbit_mpi
9+
from orbit.core.bunch import Bunch
10+
11+
12+
class SyncPartDict(TypedDict):
13+
coords: NDArray[np.float64]
14+
kin_energy: np.float64
15+
momentum: np.float64
16+
beta: np.float64
17+
gamma: np.float64
18+
time: np.float64
19+
20+
21+
class BunchDict(TypedDict):
22+
coords: NDArray[np.float64]
23+
sync_part: SyncPartDict
24+
attributes: dict[str, np.float64 | np.int32]
25+
26+
27+
class FileHandler(Protocol):
28+
"""Protocol for file handlers to read/write bunch data."""
29+
30+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
31+
32+
def read(self) -> BunchDict: ...
33+
34+
def write(self, bunch: BunchDict) -> None: ...
35+
36+
37+
class NumPyHandler:
38+
"""Handler implementing the FileHandler protocol for NumPy binary files.
39+
This handler will create two files in the directory passed to the constructor:
40+
- coords.npy: A memory-mapped NumPy array containing the bunch coordinates.
41+
- attributes.npz: A NumPy archive containing data related to the synchronous particle and other bunch attributes.
42+
"""
43+
44+
_coords_fname = "coords.npy"
45+
_attributes_fname = "attributes.npz"
46+
47+
def __init__(self, dir_name: str | pathlib.Path):
48+
if isinstance(dir_name, str):
49+
dir_name = pathlib.Path(dir_name)
50+
self._dir_name = dir_name
51+
self._coords_path = dir_name / self._coords_fname
52+
self._attributes_path = dir_name / self._attributes_fname
53+
54+
def read(self) -> BunchDict:
55+
if not self._coords_path.exists() or not self._attributes_path.exists():
56+
raise FileNotFoundError(
57+
f"Required files not found in directory: {self._dir_name}"
58+
)
59+
60+
coords = np.load(self._coords_path, mmap_mode="r")
61+
62+
attr_data = np.load(self._attributes_path, allow_pickle=True)
63+
64+
sync_part = attr_data["sync_part"].item()
65+
attributes = attr_data["attributes"].item()
66+
67+
return BunchDict(coords=coords, sync_part=sync_part, attributes=attributes)
68+
69+
def write(self, bunch: BunchDict) -> None:
70+
self._dir_name.mkdir(parents=True, exist_ok=True)
71+
np.save(self._coords_path, bunch["coords"])
72+
np.savez(
73+
self._attributes_path,
74+
sync_part=bunch["sync_part"],
75+
attributes=bunch["attributes"],
76+
)
77+
78+
79+
def collect_bunch(
80+
bunch: Bunch, output_dir: str | pathlib.Path = "/tmp", return_memmap: bool = True
81+
) -> BunchDict | None:
82+
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
83+
Parameters
84+
----------
85+
bunch : Bunch
86+
The PyOrbit::Bunch object from which to collect attributes.
87+
output_dir : str | pathlib.Path, optional
88+
The director to use for temporary storage of the bunch coordinates on each MPI rank.
89+
If None, the bunch will be stored in "/tmp".
90+
Note: take care that the temporary files are created in a directory where all MPI ranks have write access.
91+
return_memmap : bool, optional
92+
Return the bunch coordinates as a memory-mapped NumPy array, otherwise the
93+
entire array is copied into RAM and returned as normal NDArray. Default is True.
94+
Returns
95+
-------
96+
BunchDict | None
97+
A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0.
98+
BunchDict structure:
99+
{
100+
"coords": NDArray[np.float64] of shape (N, 6) where N is the total number of macroparticles,
101+
and the 6 columns correspond to [x, xp, y, yp, z, dE] in units of [m, rad, m, rad, m, GeV], respectively.
102+
"sync_part": {
103+
"coords": NDArray[np.float64] of shape (3,),
104+
"kin_energy": np.float64,
105+
"momentum": np.float64,
106+
"beta": np.float64,
107+
"gamma": np.float64,
108+
"time": np.float64
109+
},
110+
"attributes": {
111+
<bunch attribute name>: <attribute value (np.float64 or np.int32)>,
112+
...
113+
}
114+
}
115+
Raises
116+
------
117+
FileNotFoundError
118+
If the temporary files created by non-root MPI ranks could not be found by the root rank during
119+
the collection process.
120+
"""
121+
122+
global_size = bunch.getSizeGlobal()
123+
124+
if global_size == 0:
125+
return None
126+
127+
mpi_comm = bunch.getMPIComm()
128+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
129+
130+
coords_shape = (bunch.getSizeGlobal(), 6)
131+
132+
local_rows = bunch.getSize()
133+
134+
if isinstance(output_dir, str):
135+
output_dir = pathlib.Path(output_dir)
136+
137+
fname = output_dir / f"collect_bunch_tmpfile_{mpi_rank}.dat"
138+
139+
local_shape = (local_rows, coords_shape[1])
140+
dtype = np.float64
141+
coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape)
142+
143+
for i in range(local_rows):
144+
coords_memmap[i, :] = (
145+
bunch.x(i),
146+
bunch.xp(i),
147+
bunch.y(i),
148+
bunch.yp(i),
149+
bunch.z(i),
150+
bunch.dE(i),
151+
)
152+
153+
coords_memmap.flush()
154+
155+
bunch_dict: BunchDict = {"coords": None, "sync_part": {}, "attributes": {}}
156+
157+
if mpi_rank == 0:
158+
sync_part = bunch.getSyncParticle()
159+
160+
bunch_dict["sync_part"] |= {
161+
"coords": np.array(sync_part.pVector()),
162+
"kin_energy": np.float64(sync_part.kinEnergy()),
163+
"momentum": np.float64(sync_part.momentum()),
164+
"beta": np.float64(sync_part.beta()),
165+
"gamma": np.float64(sync_part.gamma()),
166+
"time": np.float64(sync_part.time()),
167+
}
168+
169+
for attr in bunch.bunchAttrDoubleNames():
170+
bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr))
171+
172+
for attr in bunch.bunchAttrIntNames():
173+
bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr))
174+
175+
orbit_mpi.MPI_Barrier(mpi_comm)
176+
177+
if mpi_rank != 0:
178+
return None
179+
180+
coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape)
181+
182+
start_row = local_rows
183+
184+
for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)):
185+
src_fname = output_dir / f"collect_bunch_tmpfile_{r}.dat"
186+
187+
if not os.path.exists(src_fname):
188+
raise FileNotFoundError(
189+
f"Expected temporary file '{src_fname}' not found. Something went wrong."
190+
)
191+
192+
src_memmap = np.memmap(src_fname, dtype=dtype, mode="r")
193+
src_memmap = src_memmap.reshape((-1, coords_shape[1]))
194+
195+
stop_row = start_row + src_memmap.shape[0]
196+
197+
coords_memmap[start_row:stop_row, :] = src_memmap[:, :]
198+
coords_memmap.flush()
199+
200+
del src_memmap
201+
os.remove(src_fname)
202+
start_row = stop_row
203+
204+
bunch_dict["coords"] = coords_memmap if return_memmap else np.array(coords_memmap)
205+
206+
return bunch_dict
207+
208+
209+
def save_bunch(
210+
bunch: Bunch | BunchDict,
211+
output_dir: str | pathlib.Path = "bunch_data/",
212+
Handler: type[FileHandler] = NumPyHandler,
213+
) -> None:
214+
"""Saves the collected bunch attributes to a specified directory.
215+
Parameters
216+
----------
217+
bunch_dict : Bunch | BunchDict
218+
The PyOrbit::Bunch object or the dictionary containing the collected bunch attributes.
219+
output_dir : str, optional
220+
The directory where the bunch data files will be saved. Default is "bunch_data/".
221+
Handler : FileHandler, optional
222+
The file handler class to use for writing the bunch data. Default is NumPyHandler.
223+
Returns
224+
-------
225+
None
226+
Raises
227+
------
228+
ValueError
229+
If the provided `bunch` is neither a Bunch instance nor a BunchDict.
230+
"""
231+
232+
if isinstance(bunch, Bunch):
233+
mpi_comm = bunch.getMPIComm()
234+
bunch = collect_bunch(bunch)
235+
else:
236+
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
237+
238+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
239+
240+
if mpi_rank != 0 or bunch is None:
241+
return
242+
243+
if bunch["coords"].shape[0] == 0:
244+
print("No particles in the bunch to save.")
245+
return
246+
247+
if isinstance(output_dir, str):
248+
output_dir = pathlib.Path(output_dir)
249+
250+
handler = Handler(output_dir)
251+
handler.write(bunch)
252+
253+
254+
def load_bunch(
255+
input_dir: str | pathlib.Path, Handler: type[FileHandler] = NumPyHandler
256+
) -> tuple[Bunch, BunchDict]:
257+
"""Loads the bunch attributes from a specified directory containing NumPy binary files.
258+
Parameters
259+
----------
260+
input_dir : str | pathlib.Path
261+
The directory from which to load the bunch data files.
262+
Handler : FileHandler, optional
263+
The file handler class to use for reading the bunch data. Default is NumPyHandler.
264+
See `orbit.bunch_utils.file_handler` for available handlers.
265+
Returns
266+
-------
267+
BunchDict
268+
A dictionary containing the loaded bunch attributes.
269+
Raises
270+
------
271+
FileNotFoundError
272+
If the required files are not found in the specified directory.
273+
TypeError
274+
If an attribute in the loaded bunch has an unsupported type.
275+
"""
276+
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
277+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
278+
mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm)
279+
280+
handler = Handler(input_dir)
281+
282+
bunch_dict = handler.read()
283+
284+
coords = bunch_dict["coords"]
285+
286+
global_size = coords.shape[0]
287+
288+
local_size = global_size // mpi_size
289+
remainder = global_size % mpi_size
290+
if mpi_rank < remainder:
291+
local_size += 1
292+
start_row = mpi_rank * local_size
293+
else:
294+
start_row = mpi_rank * local_size + remainder
295+
stop_row = start_row + local_size
296+
297+
local_coords = coords[start_row:stop_row, :]
298+
299+
bunch = Bunch()
300+
301+
for i in range(local_size):
302+
bunch.addParticle(*local_coords[i, :])
303+
304+
for attr, value in bunch_dict["attributes"].items():
305+
if np.issubdtype(value, np.floating):
306+
bunch.bunchAttrDouble(attr, value)
307+
elif np.issubdtype(value, np.integer):
308+
bunch.bunchAttrInt(attr, value)
309+
else:
310+
raise TypeError(f"Unsupported attribute type for '{attr}': {type(value)}")
311+
312+
sync_part_obj = bunch.getSyncParticle()
313+
sync_part_obj.rVector(tuple(bunch_dict["sync_part"]["coords"]))
314+
sync_part_obj.kinEnergy(bunch_dict["sync_part"]["kin_energy"])
315+
sync_part_obj.time(bunch_dict["sync_part"]["time"])
316+
317+
return bunch, bunch_dict

0 commit comments

Comments
 (0)