Skip to content

Commit 3526408

Browse files
committed
Test vectorized particle deposition
1 parent 2884b76 commit 3526408

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

yt/fields/tests/test_fields.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def test_add_deposited_particle_field():
210210
else:
211211
assert_almost_equal(ret.sum(), ad["io", "particle_mass"].sum())
212212

213+
# Make sure the shapes are correct
214+
assert_equal(ret.ndim, 1)
215+
213216
# Test "weighted_mean" method
214217
fn = base_ds.add_deposited_particle_field(
215218
("io", "particle_ones"), "weighted_mean", weight_field="particle_ones"
@@ -218,6 +221,54 @@ def test_add_deposited_particle_field():
218221
ret = ad[fn]
219222
# The sum should equal the number of cells that have particles
220223
assert_equal(ret.sum(), np.count_nonzero(ad["deposit", "io_count"]))
224+
assert_equal(ret.ndim, 1)
225+
226+
227+
def test_add_deposited_particle_vector_field():
228+
ds = get_base_ds(1)
229+
ad = ds.all_data()
230+
231+
def vector_field(data):
232+
return np.ones(data["io", "particle_mass"].shape + (10,))
233+
234+
ds.add_field(
235+
("io", "vector_field"),
236+
function=vector_field,
237+
units="",
238+
sampling_type="particle",
239+
vector_field=True,
240+
)
241+
242+
# We need to mention the vector_field=True flag here
243+
assert_raises(
244+
RuntimeError,
245+
ds.add_deposited_particle_field,
246+
("io", "vector_field"),
247+
method="nearest",
248+
)
249+
250+
Ncell = len(ad["index", "dx"])
251+
252+
for method in ("sum", "cic", "nearest"):
253+
fname = ds.add_deposited_particle_field(
254+
("io", "vector_field"), method=method, vector_field=True
255+
)
256+
ad = ds.all_data()
257+
ret = ad[fname]
258+
assert_equal(ret.shape, (Ncell, 10))
259+
260+
if method == "sum":
261+
ref = ds.add_deposited_particle_field(("io", "particle_ones"), method="sum")
262+
ref_data = ad[ref]
263+
assert_equal(ret.sum(), ref_data.sum() * 10)
264+
265+
fname = ds.add_deposited_particle_field(
266+
("io", "vector_field"), method="count", vector_field=True
267+
)
268+
ref_fname = ds.add_deposited_particle_field(("io", "particle_ones"), method="count")
269+
# Note here: some particles may fall outside the domain,
270+
# so we compare wrt the non-vectorized result
271+
assert_equal(ad[fname].sum(), ad[ref_fname].sum())
221272

222273

223274
def test_add_gradient_fields():

0 commit comments

Comments
 (0)