Skip to content

Commit 24fd743

Browse files
authored
Particle Iterator Shortcuts (#497)
* Allow to iterate on particles of level N without a level loop. * Allow to access named pure SoA attributes directly on the particle iterator (`pti`).
1 parent a0fd36f commit 24fd743

File tree

7 files changed

+143
-18
lines changed

7 files changed

+143
-18
lines changed

docs/source/usage/compute.rst

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,31 @@ Here is the general structure for computing on particles:
9292

9393
.. tab-set::
9494

95-
.. tab-item:: Simple: Pandas (read-only)
95+
.. tab-item:: Simple
9696

9797
.. literalinclude:: ../../../tests/test_particleContainer.py
9898
:language: python3
9999
:dedent: 4
100-
:start-after: # Manual: Pure SoA Compute PC Pandas START
101-
:end-before: # Manual: Pure SoA Compute PC Pandas END
100+
:start-after: # Manual: Pure SoA Compute PC Simple pti START
101+
:end-before: # Manual: Pure SoA Compute PC Simple pti END
102102

103-
.. tab-item:: Detailed (read and write)
103+
.. tab-item:: Detailed
104104

105105
.. literalinclude:: ../../../tests/test_particleContainer.py
106106
:language: python3
107107
:dedent: 4
108108
:start-after: # Manual: Pure SoA Compute PC Detailed START
109109
:end-before: # Manual: Pure SoA Compute PC Detailed END
110110

111+
.. tab-item:: Pandas (read-only)
112+
113+
.. literalinclude:: ../../../tests/test_particleContainer.py
114+
:language: python3
115+
:dedent: 4
116+
:start-after: # Manual: Pure SoA Compute PC Pandas START
117+
:end-before: # Manual: Pure SoA Compute PC Pandas END
118+
119+
111120
.. tab-item:: Legacy (AoS + SoA) Layout
112121

113122
.. literalinclude:: ../../../tests/test_particleContainer.py

src/Particle/ParticleContainer.H

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ void make_ParticleContainer_and_Iterators (py::module &m, std::string allocstr)
477477

478478
// simpler particle iterator loops: return types of this particle box
479479
py_pc
480-
.def_property_readonly_static("iterator", [](py::object /* pc */){ return py::type::of<iterator>(); },
480+
.def_property_readonly_static("Iterator", [](py::object /* pc */){ return py::type::of<iterator>(); },
481481
"amrex iterator for particle boxes")
482-
.def_property_readonly_static("const_iterator", [](py::object /* pc */){ return py::type::of<const_iterator>(); },
482+
.def_property_readonly_static("ConstIterator", [](py::object /* pc */){ return py::type::of<const_iterator>(); },
483483
"amrex constant iterator for particle boxes (read-only)")
484484
;
485485
}

src/Particle/StructOfArrays.H

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ void make_StructOfArrays(py::module &m, std::string allocstr)
5555
py::return_value_policy::reference_internal,
5656
py::arg("index"),
5757
"Get access to a particle Real component Array (compile-time and runtime component)")
58+
.def("get_real_data", py::overload_cast<std::string const &>(&SOAType::GetRealData),
59+
py::return_value_policy::reference_internal,
60+
py::arg("name"),
61+
"Get access to a particle Real component Array (compile-time and runtime component)")
62+
.def("get_int_data", py::overload_cast<std::string const &>(&SOAType::GetIntData),
63+
py::return_value_policy::reference_internal,
64+
py::arg("name"),
65+
"Get access to a particle Real component Array (compile-time and runtime component)")
5866

5967
// names
6068
.def_property_readonly("real_names",

src/amrex/extensions/Iterator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,18 @@ def next(self):
3838
raise StopIteration
3939

4040
return self
41+
42+
43+
def getitem(self, name):
44+
"""Access (read/write) particle vectors."""
45+
if not self.is_soa_particle:
46+
raise ValueError("Only pure SoA particle containers support pti.__get__")
47+
48+
if name == "idcpu":
49+
return self.soa().get_idcpu_data().to_xp(copy=False)
50+
elif name in self.soa().real_names:
51+
return self.soa().get_real_data(name).to_xp(copy=False)
52+
elif name in self.soa().int_names:
53+
return self.soa().get_int_data(name).to_xp(copy=False)
54+
else:
55+
raise KeyError(f"Unknown particle attribute name: {name}")

src/amrex/extensions/ParticleContainer.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,65 @@
66
License: BSD-3-Clause-LBNL
77
"""
88

9-
from .Iterator import next
9+
import warnings
10+
11+
from .Iterator import getitem, next
12+
13+
14+
def iterator(self, *args, level=None):
15+
"""Create an iterator over all particle tiles
16+
17+
Parameters
18+
----------
19+
self : amrex.ParticleContainer_*
20+
A ParticleContainer class in pyAMReX
21+
args : deprecated positional argument
22+
level : int | str, optional
23+
The MR level. Allowed values are [0:self.finest_level+1) and "all".
24+
If there is more than one MR level, the argument is required.
25+
26+
Returns
27+
-------
28+
Iterator over all particle tiles at the specified level.
29+
30+
Examples
31+
--------
32+
>>> pc.iterator(level="all")
33+
>>> pc.iterator(level=0) # only particles on the the coarsest MR level
34+
"""
35+
# Warn if a second positional argument is provided (ignored argument)
36+
if len(args) > 0:
37+
if len(args) == 1 and isinstance(args[0], int) and level is None:
38+
level = args[0]
39+
else:
40+
warnings.warn(
41+
"The second positional argument to iterator() is deprecated and ignored. "
42+
"Please update your code to use iterator(self, level=...) instead.",
43+
DeprecationWarning,
44+
stacklevel=2,
45+
)
46+
47+
has_mr = self.finest_level > 0
48+
49+
if level is None:
50+
if has_mr:
51+
raise ValueError(
52+
"level must be specified for multi-level ParticleContainers"
53+
)
54+
else:
55+
level = 0
56+
57+
if level == "all":
58+
raise ValueError("level='all' is not yet supported for ParticleContainers")
59+
# TODO: This does not work
60+
# for lvl in range(self.finest_level + 1):
61+
# yield self.Iterator(self, level=lvl)
62+
elif isinstance(level, int) and level >= 0:
63+
return self.Iterator(self, level=level)
64+
else:
65+
raise ValueError(
66+
f"level must be an integer in [0:{self.finest_level + 1}) or 'all', but got: {level}"
67+
)
1068

1169

1270
def pc_to_df(self, local=True, comm=None, root_rank=0):
@@ -45,7 +103,7 @@ def pc_to_df(self, local=True, comm=None, root_rank=0):
45103
# local DataFrame(s)
46104
dfs_local = []
47105
for lvl in range(self.finest_level + 1):
48-
for pti in self.const_iterator(self, level=lvl):
106+
for pti in self.const_iterator(level=lvl):
49107
if pti.size == 0:
50108
continue
51109

@@ -130,6 +188,7 @@ def register_ParticleContainer_extension(amr):
130188
):
131189
ParIter_type.__next__ = next
132190
ParIter_type.__iter__ = lambda self: self
191+
ParIter_type.__getitem__ = getitem
133192

134193
# register member functions for every ParticleContainer_* type
135194
for _, ParticleContainer_type in inspect.getmembers(
@@ -138,4 +197,8 @@ def register_ParticleContainer_extension(amr):
138197
and member.__module__ == amr.__name__
139198
and member.__name__.startswith("ParticleContainer_"),
140199
):
200+
ParticleContainer_type.iterator = iterator
201+
ParticleContainer_type.const_iterator = (
202+
iterator # TODO: simplified, code duplication
203+
)
141204
ParticleContainer_type.to_df = pc_to_df

tests/test_particleContainer.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def particle_container(Npart, std_geometry, distmap, boxarr, std_real_box):
6363

6464
# assign some values to runtime components
6565
for lvl in range(pc.finest_level + 1):
66-
for pti in pc.iterator(pc, level=lvl):
66+
for pti in pc.iterator(level=lvl):
6767
soa = pti.soa()
6868
soa.get_real_data(2).assign(1.2345)
6969
soa.get_int_data(1).assign(42)
@@ -97,7 +97,7 @@ def soa_particle_container(Npart, std_geometry, distmap, boxarr, std_real_box):
9797

9898
# assign some values to runtime components
9999
for lvl in range(pc.finest_level + 1):
100-
for pti in pc.iterator(pc, level=lvl):
100+
for pti in pc.iterator(level=lvl):
101101
soa = pti.soa()
102102
soa.get_real_data(8).assign(1.2345)
103103
soa.get_int_data(0).assign(42)
@@ -212,7 +212,7 @@ def test_pc_init():
212212
# lvl = 0
213213
for lvl in range(pc.finest_level + 1):
214214
print(f"at level {lvl}:")
215-
for pti in pc.iterator(pc, level=lvl):
215+
for pti in pc.iterator(level=lvl):
216216
print("...")
217217
assert pti.num_particles == 1
218218
assert pti.num_real_particles == 1
@@ -243,7 +243,7 @@ def test_pc_init():
243243

244244
# read-only
245245
for lvl in range(pc.finest_level + 1):
246-
for pti in pc.const_iterator(pc, level=lvl):
246+
for pti in pc.const_iterator(level=lvl):
247247
assert pti.num_particles == 1
248248
assert pti.num_real_particles == 1
249249
assert pti.num_neighbor_particles == 0
@@ -383,17 +383,17 @@ class Config:
383383
# iterate over mesh-refinement levels
384384
for lvl in range(pc.finest_level + 1):
385385
# loop local tiles of particles
386-
for pti in pc.iterator(pc, level=lvl):
386+
for pti in pc.iterator(level=lvl):
387387
# compile-time and runtime attributes
388388
soa = pti.soa().to_xp()
389389

390-
# print all particle ids in the chunk
390+
# print all particle ids in the tile
391391
print("idcpu =", soa.idcpu)
392392

393393
x = soa.real["x"]
394394
y = soa.real["y"]
395395

396-
# write to all particles in the chunk
396+
# write to all particles in the tile
397397
# note: careful, if you change particle positions, you might need to
398398
# redistribute particles before continuing the simulation step
399399
soa.real["x"][:] = 0.30
@@ -410,6 +410,36 @@ class Config:
410410
soa_int[:] = 12
411411
# Manual: Pure SoA Compute PC Detailed END
412412

413+
# Manual: Pure SoA Compute PC Simple pti START
414+
# code-specific getter function, e.g.:
415+
# pc = sim.get_particles()
416+
# Config = sim.extension.Config
417+
418+
# iterate over particles on level 0
419+
for pti in pc.iterator(level=0):
420+
# print all particle ids in the tile
421+
print("idcpu =", pti["idcpu"])
422+
423+
x = pti["x"] # this is automatically a cupy or numpy
424+
y = pti["y"] # array, depending on Config.have_gpu
425+
426+
# write to all particles in the chunk
427+
# note: careful, if you change particle positions, you might need to
428+
# redistribute particles before continuing the simulation step
429+
pti["x"][:] = 0.30
430+
pti["y"][:] = 0.35
431+
pti["z"][:] = 0.40
432+
433+
pti["a"][:] = x[:] ** 2
434+
pti["b"][:] = x[:] + y[:]
435+
pti["c"][:] = 0.50
436+
# ...
437+
438+
# int attributes
439+
pti["i1"][:] = 12
440+
pti["i2"][:] = 13
441+
# Manual: Pure SoA Compute PC Simple pti END
442+
413443

414444
def test_pc_numpy(particle_container, Npart):
415445
"""Used in docs/source/usage/compute.rst"""
@@ -427,7 +457,7 @@ class Config:
427457
# iterate over mesh-refinement levels
428458
for lvl in range(pc.finest_level + 1):
429459
# loop local tiles of particles
430-
for pti in pc.iterator(pc, level=lvl):
460+
for pti in pc.iterator(level=lvl):
431461
# default layout: AoS with positions and idcpu
432462
# note: not part of the new PureSoA particle container layout
433463
aos = (

tests/test_plotfileparticledata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def particle_container(Rpart, std_geometry, distmap, boxarr, std_real_box):
6565
particles_tile_ct = 0
6666
# assign some values to runtime components
6767
for lvl in range(pc.finest_level + 1):
68-
for pti in pc.iterator(pc, level=lvl):
68+
for pti in pc.iterator(level=lvl):
6969
aos = pti.aos()
7070
aos_numpy = aos.to_numpy(copy=False)
7171
for i, p in enumerate(aos_numpy):
@@ -81,7 +81,7 @@ def check_particles_container(pc, reference_particles):
8181
Checks the contents of `pc` against `reference_particles`
8282
"""
8383
for lvl in range(pc.finest_level + 1):
84-
for i, pti in enumerate(pc.iterator(pc, level=lvl)):
84+
for i, pti in enumerate(pc.iterator(level=lvl)):
8585
aos = pti.aos()
8686
for p in aos.to_numpy(copy=True):
8787
ref = reference_particles[p["idata_0"]]

0 commit comments

Comments
 (0)