Skip to content

Commit 16802d8

Browse files
authored
pyAMReX: Bind Particle Iterators (#87)
* pyAMReX: Bind Particle Iterators * First Tests
1 parent 7eb82a9 commit 16802d8

File tree

4 files changed

+213
-27
lines changed

4 files changed

+213
-27
lines changed

src/Base/Iterator.H

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright 2021-2022 The AMReX Community
2+
*
3+
* Authors: Axel Huebl
4+
* License: BSD-3-Clause-LBNL
5+
*/
6+
#include <pybind11/pybind11.h>
7+
#include <pybind11/stl.h>
8+
9+
#include <AMReX_Config.H>
10+
#include <AMReX_BoxArray.H>
11+
#include <AMReX_DistributionMapping.H>
12+
#include <AMReX_FArrayBox.H>
13+
#include <AMReX_FabArray.H>
14+
#include <AMReX_FabArrayBase.H>
15+
#include <AMReX_MultiFab.H>
16+
17+
#include <memory>
18+
#include <string>
19+
20+
namespace py = pybind11;
21+
using namespace amrex;
22+
23+
namespace pyAMReX
24+
{
25+
/** This is a helper function for the C++ equivalent of void operator++()
26+
*
27+
* In Python, iterators always are called with __next__, even for the
28+
* first access. This means we need to handle the first iterator element
29+
* explicitly, otherwise we will jump directly to the 2nd element. We do
30+
* this the same way as pybind11 does this, via a little state:
31+
* https://github.com/AMReX-Codes/pyamrex/pull/50
32+
* https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282
33+
*
34+
* To avoid unnecessary (and expensive) copies, remember to only call this
35+
* helper always with py::return_value_policy::reference_internal!
36+
*
37+
*
38+
* @tparam T_Iterator This is usally MFIter or Par(Const)Iter or derived classes
39+
* @param it the current iterator
40+
* @return the updated iterator
41+
*/
42+
template< typename T_Iterator >
43+
T_Iterator &
44+
iterator_next( T_Iterator & it )
45+
{
46+
py::object self = py::cast(it);
47+
if (!py::hasattr(self, "first_or_done"))
48+
self.attr("first_or_done") = true;
49+
50+
bool first_or_done = self.attr("first_or_done").cast<bool>();
51+
if (first_or_done) {
52+
first_or_done = false;
53+
self.attr("first_or_done") = first_or_done;
54+
}
55+
else
56+
++it;
57+
if( !it.isValid() )
58+
{
59+
first_or_done = true;
60+
it.Finalize();
61+
throw py::stop_iteration();
62+
}
63+
return it;
64+
}
65+
}

src/Base/MultiFab.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <pybind11/pybind11.h>
77
#include <pybind11/stl.h>
88

9+
#include "Base/Iterator.H"
10+
911
#include <AMReX_Config.H>
1012
#include <AMReX_BoxArray.H>
1113
#include <AMReX_DistributionMapping.H>
@@ -79,26 +81,7 @@ void init_MultiFab(py::module &m) {
7981

8082
// eq. to void operator++()
8183
.def("__next__",
82-
[](MFIter & mfi) -> MFIter & {
83-
py::object self = py::cast(mfi);
84-
if (!py::hasattr(self, "first_or_done"))
85-
self.attr("first_or_done") = true;
86-
87-
bool first_or_done = self.attr("first_or_done").cast<bool>();
88-
if (first_or_done) {
89-
first_or_done = false;
90-
self.attr("first_or_done") = first_or_done;
91-
}
92-
else
93-
++mfi;
94-
if( !mfi.isValid() )
95-
{
96-
first_or_done = true;
97-
mfi.Finalize();
98-
throw py::stop_iteration();
99-
}
100-
return mfi;
101-
},
84+
&pyAMReX::iterator_next<MFIter>,
10285
py::return_value_policy::reference_internal
10386
)
10487

src/Particle/ParticleContainer.cpp

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
#include <pybind11/pybind11.h>
77
#include <pybind11/stl.h>
88

9+
#include "Base/Iterator.H"
10+
911
#include <AMReX_BoxArray.H>
1012
#include <AMReX_IntVect.H>
13+
#include <AMReX_ParIter.H>
1114
#include <AMReX_Particles.H>
1215
#include <AMReX_ParticleContainer.H>
1316
#include <AMReX_ParticleTile.H>
@@ -19,9 +22,88 @@
1922
namespace py = pybind11;
2023
using namespace amrex;
2124

25+
26+
template <bool is_const, typename T_ParIterBase>
27+
void make_Base_Iterators (py::module &m)
28+
{
29+
using iterator_base = T_ParIterBase;
30+
using container = typename iterator_base::ContainerType;
31+
constexpr int NStructReal = container::NStructReal;
32+
constexpr int NStructInt = container::NStructInt;
33+
constexpr int NArrayReal = container::NArrayReal;
34+
constexpr int NArrayInt = container::NArrayInt;
35+
36+
std::string particle_it_base_name = std::string("ParIterBase_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt));
37+
if (is_const) particle_it_base_name = "Const" + particle_it_base_name;
38+
py::class_<iterator_base, MFIter>(m, particle_it_base_name.c_str())
39+
.def(py::init<container&, int>(),
40+
py::arg("particle_container"), py::arg("level"))
41+
.def(py::init<container&, int, MFItInfo&>(),
42+
py::arg("particle_container"), py::arg("level"), py::arg("info"))
43+
44+
.def("particle_tile", &iterator_base::GetParticleTile,
45+
py::return_value_policy::reference_internal)
46+
.def("aos", &iterator_base::GetArrayOfStructs,
47+
py::return_value_policy::reference_internal)
48+
.def("soa", &iterator_base::GetStructOfArrays,
49+
py::return_value_policy::reference_internal)
50+
51+
.def_property_readonly("num_particles", &iterator_base::numParticles)
52+
.def_property_readonly("num_real_particles", &iterator_base::numRealParticles)
53+
.def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles)
54+
.def_property_readonly("level", &iterator_base::GetLevel)
55+
.def_property_readonly("pair_index", &iterator_base::GetPairIndex)
56+
.def("geom", &iterator_base::Geom, py::arg("level"))
57+
58+
// eq. to void operator++()
59+
.def("__next__",
60+
&pyAMReX::iterator_next<iterator_base>,
61+
py::return_value_policy::reference_internal
62+
)
63+
.def("__iter__",
64+
[](iterator_base & it) -> iterator_base & {
65+
return it;
66+
},
67+
py::return_value_policy::reference_internal
68+
)
69+
;
70+
}
71+
72+
template <bool is_const, typename T_ParIter, template<class> class Allocator=DefaultAllocator>
73+
void make_Iterators (py::module &m)
74+
{
75+
using iterator = T_ParIter;
76+
using container = typename iterator::ContainerType;
77+
constexpr int NStructReal = container::NStructReal;
78+
constexpr int NStructInt = container::NStructInt;
79+
constexpr int NArrayReal = container::NArrayReal;
80+
constexpr int NArrayInt = container::NArrayInt;
81+
82+
using iterator_base = amrex::ParIterBase<is_const, NStructReal, NStructInt, NArrayReal, NArrayInt, Allocator>;
83+
make_Base_Iterators< is_const, iterator_base >(m);
84+
85+
auto particle_it_name = std::string("Par");
86+
if (is_const) particle_it_name += "Const";
87+
particle_it_name += std::string("Iter_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt));
88+
py::class_<iterator, iterator_base>(m, particle_it_name.c_str())
89+
.def("__repr__",
90+
[particle_it_name](iterator const & pti) {
91+
std::string r = "<amrex." + particle_it_name + " (";
92+
if( !pti.isValid() ) { r.append("in"); }
93+
r.append("valid)>");
94+
return r;
95+
}
96+
)
97+
.def(py::init<container&, int>(),
98+
py::arg("particle_container"), py::arg("level"))
99+
.def(py::init<container&, int, MFItInfo&>(),
100+
py::arg("particle_container"), py::arg("level"), py::arg("info"))
101+
;
102+
}
103+
22104
template <int T_NStructReal, int T_NStructInt=0, int T_NArrayReal=0, int T_NArrayInt=0,
23105
template<class> class Allocator=DefaultAllocator>
24-
void make_ParticleContainer(py::module &m)
106+
void make_ParticleContainer_and_Iterators (py::module &m)
25107
{
26108
using ParticleContainerType = ParticleContainer<
27109
T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt,
@@ -237,17 +319,21 @@ void make_ParticleContainer(py::module &m)
237319
// m_particles[lev][index].define(NumRuntimeRealComps(), NumRuntimeIntComps());
238320
// return ParticlesAt(lev, iter);
239321
// }
240-
241322
;
323+
324+
using iterator = amrex::ParIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>;
325+
make_Iterators< false, iterator, Allocator >(m);
326+
using const_iterator = amrex::ParConstIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>;
327+
make_Iterators< true, const_iterator, Allocator >(m);
242328
}
243329

244330

245331
void init_ParticleContainer(py::module& m) {
246332
// TODO: we might need to move all or most of the defines in here into a
247333
// test/example submodule, so they do not collide with downstream projects
248-
make_ParticleContainer< 1, 1, 2, 1> (m);
249-
make_ParticleContainer< 0, 0, 4, 0> (m); // HiPACE++ 22.07
250-
make_ParticleContainer< 0, 0, 5, 0> (m); // ImpactX 22.07
251-
make_ParticleContainer< 0, 0, 7, 0> (m);
252-
make_ParticleContainer< 0, 0, 37, 1> (m); // HiPACE++ 22.07
334+
make_ParticleContainer_and_Iterators< 1, 1, 2, 1> (m);
335+
make_ParticleContainer_and_Iterators< 0, 0, 4, 0> (m); // HiPACE++ 22.07
336+
make_ParticleContainer_and_Iterators< 0, 0, 5, 0> (m); // ImpactX 22.07
337+
make_ParticleContainer_and_Iterators< 0, 0, 7, 0> (m);
338+
make_ParticleContainer_and_Iterators< 0, 0, 37, 1> (m); // HiPACE++ 22.07
253339
}

tests/test_particleContainer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,58 @@ def test_pc_init():
121121
assert pc.TotalNumberOfParticles() == pc.NumberOfParticlesAtLevel(0) == npart
122122
assert pc.OK()
123123

124+
print("Iterate particle boxes & set values")
125+
lvl = 0
126+
for pti in amrex.ParIter_1_1_2_1(pc, level=lvl):
127+
print("...")
128+
assert pti.num_particles == 1
129+
assert pti.num_real_particles == 1
130+
assert pti.num_neighbor_particles == 0
131+
assert pti.level == lvl
132+
print(pti.pair_index)
133+
print(pti.geom(level=lvl))
134+
135+
aos = pti.aos()
136+
aos_arr = np.array(aos, copy=False)
137+
aos_arr[0]["x"] = 0.30
138+
aos_arr[0]["y"] = 0.35
139+
aos_arr[0]["z"] = 0.40
140+
141+
# TODO: this seems to write into a copy of the data
142+
soa = pti.soa()
143+
real_arrays = soa.GetRealData()
144+
int_arrays = soa.GetIntData()
145+
real_arrays[0] = [0.55]
146+
real_arrays[1] = [0.22]
147+
int_arrays[0] = [2]
148+
149+
assert np.allclose(real_arrays[0], np.array([0.55]))
150+
assert np.allclose(real_arrays[1], np.array([0.22]))
151+
assert np.allclose(int_arrays[0], np.array([2]))
152+
153+
# read-only
154+
for pti in amrex.ParConstIter_1_1_2_1(pc, level=lvl):
155+
assert pti.num_particles == 1
156+
assert pti.num_real_particles == 1
157+
assert pti.num_neighbor_particles == 0
158+
assert pti.level == lvl
159+
160+
aos = pti.aos()
161+
aos_arr = np.array(aos, copy=False)
162+
assert aos[0].x == 0.30
163+
assert aos[0].y == 0.35
164+
assert aos[0].z == 0.40
165+
assert aos_arr[0]["z"] == 0.40
166+
167+
soa = pti.soa()
168+
real_arrays = soa.GetRealData()
169+
int_arrays = soa.GetIntData()
170+
print(real_arrays[0])
171+
# TODO: this does not work yet and is still the original data
172+
# assert np.allclose(real_arrays[0], np.array([0.55]))
173+
# assert np.allclose(real_arrays[1], np.array([0.22]))
174+
# assert np.allclose(int_arrays[0], np.array([2]))
175+
124176

125177
def test_particle_init(Npart, particle_container):
126178
pc = particle_container

0 commit comments

Comments
 (0)