diff --git a/orbit_tools/bunch.py b/orbit_tools/bunch.py index e238710..a80a49b 100644 --- a/orbit_tools/bunch.py +++ b/orbit_tools/bunch.py @@ -18,6 +18,14 @@ from .cov import normalization_matrix +def set_particle_macrosizes(bunch: Bunch, macrosizes: Iterable) -> Bunch: + bunch.addPartAttr("macrosize") # sets macrosize=0 for all particles + attribute_array_index = 0 + for index in range(bunch.getSize()): + bunch.partAttrValue("macrosize", index, attribute_array_index, macrosizes[index]) + return bunch + + def get_part_coords(bunch: Bunch, index: int) -> list[float]: x = bunch.x(index) y = bunch.y(index) diff --git a/tests/test_bunch.py b/tests/test_bunch.py index 8b0ba1d..72c4d45 100644 --- a/tests/test_bunch.py +++ b/tests/test_bunch.py @@ -126,3 +126,12 @@ def sample(): def test_get_bunch_twiss_containers(): bunch = make_gaussian_bunch() twiss_x, twiss_y, twiss_z = ot.bunch.get_bunch_twiss_containers(bunch) + + +def test_set_particle_macrosizes(): + bunch = make_gaussian_bunch() + macrosizes = list(range(bunch.getSize())) + bunch = ot.bunch.set_particle_macrosizes(bunch, macrosizes) + for i, macrosize in enumerate(macrosizes): + assert macrosize == bunch.partAttrValue("macrosize", i, 0) +