@@ -210,6 +210,9 @@ def test_add_deposited_particle_field():
210210 else :
211211 assert_almost_equal (ret .sum (), ad ["io" , "particle_mass" ].sum ())
212212
213+ # Make sure the shapes are correct
214+ assert_equal (ret .ndim , 1 )
215+
213216 # Test "weighted_mean" method
214217 fn = base_ds .add_deposited_particle_field (
215218 ("io" , "particle_ones" ), "weighted_mean" , weight_field = "particle_ones"
@@ -218,6 +221,54 @@ def test_add_deposited_particle_field():
218221 ret = ad [fn ]
219222 # The sum should equal the number of cells that have particles
220223 assert_equal (ret .sum (), np .count_nonzero (ad ["deposit" , "io_count" ]))
224+ assert_equal (ret .ndim , 1 )
225+
226+
227+ def test_add_deposited_particle_vector_field ():
228+ ds = get_base_ds (1 )
229+ ad = ds .all_data ()
230+
231+ def vector_field (data ):
232+ return np .ones (data ["io" , "particle_mass" ].shape + (10 ,))
233+
234+ ds .add_field (
235+ ("io" , "vector_field" ),
236+ function = vector_field ,
237+ units = "" ,
238+ sampling_type = "particle" ,
239+ vector_field = True ,
240+ )
241+
242+ # We need to mention the vector_field=True flag here
243+ assert_raises (
244+ RuntimeError ,
245+ ds .add_deposited_particle_field ,
246+ ("io" , "vector_field" ),
247+ method = "nearest" ,
248+ )
249+
250+ Ncell = len (ad ["index" , "dx" ])
251+
252+ for method in ("sum" , "cic" , "nearest" ):
253+ fname = ds .add_deposited_particle_field (
254+ ("io" , "vector_field" ), method = method , vector_field = True
255+ )
256+ ad = ds .all_data ()
257+ ret = ad [fname ]
258+ assert_equal (ret .shape , (Ncell , 10 ))
259+
260+ if method == "sum" :
261+ ref = ds .add_deposited_particle_field (("io" , "particle_ones" ), method = "sum" )
262+ ref_data = ad [ref ]
263+ assert_equal (ret .sum (), ref_data .sum () * 10 )
264+
265+ fname = ds .add_deposited_particle_field (
266+ ("io" , "vector_field" ), method = "count" , vector_field = True
267+ )
268+ ref_fname = ds .add_deposited_particle_field (("io" , "particle_ones" ), method = "count" )
269+ # Note here: some particles may fall outside the domain,
270+ # so we compare wrt the non-vectorized result
271+ assert_equal (ad [fname ].sum (), ad [ref_fname ].sum ())
221272
222273
223274def test_add_gradient_fields ():
0 commit comments