7676@inline ODE_DEFAULT_ISOUTOFDOMAIN(u, p, t) = false
7777
7878"""
79- update_velocity (v, r, param, dt, t )
79+ boris_velocity_update (v, E, B, qdt_2m )
8080
8181Update velocity using the Boris method, returning the new velocity as an SVector.
82+ This is the core logic shared between the standard solver and the kernel solver.
8283"""
83- @muladd function update_velocity(v, r, param, dt, t)
84- q2m, _, Efunc, Bfunc, _ = param
85- E = Efunc(r, t)
86- B = Bfunc(r, t)
87-
88- t_rotate = q2m * B * 0.5 * dt
84+ @inline @muladd function boris_velocity_update(v, E, B, qdt_2m)
85+ t_rotate = qdt_2m * B
8986 t_mag2 = sum(abs2, t_rotate)
9087 s_rotate = 2 * t_rotate / (1 + t_mag2)
9188
92- v_minus = v + q2m * E * 0.5 * dt
89+ v_minus = v + qdt_2m * E
9390 v_prime = v_minus + (v_minus × t_rotate)
9491 v_plus = v_minus + (v_prime × s_rotate)
9592
96- v_new = v_plus + q2m * E * 0.5 * dt
93+ v_new = v_plus + qdt_2m * E
9794
9895 return v_new
9996end
10097
98+ """
99+ update_velocity(v, r, param, dt, t)
100+
101+ Update velocity using the Boris method, returning the new velocity as an SVector.
102+ """
103+ @inline @muladd function update_velocity(v, r, param:: P , dt, t) where {P}
104+ q2m, _, Efunc, Bfunc, _ = param
105+ E = Efunc(r, t)
106+ B = Bfunc(r, t)
107+ qdt_2m = q2m * 0.5 * dt
108+
109+ return boris_velocity_update(v, E, B, qdt_2m)
110+ end
111+
101112"""
102113 update_velocity!(xv, paramBoris, param, dt, t)
103114
@@ -184,22 +195,22 @@ Trace particles using the Boris method with specified `prob`.
184195 - `save_work::Bool`: save the work done by the electric field. Default is `false`.
185196"""
186197@inline function solve(
187- prob:: TraceProblem , ensemblealg:: BasicEnsembleAlgorithm = EnsembleSerial();
198+ prob:: TraceProblem , ensemblealg:: EA = EnsembleSerial();
188199 trajectories:: Int = 1 , savestepinterval:: Int = 1 , dt:: AbstractFloat ,
189- isoutofdomain:: Function = ODE_DEFAULT_ISOUTOFDOMAIN, n:: Int = 1 ,
200+ isoutofdomain:: F = ODE_DEFAULT_ISOUTOFDOMAIN, n:: Int = 1 ,
190201 save_start:: Bool = true , save_end:: Bool = true , save_everystep:: Bool = true ,
191202 save_fields:: Bool = false , save_work:: Bool = false
192- )
203+ ) where {EA <: BasicEnsembleAlgorithm , F}
193204 return _solve(
194205 ensemblealg, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
195206 save_start, save_end, save_everystep, Val(save_fields), Val(save_work)
196207 )
197208end
198209
199210function _dispatch_boris!(
200- sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n,
211+ sols, prob:: TraceProblem , irange, savestepinterval, dt, nt, nout, isoutofdomain:: F , n,
201212 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork}
202- ) where {SaveFields, SaveWork}
213+ ) where {SaveFields, SaveWork, F }
203214 return if n == 1
204215 _boris!(
205216 sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@@ -213,10 +224,10 @@ function _dispatch_boris!(
213224 end
214225end
215226
216- function _solve(
217- :: EnsembleSerial , prob, trajectories, dt, savestepinterval, isoutofdomain, n,
227+ @inline function _solve(
228+ :: EnsembleSerial , prob:: TraceProblem , trajectories, dt, savestepinterval, isoutofdomain:: F , n,
218229 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork}
219- ) where {SaveFields, SaveWork}
230+ ) where {SaveFields, SaveWork, F }
220231 sols, nt,
221232 nout = _prepare(
222233 prob, trajectories, dt, savestepinterval,
@@ -231,10 +242,10 @@ function _solve(
231242 return sols
232243end
233244
234- function _solve(
235- :: EnsembleThreads , prob, trajectories, dt, savestepinterval, isoutofdomain, n,
245+ @inline function _solve(
246+ :: EnsembleThreads , prob:: TraceProblem , trajectories, dt, savestepinterval, isoutofdomain:: F , n,
236247 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork}
237- ) where {SaveFields, SaveWork}
248+ ) where {SaveFields, SaveWork, F }
238249 sols, nt,
239250 nout = _prepare(
240251 prob, trajectories, dt, savestepinterval,
@@ -351,13 +362,41 @@ end
351362"""
352363Apply Boris method for particles with index in `irange`.
353364"""
354- @muladd function _generic_boris!(
355- sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
365+ @inline function _boris_loop!(
366+ traj, tsave, iout, r, v, p, dt, nt, tspan,
367+ savestepinterval, save_everystep, isoutofdomain:: F1 , velocity_updater:: F2 ,
368+ :: Val{SaveFields} , :: Val{SaveWork}
369+ ) where {F1, F2, SaveFields, SaveWork}
370+ it = 1
371+ while it <= nt
372+ v_prev = v
373+ t = (it - 0.5 ) * dt
374+ v = velocity_updater(v, r, p, dt, t)
375+
376+ if save_everystep && (it - 1 ) > 0 && (it - 1 ) % savestepinterval == 0
377+ iout += 1
378+ if iout <= length(traj)
379+ t_current = tspan[1 ] + (it - 1 ) * dt
380+ v_save = velocity_updater(v_prev, r, p, 0.5 * dt, t_current)
381+ data = vcat(r, v_save)
382+ traj[iout] = _prepare_saved_data(data, p, t_current, Val(SaveFields), Val(SaveWork))
383+ tsave[iout] = t_current
384+ end
385+ end
386+
387+ r += v * dt
388+ isoutofdomain(vcat(r, v), p, it * dt) && break
389+ it += 1
390+ end
391+ return it, iout, r, v
392+ end
393+
394+ @inline @muladd function _generic_boris!(
395+ sols, prob:: TraceProblem , irange, savestepinterval, dt, nt, nout, isoutofdomain:: F1 ,
356396 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork} ,
357- velocity_updater, alg_name
358- ) where {SaveFields, SaveWork}
397+ velocity_updater:: F2 , alg_name
398+ ) where {SaveFields, SaveWork, F1, F2 }
359399 (; tspan, p, u0) = prob
360- q2m, m, Efunc, Bfunc, _ = p
361400 T = eltype(u0)
362401
363402 vars_dim = 6
@@ -368,57 +407,31 @@ Apply Boris method for particles with index in `irange`.
368407 vars_dim += 4
369408 end
370409
371- @fastmath @ inbounds for i in irange
410+ @inbounds for i in irange
372411 traj = Vector{SVector{vars_dim, T}}(undef, nout)
373412 tsave = Vector{typeof(tspan[1 ] + dt)}(undef, nout)
374413
375414 # set initial conditions for each trajectory i
376415 iout = 0
377416 new_prob = prob. prob_func(prob, i, false )
378- # Load independent r and v SVector from u0
379417 u0_i = SVector{6 , T}(new_prob. u0)
380418 r = u0_i[SVector(1 , 2 , 3 )]
381419 v = u0_i[SVector(4 , 5 , 6 )]
382420
383421 if save_start
384422 iout += 1
385- traj[iout] = _prepare_saved_data(
386- u0_i, p, tspan[1 ], Val(SaveFields), Val(SaveWork)
387- )
423+ traj[iout] = _prepare_saved_data(u0_i, p, tspan[1 ], Val(SaveFields), Val(SaveWork))
388424 tsave[iout] = tspan[1 ]
389425 end
390426
391427 # push velocity back in time by 1/2 dt
392428 v = velocity_updater(v, r, p, - 0.5 * dt, tspan[1 ])
393429
394- it = 1
395- while it <= nt
396- v_prev = v
397- t = (it - 0.5 ) * dt
398- v = velocity_updater(v, r, p, dt, t)
399-
400- if save_everystep && (it - 1 ) > 0 && (it - 1 ) % savestepinterval == 0
401- iout += 1
402- if iout <= nout
403- t_current = tspan[1 ] + (it - 1 ) * dt
404- # Approximate v_n from v_{n-1/2} (v_prev)
405- v_save = velocity_updater(
406- v_prev, r, p, 0.5 * dt,
407- t_current
408- )
409-
410- data = vcat(r, v_save)
411- traj[iout] = _prepare_saved_data(
412- data, p, t_current, Val(SaveFields), Val(SaveWork)
413- )
414- tsave[iout] = t_current
415- end
416- end
417-
418- r += v * dt
419- isoutofdomain(vcat(r, v), p, it * dt) && break
420- it += 1
421- end
430+ it, iout, r, v = _boris_loop!(
431+ traj, tsave, iout, r, v, p, dt, nt, tspan,
432+ savestepinterval, save_everystep, isoutofdomain, velocity_updater,
433+ Val(SaveFields), Val(SaveWork)
434+ )
422435
423436 final_step = min(it, nt)
424437 should_save_final = false
@@ -464,10 +477,10 @@ end
464477"""
465478Apply Boris method for particles with index in `irange`.
466479"""
467- @muladd function _boris!(
468- sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
480+ @inline @ muladd function _boris!(
481+ sols, prob:: TraceProblem , irange, savestepinterval, dt, nt, nout, isoutofdomain:: F ,
469482 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork}
470- ) where {SaveFields, SaveWork}
483+ ) where {SaveFields, SaveWork, F }
471484
472485 _generic_boris!(
473486 sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@@ -541,10 +554,10 @@ Reference: [Zenitani & Kato 2025](https://arxiv.org/abs/2505.02270)
541554 return v_new
542555end
543556
544- @muladd function _multistep_boris!(
545- sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n_steps:: Int ,
557+ @inline @ muladd function _multistep_boris!(
558+ sols, prob:: TraceProblem , irange, savestepinterval, dt, nt, nout, isoutofdomain:: F , n_steps:: Int ,
546559 save_start, save_end, save_everystep, :: Val{SaveFields} , :: Val{SaveWork}
547- ) where {SaveFields, SaveWork}
560+ ) where {SaveFields, SaveWork, F }
548561
549562 velocity_updater = (v, r, p, dt, t) ->
550563 update_velocity_multistep(v, r, p, dt, t, n_steps)
0 commit comments