Skip to content

Commit c3ac96e

Browse files
authored
Python: WarpXParticleContainer.add_particles (#6350)
Move from wrapper to fundamental class.
1 parent 31eda89 commit c3ac96e

File tree

3 files changed

+187
-116
lines changed

3 files changed

+187
-116
lines changed

Python/pywarpx/_libwarpx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,13 @@ def load_library(self):
134134
from .extensions.MultiFabRegister import (
135135
register_warpx_MultiFabRegister_extension,
136136
)
137+
from .extensions.WarpXParticleContainer import (
138+
register_warpx_WarpXParticleContainer_extension,
139+
)
137140

138141
register_warpx_MultiFab_extension(self.amr)
139142
register_warpx_MultiFabRegister_extension(self.libwarpx_so)
143+
register_warpx_WarpXParticleContainer_extension(self.libwarpx_so)
140144

141145
def amrex_init(self, argv, mpi_comm=None):
142146
if mpi_comm is None: # or MPI is None:
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
This file is part of WarpX
3+
4+
Copyright 2024 WarpX community
5+
Authors: Axel Huebl, David Grote
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def add_particles(
11+
self,
12+
x=None,
13+
y=None,
14+
z=None,
15+
ux=None,
16+
uy=None,
17+
uz=None,
18+
w=None,
19+
unique_particles=True,
20+
**kwargs,
21+
):
22+
"""
23+
A function for adding particles to the WarpX simulation.
24+
25+
Parameters
26+
----------
27+
28+
species_name : str
29+
The type of species for which particles will be added
30+
31+
x, y, z : arrays or scalars
32+
The particle positions (m) (default = 0.)
33+
34+
ux, uy, uz : arrays or scalars
35+
The particle proper velocities (m/s) (default = 0.)
36+
37+
w : array or scalars
38+
Particle weights (default = 0.)
39+
40+
unique_particles : bool
41+
True means the added particles are duplicated by each process;
42+
False means the number of added particles is independent of
43+
the number of processes (default = True)
44+
45+
kwargs : dict
46+
Containing an entry for all the extra particle attribute arrays. If
47+
an attribute is not given it will be set to 0.
48+
"""
49+
import numpy as np
50+
51+
from .._libwarpx import libwarpx
52+
53+
# --- Get length of arrays, set to one for scalars
54+
lenx = np.size(x)
55+
leny = np.size(y)
56+
lenz = np.size(z)
57+
lenux = np.size(ux)
58+
lenuy = np.size(uy)
59+
lenuz = np.size(uz)
60+
lenw = np.size(w)
61+
62+
# --- Find the max length of the parameters supplied
63+
maxlen = 0
64+
if x is not None:
65+
maxlen = max(maxlen, lenx)
66+
if y is not None:
67+
maxlen = max(maxlen, leny)
68+
if z is not None:
69+
maxlen = max(maxlen, lenz)
70+
if ux is not None:
71+
maxlen = max(maxlen, lenux)
72+
if uy is not None:
73+
maxlen = max(maxlen, lenuy)
74+
if uz is not None:
75+
maxlen = max(maxlen, lenuz)
76+
if w is not None:
77+
maxlen = max(maxlen, lenw)
78+
79+
# --- Make sure that the lengths of the input parameters are consistent
80+
assert x is None or lenx == maxlen or lenx == 1, (
81+
"Length of x doesn't match len of others"
82+
)
83+
assert y is None or leny == maxlen or leny == 1, (
84+
"Length of y doesn't match len of others"
85+
)
86+
assert z is None or lenz == maxlen or lenz == 1, (
87+
"Length of z doesn't match len of others"
88+
)
89+
assert ux is None or lenux == maxlen or lenux == 1, (
90+
"Length of ux doesn't match len of others"
91+
)
92+
assert uy is None or lenuy == maxlen or lenuy == 1, (
93+
"Length of uy doesn't match len of others"
94+
)
95+
assert uz is None or lenuz == maxlen or lenuz == 1, (
96+
"Length of uz doesn't match len of others"
97+
)
98+
assert w is None or lenw == maxlen or lenw == 1, (
99+
"Length of w doesn't match len of others"
100+
)
101+
for key, val in kwargs.items():
102+
assert np.size(val) == 1 or len(val) == maxlen, (
103+
f"Length of {key} doesn't match len of others"
104+
)
105+
106+
# --- Broadcast scalars into appropriate length arrays
107+
# --- If the parameter was not supplied, use the default value
108+
if lenx == 1:
109+
x = np.full(maxlen, (x or 0.0))
110+
if leny == 1:
111+
y = np.full(maxlen, (y or 0.0))
112+
if lenz == 1:
113+
z = np.full(maxlen, (z or 0.0))
114+
if lenux == 1:
115+
ux = np.full(maxlen, (ux or 0.0))
116+
if lenuy == 1:
117+
uy = np.full(maxlen, (uy or 0.0))
118+
if lenuz == 1:
119+
uz = np.full(maxlen, (uz or 0.0))
120+
if lenw == 1:
121+
w = np.full(maxlen, (w or 0.0))
122+
for key, val in kwargs.items():
123+
if np.size(val) == 1:
124+
kwargs[key] = np.full(maxlen, val)
125+
126+
# --- The number of built in attributes
127+
# --- The positions
128+
built_in_attrs = libwarpx.dim
129+
# --- The three velocities
130+
built_in_attrs += 3
131+
if libwarpx.geometry_dim == "rz":
132+
# --- With RZ, there is also theta
133+
built_in_attrs += 1
134+
135+
# --- The number of extra attributes (including the weight)
136+
nattr = self.num_real_comps - built_in_attrs
137+
attr = np.zeros((maxlen, nattr))
138+
attr[:, 0] = w
139+
140+
# --- Note that the velocities are handled separately and not included in attr
141+
# --- (even though they are stored as attributes in the C++)
142+
for key, vals in kwargs.items():
143+
attr[:, self.get_real_comp_index(key) - built_in_attrs] = vals
144+
145+
nattr_int = 0
146+
attr_int = np.empty([0], dtype=np.int32)
147+
148+
# TODO: expose ParticleReal through pyAMReX
149+
# and cast arrays to the correct types, before calling add_n_particles
150+
# x = x.astype(self._numpy_particlereal_dtype, copy=False)
151+
# y = y.astype(self._numpy_particlereal_dtype, copy=False)
152+
# z = z.astype(self._numpy_particlereal_dtype, copy=False)
153+
# ux = ux.astype(self._numpy_particlereal_dtype, copy=False)
154+
# uy = uy.astype(self._numpy_particlereal_dtype, copy=False)
155+
# uz = uz.astype(self._numpy_particlereal_dtype, copy=False)
156+
157+
self.add_n_particles(
158+
0,
159+
x.size,
160+
x,
161+
y,
162+
z,
163+
ux,
164+
uy,
165+
uz,
166+
nattr,
167+
attr,
168+
nattr_int,
169+
attr_int,
170+
unique_particles,
171+
)
172+
173+
174+
def register_warpx_WarpXParticleContainer_extension(libwarpx_so):
175+
"""WarpXParticleContainer helper methods"""
176+
177+
# Register the overload dispatcher
178+
# note: this currently overwrites the pyAMReX signature
179+
# add_particles(other: ParticleContainer, local: bool = False)
180+
libwarpx_so.WarpXParticleContainer.add_particles = add_particles

Python/pywarpx/particle_containers.py

Lines changed: 3 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#
77
# License: BSD-3-Clause-LBNL
88

9-
import numpy as np
109

1110
from ._libwarpx import libwarpx
1211
from .LoadThirdParty import load_cupy
@@ -57,7 +56,6 @@ def add_particles(
5756
5857
Parameters
5958
----------
60-
6159
species_name : str
6260
The type of species for which particles will be added
6361
@@ -79,127 +77,16 @@ def add_particles(
7977
Containing an entry for all the extra particle attribute arrays. If
8078
an attribute is not given it will be set to 0.
8179
"""
82-
83-
# --- Get length of arrays, set to one for scalars
84-
lenx = np.size(x)
85-
leny = np.size(y)
86-
lenz = np.size(z)
87-
lenux = np.size(ux)
88-
lenuy = np.size(uy)
89-
lenuz = np.size(uz)
90-
lenw = np.size(w)
91-
92-
# --- Find the max length of the parameters supplied
93-
maxlen = 0
94-
if x is not None:
95-
maxlen = max(maxlen, lenx)
96-
if y is not None:
97-
maxlen = max(maxlen, leny)
98-
if z is not None:
99-
maxlen = max(maxlen, lenz)
100-
if ux is not None:
101-
maxlen = max(maxlen, lenux)
102-
if uy is not None:
103-
maxlen = max(maxlen, lenuy)
104-
if uz is not None:
105-
maxlen = max(maxlen, lenuz)
106-
if w is not None:
107-
maxlen = max(maxlen, lenw)
108-
109-
# --- Make sure that the lengths of the input parameters are consistent
110-
assert x is None or lenx == maxlen or lenx == 1, (
111-
"Length of x doesn't match len of others"
112-
)
113-
assert y is None or leny == maxlen or leny == 1, (
114-
"Length of y doesn't match len of others"
115-
)
116-
assert z is None or lenz == maxlen or lenz == 1, (
117-
"Length of z doesn't match len of others"
118-
)
119-
assert ux is None or lenux == maxlen or lenux == 1, (
120-
"Length of ux doesn't match len of others"
121-
)
122-
assert uy is None or lenuy == maxlen or lenuy == 1, (
123-
"Length of uy doesn't match len of others"
124-
)
125-
assert uz is None or lenuz == maxlen or lenuz == 1, (
126-
"Length of uz doesn't match len of others"
127-
)
128-
assert w is None or lenw == maxlen or lenw == 1, (
129-
"Length of w doesn't match len of others"
130-
)
131-
for key, val in kwargs.items():
132-
assert np.size(val) == 1 or len(val) == maxlen, (
133-
f"Length of {key} doesn't match len of others"
134-
)
135-
136-
# --- Broadcast scalars into appropriate length arrays
137-
# --- If the parameter was not supplied, use the default value
138-
if lenx == 1:
139-
x = np.full(maxlen, (x or 0.0))
140-
if leny == 1:
141-
y = np.full(maxlen, (y or 0.0))
142-
if lenz == 1:
143-
z = np.full(maxlen, (z or 0.0))
144-
if lenux == 1:
145-
ux = np.full(maxlen, (ux or 0.0))
146-
if lenuy == 1:
147-
uy = np.full(maxlen, (uy or 0.0))
148-
if lenuz == 1:
149-
uz = np.full(maxlen, (uz or 0.0))
150-
if lenw == 1:
151-
w = np.full(maxlen, (w or 0.0))
152-
for key, val in kwargs.items():
153-
if np.size(val) == 1:
154-
kwargs[key] = np.full(maxlen, val)
155-
156-
# --- The number of built in attributes
157-
# --- The positions
158-
built_in_attrs = libwarpx.dim
159-
# --- The three velocities
160-
built_in_attrs += 3
161-
if libwarpx.geometry_dim == "rz":
162-
# --- With RZ, there is also theta
163-
built_in_attrs += 1
164-
165-
# --- The number of extra attributes (including the weight)
166-
nattr = self.particle_container.num_real_comps - built_in_attrs
167-
attr = np.zeros((maxlen, nattr))
168-
attr[:, 0] = w
169-
170-
# --- Note that the velocities are handled separately and not included in attr
171-
# --- (even though they are stored as attributes in the C++)
172-
for key, vals in kwargs.items():
173-
attr[
174-
:, self.particle_container.get_real_comp_index(key) - built_in_attrs
175-
] = vals
176-
177-
nattr_int = 0
178-
attr_int = np.empty([0], dtype=np.int32)
179-
180-
# TODO: expose ParticleReal through pyAMReX
181-
# and cast arrays to the correct types, before calling add_n_particles
182-
# x = x.astype(self._numpy_particlereal_dtype, copy=False)
183-
# y = y.astype(self._numpy_particlereal_dtype, copy=False)
184-
# z = z.astype(self._numpy_particlereal_dtype, copy=False)
185-
# ux = ux.astype(self._numpy_particlereal_dtype, copy=False)
186-
# uy = uy.astype(self._numpy_particlereal_dtype, copy=False)
187-
# uz = uz.astype(self._numpy_particlereal_dtype, copy=False)
188-
189-
self.particle_container.add_n_particles(
190-
0,
191-
x.size,
80+
self.particle_container.add_particles(
19281
x,
19382
y,
19483
z,
19584
ux,
19685
uy,
19786
uz,
198-
nattr,
199-
attr,
200-
nattr_int,
201-
attr_int,
87+
w,
20288
unique_particles,
89+
**kwargs,
20390
)
20491

20592
def get_particle_count(self, local=False):

0 commit comments

Comments
 (0)