Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions py/orbit/bunch_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,12 @@

from orbit.bunch_utils.particleidnumber import ParticleIdNumber

# This guards against missing numpy.
# Should be imporved with some meaningful (and MPI friendly?) warning printed out.
try:
from orbit.bunch_utils.serialize import collect_bunch, save_bunch, load_bunch
except:
pass

__all__ = []
__all__.append("addParticleIdNumbers")
5 changes: 3 additions & 2 deletions py/orbit/bunch_utils/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

py_sources = files([
'__init__.py',
'particleidnumber.py'
'particleidnumber.py',
'serialize.py',
])

python.install_sources(
py_sources,
subdir: 'orbit/bunch_utils',
# pure: true,
)
)
317 changes: 317 additions & 0 deletions py/orbit/bunch_utils/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import os
import pathlib
from typing import Any, Protocol, TypedDict

import numpy as np
from numpy.typing import NDArray

from orbit.core import orbit_mpi
from orbit.core.bunch import Bunch


class SyncPartDict(TypedDict):
coords: NDArray[np.float64]
kin_energy: np.float64
momentum: np.float64
beta: np.float64
gamma: np.float64
time: np.float64


class BunchDict(TypedDict):
coords: NDArray[np.float64]
sync_part: SyncPartDict
attributes: dict[str, np.float64 | np.int32]


class FileHandler(Protocol):
"""Protocol for file handlers to read/write bunch data."""

def __init__(self, *args: Any, **kwargs: Any) -> None: ...

def read(self) -> BunchDict: ...

def write(self, bunch: BunchDict) -> None: ...


class NumPyHandler:
"""Handler implementing the FileHandler protocol for NumPy binary files.
This handler will create two files in the directory passed to the constructor:
- coords.npy: A memory-mapped NumPy array containing the bunch coordinates.
- attributes.npz: A NumPy archive containing data related to the synchronous particle and other bunch attributes.
"""

_coords_fname = "coords.npy"
_attributes_fname = "attributes.npz"

def __init__(self, dir_name: str | pathlib.Path):
if isinstance(dir_name, str):
dir_name = pathlib.Path(dir_name)
self._dir_name = dir_name
self._coords_path = dir_name / self._coords_fname
self._attributes_path = dir_name / self._attributes_fname

def read(self) -> BunchDict:
if not self._coords_path.exists() or not self._attributes_path.exists():
raise FileNotFoundError(
f"Required files not found in directory: {self._dir_name}"
)

coords = np.load(self._coords_path, mmap_mode="r")

attr_data = np.load(self._attributes_path, allow_pickle=True)

sync_part = attr_data["sync_part"].item()
attributes = attr_data["attributes"].item()

return BunchDict(coords=coords, sync_part=sync_part, attributes=attributes)

def write(self, bunch: BunchDict) -> None:
self._dir_name.mkdir(parents=True, exist_ok=True)
np.save(self._coords_path, bunch["coords"])
np.savez(
self._attributes_path,
sync_part=bunch["sync_part"],
attributes=bunch["attributes"],
)


def collect_bunch(
bunch: Bunch, output_dir: str | pathlib.Path = "/tmp", return_memmap: bool = True
) -> BunchDict | None:
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
Parameters
----------
bunch : Bunch
The PyOrbit::Bunch object from which to collect attributes.
output_dir : str | pathlib.Path, optional
The director to use for temporary storage of the bunch coordinates on each MPI rank.
If None, the bunch will be stored in "/tmp".
Note: take care that the temporary files are created in a directory where all MPI ranks have write access.
return_memmap : bool, optional
Return the bunch coordinates as a memory-mapped NumPy array, otherwise the
entire array is copied into RAM and returned as normal NDArray. Default is True.
Returns
-------
BunchDict | None
A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0.
BunchDict structure:
{
"coords": NDArray[np.float64] of shape (N, 6) where N is the total number of macroparticles,
and the 6 columns correspond to [x, xp, y, yp, z, dE] in units of [m, rad, m, rad, m, GeV], respectively.
"sync_part": {
"coords": NDArray[np.float64] of shape (3,),
"kin_energy": np.float64,
"momentum": np.float64,
"beta": np.float64,
"gamma": np.float64,
"time": np.float64
},
"attributes": {
<bunch attribute name>: <attribute value (np.float64 or np.int32)>,
...
}
}
Raises
------
FileNotFoundError
If the temporary files created by non-root MPI ranks could not be found by the root rank during
the collection process.
"""

global_size = bunch.getSizeGlobal()

if global_size == 0:
return None

mpi_comm = bunch.getMPIComm()
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)

coords_shape = (bunch.getSizeGlobal(), 6)

local_rows = bunch.getSize()

if isinstance(output_dir, str):
output_dir = pathlib.Path(output_dir)

fname = output_dir / f"collect_bunch_tmpfile_{mpi_rank}.dat"

local_shape = (local_rows, coords_shape[1])
dtype = np.float64
coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape)

for i in range(local_rows):
coords_memmap[i, :] = (
bunch.x(i),
bunch.xp(i),
bunch.y(i),
bunch.yp(i),
bunch.z(i),
bunch.dE(i),
)

coords_memmap.flush()

bunch_dict: BunchDict = {"coords": None, "sync_part": {}, "attributes": {}}

if mpi_rank == 0:
sync_part = bunch.getSyncParticle()

bunch_dict["sync_part"] |= {
"coords": np.array(sync_part.pVector()),
"kin_energy": np.float64(sync_part.kinEnergy()),
"momentum": np.float64(sync_part.momentum()),
"beta": np.float64(sync_part.beta()),
"gamma": np.float64(sync_part.gamma()),
"time": np.float64(sync_part.time()),
}

for attr in bunch.bunchAttrDoubleNames():
bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr))

for attr in bunch.bunchAttrIntNames():
bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr))

orbit_mpi.MPI_Barrier(mpi_comm)

if mpi_rank != 0:
return None

coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape)

start_row = local_rows

for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)):
src_fname = output_dir / f"collect_bunch_tmpfile_{r}.dat"

if not os.path.exists(src_fname):
raise FileNotFoundError(
f"Expected temporary file '{src_fname}' not found. Something went wrong."
)

src_memmap = np.memmap(src_fname, dtype=dtype, mode="r")
src_memmap = src_memmap.reshape((-1, coords_shape[1]))

stop_row = start_row + src_memmap.shape[0]

coords_memmap[start_row:stop_row, :] = src_memmap[:, :]
coords_memmap.flush()

del src_memmap
os.remove(src_fname)
start_row = stop_row

bunch_dict["coords"] = coords_memmap if return_memmap else np.array(coords_memmap)

return bunch_dict


def save_bunch(
bunch: Bunch | BunchDict,
output_dir: str | pathlib.Path = "bunch_data/",
Handler: type[FileHandler] = NumPyHandler,
) -> None:
"""Saves the collected bunch attributes to a specified directory.
Parameters
----------
bunch_dict : Bunch | BunchDict
The PyOrbit::Bunch object or the dictionary containing the collected bunch attributes.
output_dir : str, optional
The directory where the bunch data files will be saved. Default is "bunch_data/".
Handler : FileHandler, optional
The file handler class to use for writing the bunch data. Default is NumPyHandler.
Returns
-------
None
Raises
------
ValueError
If the provided `bunch` is neither a Bunch instance nor a BunchDict.
"""

if isinstance(bunch, Bunch):
mpi_comm = bunch.getMPIComm()
bunch = collect_bunch(bunch)
else:
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD

mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)

if mpi_rank != 0 or bunch is None:
return

if bunch["coords"].shape[0] == 0:
print("No particles in the bunch to save.")
return

if isinstance(output_dir, str):
output_dir = pathlib.Path(output_dir)

handler = Handler(output_dir)
handler.write(bunch)


def load_bunch(
input_dir: str | pathlib.Path, Handler: type[FileHandler] = NumPyHandler
) -> tuple[Bunch, BunchDict]:
"""Loads the bunch attributes from a specified directory containing NumPy binary files.
Parameters
----------
input_dir : str | pathlib.Path
The directory from which to load the bunch data files.
Handler : FileHandler, optional
The file handler class to use for reading the bunch data. Default is NumPyHandler.
See `orbit.bunch_utils.file_handler` for available handlers.
Returns
-------
BunchDict
A dictionary containing the loaded bunch attributes.
Raises
------
FileNotFoundError
If the required files are not found in the specified directory.
TypeError
If an attribute in the loaded bunch has an unsupported type.
"""
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm)

handler = Handler(input_dir)

bunch_dict = handler.read()

coords = bunch_dict["coords"]

global_size = coords.shape[0]

local_size = global_size // mpi_size
remainder = global_size % mpi_size
if mpi_rank < remainder:
local_size += 1
start_row = mpi_rank * local_size
else:
start_row = mpi_rank * local_size + remainder
stop_row = start_row + local_size

local_coords = coords[start_row:stop_row, :]

bunch = Bunch()

for i in range(local_size):
bunch.addParticle(*local_coords[i, :])

for attr, value in bunch_dict["attributes"].items():
if np.issubdtype(value, np.floating):
bunch.bunchAttrDouble(attr, value)
elif np.issubdtype(value, np.integer):
bunch.bunchAttrInt(attr, value)
else:
raise TypeError(f"Unsupported attribute type for '{attr}': {type(value)}")

sync_part_obj = bunch.getSyncParticle()
sync_part_obj.rVector(tuple(bunch_dict["sync_part"]["coords"]))
sync_part_obj.kinEnergy(bunch_dict["sync_part"]["kin_energy"])
sync_part_obj.time(bunch_dict["sync_part"]["time"])

return bunch, bunch_dict
Loading