Skip to content

Commit 79ee8cb

Browse files
author
Wood, Tony
committed
Implementation of a pure-python function to collect bunch into dictionary of numpy objects
Each MPI node writes the bunch coordinates to a memory-mapped numpy array in /tmp. The primary rank concatenates them into a single memory-mapped array, and the extras are removed from disk.
1 parent f99113c commit 79ee8cb

File tree

4 files changed

+242
-2
lines changed

4 files changed

+242
-2
lines changed

py/orbit/bunch_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77

88
from orbit.bunch_utils.particleidnumber import ParticleIdNumber
9+
from orbit.bunch_utils.collect_bunch import collect_bunch
910

1011
__all__ = []
1112
__all__.append("addParticleIdNumbers")
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
# import tempfile
3+
from enum import IntEnum
4+
from typing import Optional, TypedDict
5+
6+
from orbit.core.bunch import Bunch
7+
from orbit.core import orbit_mpi
8+
9+
import numpy as np
10+
from numpy.typing import NDArray
11+
12+
13+
class SyncPartDict(TypedDict):
14+
coords: NDArray[np.float64]
15+
kin_energy: np.float64
16+
momentum: np.float64
17+
beta: np.float64
18+
gamma: np.float64
19+
time: np.float64
20+
21+
22+
class BunchDict(TypedDict):
23+
coords_array: NDArray[np.float64]
24+
sync_part: SyncPartDict
25+
attr: dict[str, np.float64 | np.int32]
26+
27+
28+
class BunchCoord(IntEnum):
29+
X = 0
30+
XP = 1
31+
Y = 2
32+
YP = 3
33+
Z = 4
34+
DE = 5
35+
36+
37+
def collect_bunch(
38+
bunch: Bunch,
39+
return_memmap: bool = True,
40+
output_fname: Optional[str] = None,
41+
) -> BunchDict | None:
42+
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
43+
Parameters
44+
----------
45+
bunch : Bunch
46+
The PyOrbit::Bunch object from which to collect attributes.
47+
return_memmap : bool, optional
48+
Return the bunch coordinates as a memory-mapped NumPy array, otherwise the
49+
entire array is copied into RAM and returned as normal NDArray. Default is True.
50+
Returns
51+
-------
52+
BunchDict | None
53+
A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0.
54+
BunchDict structure:
55+
{
56+
"coords": NDArray[np.float64] of shape (N, 6) where N is the total number of macroparticles,
57+
and the 6 columns correspond to [x, xp, y, yp, z, dE] in units of [m, rad, m, rad, m, eV], respectively.
58+
"sync_part": {
59+
"coords": NDArray[np.float64] of shape (3,),
60+
"kin_energy": np.float64,
61+
"momentum": np.float64,
62+
"beta": np.float64,
63+
"gamma": np.float64,
64+
"time": np.float64
65+
},
66+
"attributes": {
67+
<bunch attribute name>: <attribute value (np.float64 or np.int32)>,
68+
...
69+
}
70+
}
71+
"""
72+
73+
global_size = bunch.getSizeGlobal()
74+
75+
if global_size == 0:
76+
return None
77+
78+
mpi_comm = bunch.getMPIComm()
79+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
80+
81+
coords_shape = (bunch.getSizeGlobal(), 6)
82+
83+
local_rows = bunch.getSize()
84+
85+
# print(f"[DEBUG] Rank {mpi_rank}: start_row={start_row}, stop_row={stop_row}, local_rows={local_rows} bunch.getSize()={bunch.getSize()}")
86+
87+
# if mpi_rank == 0:
88+
# file_desc, fname = tempfile.mkstemp(suffix=".dat", prefix="collect_bunch_", dir="/tmp")
89+
# os.close(file_desc)
90+
#
91+
# TODO: this doesn't seem to work. "SystemError: PY_SSIZE_T_CLEAN macro must be defined for '#' formats"
92+
# fname = orbit_mpi.MPI_Bcast(fname, orbit_mpi.mpi_datatype.MPI_CHAR, 0, mpi_comm)
93+
94+
# Using a fixed filename in the temp directory for now. Maybe that's sufficient.
95+
fname = f"/tmp/collect_bunch_tmpfile_{mpi_rank}.dat"
96+
97+
local_shape = (local_rows, coords_shape[1])
98+
dtype = np.float64
99+
coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape)
100+
101+
for i in range(local_rows):
102+
coords_memmap[i, BunchCoord.X] = bunch.x(i)
103+
coords_memmap[i, BunchCoord.XP] = bunch.xp(i)
104+
coords_memmap[i, BunchCoord.Y] = bunch.y(i)
105+
coords_memmap[i, BunchCoord.YP] = bunch.yp(i)
106+
coords_memmap[i, BunchCoord.Z] = bunch.z(i)
107+
coords_memmap[i, BunchCoord.DE] = bunch.dE(i)
108+
109+
coords_memmap.flush()
110+
111+
bunch_dict = {"coords": None, "sync_part": {}, "attributes": {}}
112+
113+
if mpi_rank == 0:
114+
sync_part = bunch.getSyncParticle()
115+
116+
bunch_dict["sync_part"] |= {
117+
"coords": np.array(sync_part.pVector()),
118+
"kin_energy": np.float64(sync_part.kinEnergy()),
119+
"momentum": np.float64(sync_part.momentum()),
120+
"beta": np.float64(sync_part.beta()),
121+
"gamma": np.float64(sync_part.gamma()),
122+
"time": np.float64(sync_part.time()),
123+
}
124+
125+
for attr in bunch.bunchAttrDoubleNames():
126+
bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr))
127+
128+
for attr in bunch.bunchAttrIntNames():
129+
bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr))
130+
131+
orbit_mpi.MPI_Barrier(mpi_comm)
132+
133+
if mpi_rank == 0:
134+
coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape)
135+
136+
start_row = local_rows
137+
138+
for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)):
139+
src_fname = f"/tmp/collect_bunch_tmpfile_{r}.dat"
140+
141+
if not os.path.exists(src_fname):
142+
raise FileNotFoundError(f"Expected temporary file '{src_fname}' not found. Something went wrong.")
143+
144+
src_memmap = np.memmap(src_fname, dtype=dtype, mode="r")
145+
src_memmap = src_memmap.reshape((-1, coords_shape[1]))
146+
147+
stop_row = start_row + src_memmap.shape[0]
148+
149+
coords_memmap[start_row:stop_row, :] = src_memmap[:, :]
150+
coords_memmap.flush()
151+
152+
del src_memmap
153+
os.remove(src_fname)
154+
start_row = stop_row
155+
156+
bunch_dict["coords"] = (
157+
coords_memmap if return_memmap else np.array(coords_memmap)
158+
)
159+
160+
return bunch_dict

py/orbit/bunch_utils/meson.build

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
py_sources = files([
55
'__init__.py',
6-
'particleidnumber.py'
6+
'particleidnumber.py',
7+
'collect_bunch.py'
78
])
89

910
python.install_sources(
1011
py_sources,
1112
subdir: 'orbit/bunch_utils',
1213
# pure: true,
13-
)
14+
)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from orbit.core.bunch import Bunch
2+
from orbit.bunch_generators import GaussDist3D
3+
from orbit.bunch_utils import collect_bunch
4+
5+
from pytest import fixture
6+
7+
8+
@fixture
9+
def bunch():
10+
bunch = Bunch()
11+
bunch.mass(0.939294)
12+
bunch.charge(-1.0)
13+
bunch.getSyncParticle().kinEnergy(0.0025)
14+
gauss_dist = GaussDist3D()
15+
for i in range(10):
16+
x, xp, y, yp, z, dE = gauss_dist.getCoordinates()
17+
bunch.addParticle(x, xp, y, yp, z, dE)
18+
bunch.macroSize(10)
19+
return bunch
20+
21+
22+
def test_collect_bunch(bunch):
23+
d = collect_bunch(bunch, return_memmap=False)
24+
25+
n_particles = bunch.getSize()
26+
27+
toplevel_keys = {"coords", "sync_part", "attributes"}
28+
29+
attribute_keys = {"charge", "classical_radius", "mass", "macro_size"}
30+
sync_part_keys = {"coords", "kin_energy", "momentum", "beta", "gamma", "time"}
31+
32+
x, xp, y, yp, z, dE = [], [], [], [], [], []
33+
for i in range(n_particles):
34+
x.append(bunch.x(i))
35+
xp.append(bunch.px(i))
36+
y.append(bunch.y(i))
37+
yp.append(bunch.py(i))
38+
z.append(bunch.z(i))
39+
dE.append(bunch.dE(i))
40+
41+
assert set(d.keys()) == toplevel_keys
42+
assert set(d["sync_part"].keys()) == sync_part_keys
43+
assert set(d["attributes"].keys()) == attribute_keys
44+
45+
assert d["coords"].shape == (n_particles, 6)
46+
47+
assert d["attributes"]["charge"] == bunch.bunchAttrDouble("charge")
48+
assert d["attributes"]["classical_radius"] == bunch.bunchAttrDouble(
49+
"classical_radius"
50+
)
51+
assert d["attributes"]["mass"] == bunch.bunchAttrDouble("mass")
52+
assert d["attributes"]["macro_size"] == bunch.bunchAttrDouble("macro_size")
53+
54+
sync_part = bunch.getSyncParticle()
55+
assert (d["sync_part"]["coords"] == sync_part.pVector()).all()
56+
assert d["sync_part"]["kin_energy"] == sync_part.kinEnergy()
57+
assert d["sync_part"]["momentum"] == sync_part.momentum()
58+
assert d["sync_part"]["beta"] == sync_part.beta()
59+
assert d["sync_part"]["gamma"] == sync_part.gamma()
60+
assert d["sync_part"]["time"] == sync_part.time()
61+
62+
63+
def test_collect_empty_bunch():
64+
bunch = Bunch()
65+
d = collect_bunch(bunch)
66+
assert d is None
67+
68+
69+
def test_collect_arbitrary_bunch_attr(bunch):
70+
bunch.bunchAttrDouble("arbitrary_dbl_attr", 42.0)
71+
bunch.bunchAttrInt("arbitrary_int_attr", 42)
72+
73+
d = collect_bunch(bunch)
74+
75+
assert "arbitrary_dbl_attr" in d["attributes"]
76+
assert d["attributes"]["arbitrary_dbl_attr"] == 42.0
77+
assert "arbitrary_int_attr" in d["attributes"]
78+
assert d["attributes"]["arbitrary_int_attr"] == 42

0 commit comments

Comments
 (0)