@@ -240,36 +240,42 @@ namespace impactx
240240 amrex::Gpu::DeviceVector<amrex::ParticleReal> const & px,
241241 amrex::Gpu::DeviceVector<amrex::ParticleReal> const & py,
242242 amrex::Gpu::DeviceVector<amrex::ParticleReal> const & pt,
243- amrex::Gpu::DeviceVector<amrex::ParticleReal> const & sx,
244- amrex::Gpu::DeviceVector<amrex::ParticleReal> const & sy,
245- amrex::Gpu::DeviceVector<amrex::ParticleReal> const & sz,
246243 amrex::ParticleReal qm,
247244 std::optional<amrex::ParticleReal> bunch_charge,
248- std::optional<amrex::Gpu::DeviceVector<amrex::ParticleReal>> w
245+ std::optional<amrex::Gpu::DeviceVector<amrex::ParticleReal>> w,
246+ std::optional<amrex::Gpu::DeviceVector<amrex::ParticleReal>> sx,
247+ std::optional<amrex::Gpu::DeviceVector<amrex::ParticleReal>> sy,
248+ std::optional<amrex::Gpu::DeviceVector<amrex::ParticleReal>> sz
249249 )
250250 {
251251 BL_PROFILE (" ImpactX::AddNParticles" );
252252
253253 using namespace amrex ::literals; // for _rt and _prt
254254
255+ // number of particles to add
256+ std::size_t const np_s = x.size ();
257+
258+ // input validation
255259 bool const has_w = w.has_value ();
256260 if (!(bunch_charge.has_value () ^ has_w))
257261 {
258262 throw std::runtime_error (" AddNParticles: Exactly one of bunch_charge or w must be provided!" );
259263 }
260264
261- AMREX_ALWAYS_ASSERT (x.size () == y.size ());
262- AMREX_ALWAYS_ASSERT (x.size () == t.size ());
263- AMREX_ALWAYS_ASSERT (x.size () == px.size ());
264- AMREX_ALWAYS_ASSERT (x.size () == py.size ());
265- AMREX_ALWAYS_ASSERT (x.size () == pt.size ());
266- AMREX_ALWAYS_ASSERT (x.size () == sx.size ());
267- AMREX_ALWAYS_ASSERT (x.size () == sy.size ());
268- AMREX_ALWAYS_ASSERT (x.size () == sz.size ());
269- if (has_w) { AMREX_ALWAYS_ASSERT (x.size () == w->size ()); }
270-
271- // number of particles to add
272- amrex::Long const np = x.size ();
265+ bool const has_spin = sx.has_value ();
266+
267+ AMREX_ALWAYS_ASSERT (np_s == y.size ());
268+ AMREX_ALWAYS_ASSERT (np_s == t.size ());
269+ AMREX_ALWAYS_ASSERT (np_s == px.size ());
270+ AMREX_ALWAYS_ASSERT (np_s == py.size ());
271+ AMREX_ALWAYS_ASSERT (np_s == pt.size ());
272+ if (has_spin) {
273+ AMREX_ALWAYS_ASSERT (sy.has_value ());
274+ AMREX_ALWAYS_ASSERT (sz.has_value ());
275+ AMREX_ALWAYS_ASSERT (np_s == sy.value ().size ());
276+ AMREX_ALWAYS_ASSERT (np_s == sz.value ().size ());
277+ }
278+ if (has_w) { AMREX_ALWAYS_ASSERT (np_s == w->size ()); }
273279
274280 // we add particles to lev 0, grid 0
275281 int lid = 0 , gid = 0 ;
@@ -294,7 +300,8 @@ namespace impactx
294300 DefineAndReturnParticleTile (lid, gid, ithr);
295301 }
296302
297- amrex::Long pid = ParticleType::NextID ();
303+ amrex::Long const pid = ParticleType::NextID ();
304+ amrex::Long const np = np_s;
298305 ParticleType::NextID (pid + np);
299306 AMREX_ALWAYS_ASSERT_WITH_MESSAGE (
300307 pid + np < amrex::LongParticleIds::LastParticleID,
@@ -337,9 +344,9 @@ namespace impactx
337344 amrex::ParticleReal * const AMREX_RESTRICT px_arr = soa[RealSoA::px].dataPtr ();
338345 amrex::ParticleReal * const AMREX_RESTRICT py_arr = soa[RealSoA::py].dataPtr ();
339346 amrex::ParticleReal * const AMREX_RESTRICT pt_arr = soa[RealSoA::pt].dataPtr ();
340- amrex::ParticleReal * const AMREX_RESTRICT sx_arr = soa[RealSoA::sx].dataPtr ();
341- amrex::ParticleReal * const AMREX_RESTRICT sy_arr = soa[RealSoA::sy].dataPtr ();
342- amrex::ParticleReal * const AMREX_RESTRICT sz_arr = soa[RealSoA::sz].dataPtr ();
347+ amrex::ParticleReal * const AMREX_RESTRICT sx_arr = has_spin ? soa[RealSoA::sx].dataPtr () : nullptr ;
348+ amrex::ParticleReal * const AMREX_RESTRICT sy_arr = has_spin ? soa[RealSoA::sy].dataPtr () : nullptr ;
349+ amrex::ParticleReal * const AMREX_RESTRICT sz_arr = has_spin ? soa[RealSoA::sz].dataPtr () : nullptr ;
343350 amrex::ParticleReal * const AMREX_RESTRICT qm_arr = soa[RealSoA::qm].dataPtr ();
344351 amrex::ParticleReal * const AMREX_RESTRICT w_arr = soa[RealSoA::w ].dataPtr ();
345352
@@ -351,9 +358,9 @@ namespace impactx
351358 amrex::ParticleReal const * const AMREX_RESTRICT px_ptr = px.data ();
352359 amrex::ParticleReal const * const AMREX_RESTRICT py_ptr = py.data ();
353360 amrex::ParticleReal const * const AMREX_RESTRICT pt_ptr = pt.data ();
354- amrex::ParticleReal const * const AMREX_RESTRICT sx_ptr = sx.data ();
355- amrex::ParticleReal const * const AMREX_RESTRICT sy_ptr = sy.data ();
356- amrex::ParticleReal const * const AMREX_RESTRICT sz_ptr = sz.data ();
361+ amrex::ParticleReal const * const AMREX_RESTRICT sx_ptr = has_spin ? sx.value (). data () : nullptr ;
362+ amrex::ParticleReal const * const AMREX_RESTRICT sy_ptr = has_spin ? sy.value (). data () : nullptr ;
363+ amrex::ParticleReal const * const AMREX_RESTRICT sz_ptr = has_spin ? sz.value (). data () : nullptr ;
357364 amrex::ParticleReal const * const AMREX_RESTRICT w_ptr = has_w ? w->data () : nullptr ;
358365 amrex::ParticleReal const bunch_charge_value = has_w ? 0_prt : bunch_charge.value ();
359366
@@ -370,9 +377,15 @@ namespace impactx
370377 py_arr[old_np+i] = py_ptr[my_offset+i];
371378 pt_arr[old_np+i] = pt_ptr[my_offset+i];
372379
373- sx_arr[old_np+i] = sx_ptr[my_offset+i];
374- sy_arr[old_np+i] = sy_ptr[my_offset+i];
375- sz_arr[old_np+i] = sz_ptr[my_offset+i];
380+ if (has_spin) {
381+ sx_arr[old_np+i] = sx_ptr[my_offset+i];
382+ sy_arr[old_np+i] = sy_ptr[my_offset+i];
383+ sz_arr[old_np+i] = sz_ptr[my_offset+i];
384+ } else {
385+ sx_arr[old_np+i] = 0_prt;
386+ sy_arr[old_np+i] = 0_prt;
387+ sz_arr[old_np+i] = 0_prt;
388+ }
376389
377390 qm_arr[old_np+i] = qm;
378391 w_arr[old_np+i] = has_w ? w_ptr[my_offset+i] : bunch_charge_value / ablastr::constant::SI::q_e/np;
0 commit comments