Skip to content

Commit dc24558

Browse files
author
Wood, Tony
committed
Implementation of a pure-python function to collect bunch into dictionary of numpy objects
Each MPI node writes the bunch coordinates to a memory-mapped numpy array in /tmp. The primary rank concatenates them into a single memory-mapped array, and the extras are removed from disk. Also introduces a FileHandler protocol, which can define the schema for handling different filetypes, e.g., numpy binaries, HDF5, etc. The desired FileHandler can be passed as an argument to the functions in `collect_bunch.py`
1 parent 7671a02 commit dc24558

File tree

5 files changed

+416
-2
lines changed

5 files changed

+416
-2
lines changed

py/orbit/bunch_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77

88
from orbit.bunch_utils.particleidnumber import ParticleIdNumber
9+
from orbit.bunch_utils.collect_bunch import collect_bunch, save_bunch, load_bunch
910

1011
__all__ = []
1112
__all__.append("addParticleIdNumbers")
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import os
2+
import pathlib
3+
4+
from orbit.core.bunch import Bunch
5+
from orbit.core import orbit_mpi
6+
from orbit.bunch_utils.file_handler import FileHandler, NumPyHandler, BunchDict
7+
8+
import numpy as np
9+
10+
11+
def collect_bunch(
12+
bunch: Bunch, output_dir: str | pathlib.Path = "/tmp", return_memmap: bool = True
13+
) -> BunchDict | None:
14+
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
15+
Parameters
16+
----------
17+
bunch : Bunch
18+
The PyOrbit::Bunch object from which to collect attributes.
19+
output_dir : str | pathlib.Path, optional
20+
The director to use for temporary storage of the bunch coordinates on each MPI rank.
21+
If None, the bunch will be stored in "/tmp".
22+
Note: take care that the temporary files are created in a directory where all MPI ranks have write access.
23+
return_memmap : bool, optional
24+
Return the bunch coordinates as a memory-mapped NumPy array, otherwise the
25+
entire array is copied into RAM and returned as normal NDArray. Default is True.
26+
Returns
27+
-------
28+
BunchDict | None
29+
A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0.
30+
BunchDict structure:
31+
{
32+
"coords": NDArray[np.float64] of shape (N, 6) where N is the total number of macroparticles,
33+
and the 6 columns correspond to [x, xp, y, yp, z, dE] in units of [m, rad, m, rad, m, GeV], respectively.
34+
"sync_part": {
35+
"coords": NDArray[np.float64] of shape (3,),
36+
"kin_energy": np.float64,
37+
"momentum": np.float64,
38+
"beta": np.float64,
39+
"gamma": np.float64,
40+
"time": np.float64
41+
},
42+
"attributes": {
43+
<bunch attribute name>: <attribute value (np.float64 or np.int32)>,
44+
...
45+
}
46+
}
47+
Raises
48+
------
49+
FileNotFoundError
50+
If the temporary files created by non-root MPI ranks could not be found by the root rank during
51+
the collection process.
52+
"""
53+
54+
global_size = bunch.getSizeGlobal()
55+
56+
if global_size == 0:
57+
return None
58+
59+
mpi_comm = bunch.getMPIComm()
60+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
61+
62+
coords_shape = (bunch.getSizeGlobal(), 6)
63+
64+
local_rows = bunch.getSize()
65+
66+
if isinstance(output_dir, str):
67+
output_dir = pathlib.Path(output_dir)
68+
69+
fname = output_dir / f"collect_bunch_tmpfile_{mpi_rank}.dat"
70+
71+
local_shape = (local_rows, coords_shape[1])
72+
dtype = np.float64
73+
coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape)
74+
75+
for i in range(local_rows):
76+
coords_memmap[i, :] = (
77+
bunch.x(i),
78+
bunch.xp(i),
79+
bunch.y(i),
80+
bunch.yp(i),
81+
bunch.z(i),
82+
bunch.dE(i),
83+
)
84+
85+
coords_memmap.flush()
86+
87+
bunch_dict: BunchDict = {"coords": None, "sync_part": {}, "attributes": {}}
88+
89+
if mpi_rank == 0:
90+
sync_part = bunch.getSyncParticle()
91+
92+
bunch_dict["sync_part"] |= {
93+
"coords": np.array(sync_part.pVector()),
94+
"kin_energy": np.float64(sync_part.kinEnergy()),
95+
"momentum": np.float64(sync_part.momentum()),
96+
"beta": np.float64(sync_part.beta()),
97+
"gamma": np.float64(sync_part.gamma()),
98+
"time": np.float64(sync_part.time()),
99+
}
100+
101+
for attr in bunch.bunchAttrDoubleNames():
102+
bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr))
103+
104+
for attr in bunch.bunchAttrIntNames():
105+
bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr))
106+
107+
orbit_mpi.MPI_Barrier(mpi_comm)
108+
109+
if mpi_rank != 0:
110+
return None
111+
112+
coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape)
113+
114+
start_row = local_rows
115+
116+
for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)):
117+
src_fname = output_dir / f"collect_bunch_tmpfile_{r}.dat"
118+
119+
if not os.path.exists(src_fname):
120+
raise FileNotFoundError(
121+
f"Expected temporary file '{src_fname}' not found. Something went wrong."
122+
)
123+
124+
src_memmap = np.memmap(src_fname, dtype=dtype, mode="r")
125+
src_memmap = src_memmap.reshape((-1, coords_shape[1]))
126+
127+
stop_row = start_row + src_memmap.shape[0]
128+
129+
coords_memmap[start_row:stop_row, :] = src_memmap[:, :]
130+
coords_memmap.flush()
131+
132+
del src_memmap
133+
os.remove(src_fname)
134+
start_row = stop_row
135+
136+
bunch_dict["coords"] = coords_memmap if return_memmap else np.array(coords_memmap)
137+
138+
return bunch_dict
139+
140+
141+
def save_bunch(
142+
bunch: Bunch | BunchDict,
143+
output_dir: str | pathlib.Path = "bunch_data/",
144+
Handler: type[FileHandler] = NumPyHandler,
145+
) -> None:
146+
"""Saves the collected bunch attributes to a specified directory.
147+
Parameters
148+
----------
149+
bunch_dict : Bunch | BunchDict
150+
The PyOrbit::Bunch object or the dictionary containing the collected bunch attributes.
151+
output_dir : str, optional
152+
The directory where the bunch data files will be saved. Default is "bunch_data/".
153+
Handler : FileHandler, optional
154+
The file handler class to use for writing the bunch data. Default is NumPyHandler.
155+
Returns
156+
-------
157+
None
158+
Raises
159+
------
160+
ValueError
161+
If the provided `bunch` is neither a Bunch instance nor a BunchDict.
162+
"""
163+
164+
if isinstance(bunch, Bunch):
165+
mpi_comm = bunch.getMPIComm()
166+
bunch = collect_bunch(bunch)
167+
else:
168+
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
169+
170+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
171+
172+
if mpi_rank != 0 or bunch is None:
173+
return
174+
175+
if bunch["coords"].shape[0] == 0:
176+
print("No particles in the bunch to save.")
177+
return
178+
179+
if isinstance(output_dir, str):
180+
output_dir = pathlib.Path(output_dir)
181+
182+
handler = Handler(output_dir)
183+
handler.write(bunch)
184+
185+
186+
def load_bunch(
187+
input_dir: str | pathlib.Path, Handler: type[FileHandler] = NumPyHandler
188+
) -> tuple[Bunch, BunchDict]:
189+
"""Loads the bunch attributes from a specified directory containing NumPy binary files.
190+
Parameters
191+
----------
192+
input_dir : str | pathlib.Path
193+
The directory from which to load the bunch data files.
194+
Handler : FileHandler, optional
195+
The file handler class to use for reading the bunch data. Default is NumPyHandler.
196+
See `orbit.bunch_utils.file_handler` for available handlers.
197+
Returns
198+
-------
199+
BunchDict
200+
A dictionary containing the loaded bunch attributes.
201+
Raises
202+
------
203+
FileNotFoundError
204+
If the required files are not found in the specified directory.
205+
TypeError
206+
If an attribute in the loaded bunch has an unsupported type.
207+
"""
208+
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
209+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
210+
mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm)
211+
212+
handler = Handler(input_dir)
213+
214+
bunch_dict = handler.read()
215+
216+
coords = bunch_dict["coords"]
217+
218+
global_size = coords.shape[0]
219+
220+
local_size = global_size // mpi_size
221+
remainder = global_size % mpi_size
222+
if mpi_rank < remainder:
223+
local_size += 1
224+
start_row = mpi_rank * local_size
225+
else:
226+
start_row = mpi_rank * local_size + remainder
227+
stop_row = start_row + local_size
228+
229+
local_coords = coords[start_row:stop_row, :]
230+
231+
bunch = Bunch()
232+
233+
for i in range(local_size):
234+
bunch.addParticle(*local_coords[i, :])
235+
236+
for attr, value in bunch_dict["attributes"].items():
237+
if np.issubdtype(value, np.floating):
238+
bunch.bunchAttrDouble(attr, value)
239+
elif np.issubdtype(value, np.integer):
240+
bunch.bunchAttrInt(attr, value)
241+
else:
242+
raise TypeError(f"Unsupported attribute type for '{attr}': {type(value)}")
243+
244+
sync_part_obj = bunch.getSyncParticle()
245+
sync_part_obj.rVector(tuple(bunch_dict["sync_part"]["coords"]))
246+
sync_part_obj.kinEnergy(bunch_dict["sync_part"]["kin_energy"])
247+
sync_part_obj.time(bunch_dict["sync_part"]["time"])
248+
249+
return bunch, bunch_dict
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pathlib
2+
from typing import Any, Protocol, TypedDict
3+
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
7+
8+
class SyncPartDict(TypedDict):
9+
coords: NDArray[np.float64]
10+
kin_energy: np.float64
11+
momentum: np.float64
12+
beta: np.float64
13+
gamma: np.float64
14+
time: np.float64
15+
16+
17+
class BunchDict(TypedDict):
18+
coords: NDArray[np.float64]
19+
sync_part: SyncPartDict
20+
attributes: dict[str, np.float64 | np.int32]
21+
22+
23+
class FileHandler(Protocol):
24+
"""Protocol for file handlers to read/write bunch data."""
25+
26+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
27+
28+
def read(self) -> BunchDict: ...
29+
30+
def write(self, bunch: BunchDict) -> None: ...
31+
32+
33+
class NumPyHandler:
34+
"""Handler implementing the FileHandler protocol for NumPy binary files.
35+
This handler will create two files in the directory passed to the constructor:
36+
- coords.npy: A memory-mapped NumPy array containing the bunch coordinates.
37+
- attributes.npz: A NumPy archive containing data related to the synchronous particle and other bunch attributes.
38+
"""
39+
40+
_coords_fname = "coords.npy"
41+
_attributes_fname = "attributes.npz"
42+
43+
def __init__(self, dir_name: str | pathlib.Path):
44+
if isinstance(dir_name, str):
45+
dir_name = pathlib.Path(dir_name)
46+
self._dir_name = dir_name
47+
self._coords_path = dir_name / self._coords_fname
48+
self._attributes_path = dir_name / self._attributes_fname
49+
50+
def read(self) -> BunchDict:
51+
if not self._coords_path.exists() or not self._attributes_path.exists():
52+
raise FileNotFoundError(
53+
f"Required files not found in directory: {self._dir_name}"
54+
)
55+
56+
coords = np.load(self._coords_path, mmap_mode="r")
57+
58+
attr_data = np.load(self._attributes_path, allow_pickle=True)
59+
60+
sync_part = attr_data["sync_part"].item()
61+
attributes = attr_data["attributes"].item()
62+
63+
return BunchDict(coords=coords, sync_part=sync_part, attributes=attributes)
64+
65+
def write(self, bunch: BunchDict) -> None:
66+
self._dir_name.mkdir(parents=True, exist_ok=True)
67+
np.save(self._coords_path, bunch["coords"])
68+
np.savez(
69+
self._attributes_path,
70+
sync_part=bunch["sync_part"],
71+
attributes=bunch["attributes"],
72+
)
73+
74+
75+
class HDF5Handler:
76+
# TODO
77+
def __init__(self):
78+
raise NotImplementedError("HDF5Handler is not yet implemented.")
79+
80+
def read(self) -> BunchDict:
81+
raise NotImplementedError("HDF5Handler is not yet implemented.")
82+
83+
def write(self, bunch: BunchDict) -> None:
84+
raise NotImplementedError("HDF5Handler is not yet implemented.")

py/orbit/bunch_utils/meson.build

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
py_sources = files([
55
'__init__.py',
6-
'particleidnumber.py'
6+
'particleidnumber.py',
7+
'collect_bunch.py',
8+
'file_handler.py'
79
])
810

911
python.install_sources(
1012
py_sources,
1113
subdir: 'orbit/bunch_utils',
1214
# pure: true,
13-
)
15+
)

0 commit comments

Comments
 (0)