Skip to content

Commit 0102923

Browse files
author
Wood, Tony
committed
preliminary implementation of a pure-python function to collect bunch into dictionary of numpy objects
1 parent f99113c commit 0102923

File tree

4 files changed

+196
-2
lines changed

4 files changed

+196
-2
lines changed

py/orbit/bunch_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77

88
from orbit.bunch_utils.particleidnumber import ParticleIdNumber
9+
from orbit.bunch_utils.collect_bunch import collect_bunch
910

1011
__all__ = []
1112
__all__.append("addParticleIdNumbers")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from orbit.core.bunch import Bunch
2+
from orbit.core import orbit_mpi
3+
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
7+
8+
def collect_bunch(
9+
bunch: Bunch,
10+
) -> dict[str, np.float64 | np.int32 | NDArray[np.float64]]:
11+
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
12+
Parameters
13+
----------
14+
bunch : Bunch
15+
The PyOrbit::Bunch object from which to collect attributes.
16+
Returns
17+
-------
18+
dict[str, np.float64 | np.int32 | NDArray[np.float64]]:
19+
By default this returns a dictionary containing the following keys and their corresponding values:
20+
- "x": particle x-coordinates [m]
21+
- "xp": particle x-momenta [rad]
22+
- "y": particle y-coordinates [m]
23+
- "yp": particle y-momenta [rad]
24+
- "z": particle longitudinal coordinates [m]
25+
- "dE": particle energy deviations [GeV]
26+
- "sync_part_coords": coordinates of the synchronous particle (x,y,z) [m]
27+
- "sync_part_kin_energy": kinetic energy of the synchronous particle [GeV]
28+
- "sync_part_momentum": momentum of the synchronous particle [GeV/c]
29+
- "sync_part_beta": beta of the synchronous particle
30+
- "sync_part_gamma": gamma of the synchronous particle
31+
- "sync_part_time": time of the synchronous particle [s]
32+
- Any additional attributes defined in the bunch.
33+
"""
34+
n_particles = bunch.getSize()
35+
36+
if n_particles == 0:
37+
return {}
38+
39+
mpi_comm = bunch.getMPIComm() # orbit_mpi.mpi_comm.MPI_COMM_WORLD
40+
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
41+
mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm)
42+
43+
if mpi_rank == 0:
44+
bunch_dict = {"x": [], "xp": [], "y": [], "yp": [], "z": [], "dE": []}
45+
for attr in bunch.bunchAttrDoubleNames():
46+
bunch_dict[attr] = np.float64(bunch.bunchAttrDouble(attr))
47+
48+
for attr in bunch.bunchAttrIntNames():
49+
bunch_dict[attr] = np.int32(bunch.bunchAttrInt(attr))
50+
51+
sync_part = bunch.getSyncParticle()
52+
53+
bunch_dict |= {
54+
"sync_part_coords": np.array(sync_part.pVector()),
55+
"sync_part_kin_energy": np.float64(sync_part.kinEnergy()),
56+
"sync_part_momentum": np.float64(sync_part.momentum()),
57+
"sync_part_beta": np.float64(sync_part.beta()),
58+
"sync_part_gamma": np.float64(sync_part.gamma()),
59+
"sync_part_time": np.float64(sync_part.time()),
60+
}
61+
62+
for i in range(n_particles):
63+
bunch_dict["x"].append(bunch.x(i))
64+
bunch_dict["xp"].append(bunch.xp(i))
65+
bunch_dict["y"].append(bunch.y(i))
66+
bunch_dict["yp"].append(bunch.yp(i))
67+
bunch_dict["z"].append(bunch.z(i))
68+
bunch_dict["dE"].append(bunch.dE(i))
69+
70+
mpi_tag = 42 # not sure this is necessary; seems like it can be any integer
71+
for cpu_idx in range(1, mpi_size):
72+
for i in range(n_particles):
73+
if mpi_rank == cpu_idx:
74+
coord_arr = (
75+
bunch.x(i),
76+
bunch.xp(i),
77+
bunch.y(i),
78+
bunch.yp(i),
79+
bunch.z(i),
80+
bunch.dE(i),
81+
)
82+
orbit_mpi.MPI_Send(
83+
coord_arr, orbit_mpi.mpi_datatype.MPI_DOUBLE, 0, mpi_tag, mpi_comm
84+
)
85+
elif mpi_rank == 0:
86+
coord_arr = orbit_mpi.MPI_Recv(
87+
orbit_mpi.mpi_datatype.MPI_DOUBLE, cpu_idx, mpi_tag, mpi_comm
88+
)
89+
bunch_dict["x"].append(coord_arr[0])
90+
bunch_dict["xp"].append(coord_arr[1])
91+
bunch_dict["y"].append(coord_arr[2])
92+
bunch_dict["yp"].append(coord_arr[3])
93+
bunch_dict["z"].append(coord_arr[4])
94+
bunch_dict["dE"].append(coord_arr[5])
95+
96+
if mpi_rank == 0:
97+
for k, v in bunch_dict.items():
98+
if isinstance(v, list):
99+
bunch_dict[k] = np.array(v, dtype=np.float64)
100+
return bunch_dict

py/orbit/bunch_utils/meson.build

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
py_sources = files([
55
'__init__.py',
6-
'particleidnumber.py'
6+
'particleidnumber.py',
7+
'collect_bunch.py'
78
])
89

910
python.install_sources(
1011
py_sources,
1112
subdir: 'orbit/bunch_utils',
1213
# pure: true,
13-
)
14+
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from orbit.core.bunch import Bunch
2+
from orbit.bunch_generators import GaussDist3D
3+
from orbit.bunch_utils import collect_bunch
4+
5+
from pytest import fixture
6+
7+
8+
@fixture
9+
def bunch():
10+
bunch = Bunch()
11+
bunch.mass(0.939294)
12+
bunch.charge(-1.0)
13+
bunch.getSyncParticle().kinEnergy(0.0025)
14+
gauss_dist = GaussDist3D()
15+
for i in range(10):
16+
x, xp, y, yp, z, dE = gauss_dist.getCoordinates()
17+
bunch.addParticle(x, xp, y, yp, z, dE)
18+
bunch.macroSize(10)
19+
return bunch
20+
21+
22+
def test_collect_bunch(bunch):
23+
d = collect_bunch(bunch)
24+
25+
n_particles = bunch.getSize()
26+
27+
expected_keys = {
28+
"x",
29+
"xp",
30+
"y",
31+
"yp",
32+
"z",
33+
"dE",
34+
"charge",
35+
"classical_radius",
36+
"mass",
37+
"macro_size",
38+
"sync_part_coords",
39+
"sync_part_kin_energy",
40+
"sync_part_momentum",
41+
"sync_part_beta",
42+
"sync_part_gamma",
43+
"sync_part_time",
44+
}
45+
46+
x, xp, y, yp, z, dE = [], [], [], [], [], []
47+
for i in range(n_particles):
48+
x.append(bunch.x(i))
49+
xp.append(bunch.px(i))
50+
y.append(bunch.y(i))
51+
yp.append(bunch.py(i))
52+
z.append(bunch.z(i))
53+
dE.append(bunch.dE(i))
54+
55+
assert set(d.keys()) == expected_keys
56+
assert (d["x"] == x).all()
57+
assert (d["xp"] == xp).all()
58+
assert (d["y"] == y).all()
59+
assert (d["yp"] == yp).all()
60+
assert (d["z"] == z).all()
61+
assert (d["dE"] == dE).all()
62+
63+
assert d["charge"] == bunch.bunchAttrDouble("charge")
64+
assert d["classical_radius"] == bunch.bunchAttrDouble("classical_radius")
65+
assert d["mass"] == bunch.bunchAttrDouble("mass")
66+
assert d["macro_size"] == bunch.bunchAttrDouble("macro_size")
67+
68+
sync_part = bunch.getSyncParticle()
69+
assert (d["sync_part_coords"] == sync_part.pVector()).all()
70+
assert d["sync_part_kin_energy"] == sync_part.kinEnergy()
71+
assert d["sync_part_momentum"] == sync_part.momentum()
72+
assert d["sync_part_beta"] == sync_part.beta()
73+
assert d["sync_part_gamma"] == sync_part.gamma()
74+
assert d["sync_part_time"] == sync_part.time()
75+
76+
77+
def test_collect_empty_bunch():
78+
bunch = Bunch()
79+
d = collect_bunch(bunch)
80+
assert len(d) == 0
81+
82+
83+
def test_collect_arbitrary_bunch_attr(bunch):
84+
bunch.bunchAttrDouble("arbitrary_dbl_attr", 42.0)
85+
bunch.bunchAttrInt("arbitrary_int_attr", 42)
86+
87+
d = collect_bunch(bunch)
88+
89+
assert "arbitrary_dbl_attr" in d
90+
assert d["arbitrary_dbl_attr"] == 42.0
91+
assert "arbitrary_int_attr" in d
92+
assert d["arbitrary_int_attr"] == 42

0 commit comments

Comments
 (0)