|
6 | 6 | #include <pybind11/pybind11.h> |
7 | 7 | #include <pybind11/stl.h> |
8 | 8 |
|
| 9 | +#include "Base/Iterator.H" |
| 10 | + |
9 | 11 | #include <AMReX_BoxArray.H> |
10 | 12 | #include <AMReX_IntVect.H> |
| 13 | +#include <AMReX_ParIter.H> |
11 | 14 | #include <AMReX_Particles.H> |
12 | 15 | #include <AMReX_ParticleContainer.H> |
13 | 16 | #include <AMReX_ParticleTile.H> |
|
19 | 22 | namespace py = pybind11; |
20 | 23 | using namespace amrex; |
21 | 24 |
|
| 25 | + |
| 26 | +template <bool is_const, typename T_ParIterBase> |
| 27 | +void make_Base_Iterators (py::module &m) |
| 28 | +{ |
| 29 | + using iterator_base = T_ParIterBase; |
| 30 | + using container = typename iterator_base::ContainerType; |
| 31 | + constexpr int NStructReal = container::NStructReal; |
| 32 | + constexpr int NStructInt = container::NStructInt; |
| 33 | + constexpr int NArrayReal = container::NArrayReal; |
| 34 | + constexpr int NArrayInt = container::NArrayInt; |
| 35 | + |
| 36 | + std::string particle_it_base_name = std::string("ParIterBase_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt)); |
| 37 | + if (is_const) particle_it_base_name = "Const" + particle_it_base_name; |
| 38 | + py::class_<iterator_base, MFIter>(m, particle_it_base_name.c_str()) |
| 39 | + .def(py::init<container&, int>(), |
| 40 | + py::arg("particle_container"), py::arg("level")) |
| 41 | + .def(py::init<container&, int, MFItInfo&>(), |
| 42 | + py::arg("particle_container"), py::arg("level"), py::arg("info")) |
| 43 | + |
| 44 | + .def("particle_tile", &iterator_base::GetParticleTile, |
| 45 | + py::return_value_policy::reference_internal) |
| 46 | + .def("aos", &iterator_base::GetArrayOfStructs, |
| 47 | + py::return_value_policy::reference_internal) |
| 48 | + .def("soa", &iterator_base::GetStructOfArrays, |
| 49 | + py::return_value_policy::reference_internal) |
| 50 | + |
| 51 | + .def_property_readonly("num_particles", &iterator_base::numParticles) |
| 52 | + .def_property_readonly("num_real_particles", &iterator_base::numRealParticles) |
| 53 | + .def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles) |
| 54 | + .def_property_readonly("level", &iterator_base::GetLevel) |
| 55 | + .def_property_readonly("pair_index", &iterator_base::GetPairIndex) |
| 56 | + .def("geom", &iterator_base::Geom, py::arg("level")) |
| 57 | + |
| 58 | + // eq. to void operator++() |
| 59 | + .def("__next__", |
| 60 | + &pyAMReX::iterator_next<iterator_base>, |
| 61 | + py::return_value_policy::reference_internal |
| 62 | + ) |
| 63 | + .def("__iter__", |
| 64 | + [](iterator_base & it) -> iterator_base & { |
| 65 | + return it; |
| 66 | + }, |
| 67 | + py::return_value_policy::reference_internal |
| 68 | + ) |
| 69 | + ; |
| 70 | +} |
| 71 | + |
| 72 | +template <bool is_const, typename T_ParIter, template<class> class Allocator=DefaultAllocator> |
| 73 | +void make_Iterators (py::module &m) |
| 74 | +{ |
| 75 | + using iterator = T_ParIter; |
| 76 | + using container = typename iterator::ContainerType; |
| 77 | + constexpr int NStructReal = container::NStructReal; |
| 78 | + constexpr int NStructInt = container::NStructInt; |
| 79 | + constexpr int NArrayReal = container::NArrayReal; |
| 80 | + constexpr int NArrayInt = container::NArrayInt; |
| 81 | + |
| 82 | + using iterator_base = amrex::ParIterBase<is_const, NStructReal, NStructInt, NArrayReal, NArrayInt, Allocator>; |
| 83 | + make_Base_Iterators< is_const, iterator_base >(m); |
| 84 | + |
| 85 | + auto particle_it_name = std::string("Par"); |
| 86 | + if (is_const) particle_it_name += "Const"; |
| 87 | + particle_it_name += std::string("Iter_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt)); |
| 88 | + py::class_<iterator, iterator_base>(m, particle_it_name.c_str()) |
| 89 | + .def("__repr__", |
| 90 | + [particle_it_name](iterator const & pti) { |
| 91 | + std::string r = "<amrex." + particle_it_name + " ("; |
| 92 | + if( !pti.isValid() ) { r.append("in"); } |
| 93 | + r.append("valid)>"); |
| 94 | + return r; |
| 95 | + } |
| 96 | + ) |
| 97 | + .def(py::init<container&, int>(), |
| 98 | + py::arg("particle_container"), py::arg("level")) |
| 99 | + .def(py::init<container&, int, MFItInfo&>(), |
| 100 | + py::arg("particle_container"), py::arg("level"), py::arg("info")) |
| 101 | + ; |
| 102 | +} |
| 103 | + |
22 | 104 | template <int T_NStructReal, int T_NStructInt=0, int T_NArrayReal=0, int T_NArrayInt=0, |
23 | 105 | template<class> class Allocator=DefaultAllocator> |
24 | | -void make_ParticleContainer(py::module &m) |
| 106 | +void make_ParticleContainer_and_Iterators (py::module &m) |
25 | 107 | { |
26 | 108 | using ParticleContainerType = ParticleContainer< |
27 | 109 | T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, |
@@ -237,17 +319,21 @@ void make_ParticleContainer(py::module &m) |
237 | 319 | // m_particles[lev][index].define(NumRuntimeRealComps(), NumRuntimeIntComps()); |
238 | 320 | // return ParticlesAt(lev, iter); |
239 | 321 | // } |
240 | | - |
241 | 322 | ; |
| 323 | + |
| 324 | + using iterator = amrex::ParIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>; |
| 325 | + make_Iterators< false, iterator, Allocator >(m); |
| 326 | + using const_iterator = amrex::ParConstIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>; |
| 327 | + make_Iterators< true, const_iterator, Allocator >(m); |
242 | 328 | } |
243 | 329 |
|
244 | 330 |
|
245 | 331 | void init_ParticleContainer(py::module& m) { |
246 | 332 | // TODO: we might need to move all or most of the defines in here into a |
247 | 333 | // test/example submodule, so they do not collide with downstream projects |
248 | | - make_ParticleContainer< 1, 1, 2, 1> (m); |
249 | | - make_ParticleContainer< 0, 0, 4, 0> (m); // HiPACE++ 22.07 |
250 | | - make_ParticleContainer< 0, 0, 5, 0> (m); // ImpactX 22.07 |
251 | | - make_ParticleContainer< 0, 0, 7, 0> (m); |
252 | | - make_ParticleContainer< 0, 0, 37, 1> (m); // HiPACE++ 22.07 |
| 334 | + make_ParticleContainer_and_Iterators< 1, 1, 2, 1> (m); |
| 335 | + make_ParticleContainer_and_Iterators< 0, 0, 4, 0> (m); // HiPACE++ 22.07 |
| 336 | + make_ParticleContainer_and_Iterators< 0, 0, 5, 0> (m); // ImpactX 22.07 |
| 337 | + make_ParticleContainer_and_Iterators< 0, 0, 7, 0> (m); |
| 338 | + make_ParticleContainer_and_Iterators< 0, 0, 37, 1> (m); // HiPACE++ 22.07 |
253 | 339 | } |
0 commit comments