Skip to content

Commit a064dbb

Browse files
authored
refactor Boris solvers (#452)
* refactor: remove temporary field arrays, reuse CPU buffers for saving, and use svector velocity update * feat: support time-dependent field interpolation in kernel. * Specialize the kernel solver in CPU version * refactor: use FieldInterpolator instead of function capture for GPU adapt * Cleanup imports and unused SphericalVectorFieldInterpolator * feat: support multithreading kernel Boris solver * refactor: make isoutofdomain and velocity_updater type stable and inline Field call * Remove redundant SVector wrapper
1 parent e337a8f commit a064dbb

File tree

10 files changed

+496
-472
lines changed

10 files changed

+496
-472
lines changed

src/TestParticle.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ import DiffResults
1616
using ChunkSplitters: index_chunks
1717
using PrecompileTools: @setup_workload, @compile_workload
1818
using MuladdMacro: @muladd
19+
using KernelAbstractions: @kernel, @index, @Const, synchronize, Backend, CPU
1920

21+
import KernelAbstractions as KA
22+
import Adapt
2023
import Tensors
2124
import Base: +, -, *, /, setindex!, getindex
2225
import LinearAlgebra: ×

src/boris.jl

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,39 @@ end
7676
@inline ODE_DEFAULT_ISOUTOFDOMAIN(u, p, t) = false
7777

7878
"""
79-
update_velocity(v, r, param, dt, t)
79+
boris_velocity_update(v, E, B, qdt_2m)
8080
8181
Update velocity using the Boris method, returning the new velocity as an SVector.
82+
This is the core logic shared between the standard solver and the kernel solver.
8283
"""
83-
@muladd function update_velocity(v, r, param, dt, t)
84-
q2m, _, Efunc, Bfunc, _ = param
85-
E = Efunc(r, t)
86-
B = Bfunc(r, t)
87-
88-
t_rotate = q2m * B * 0.5 * dt
84+
@inline @muladd function boris_velocity_update(v, E, B, qdt_2m)
85+
t_rotate = qdt_2m * B
8986
t_mag2 = sum(abs2, t_rotate)
9087
s_rotate = 2 * t_rotate / (1 + t_mag2)
9188

92-
v_minus = v + q2m * E * 0.5 * dt
89+
v_minus = v + qdt_2m * E
9390
v_prime = v_minus + (v_minus × t_rotate)
9491
v_plus = v_minus + (v_prime × s_rotate)
9592

96-
v_new = v_plus + q2m * E * 0.5 * dt
93+
v_new = v_plus + qdt_2m * E
9794

9895
return v_new
9996
end
10097

98+
"""
99+
update_velocity(v, r, param, dt, t)
100+
101+
Update velocity using the Boris method, returning the new velocity as an SVector.
102+
"""
103+
@inline @muladd function update_velocity(v, r, param::P, dt, t) where {P}
104+
q2m, _, Efunc, Bfunc, _ = param
105+
E = Efunc(r, t)
106+
B = Bfunc(r, t)
107+
qdt_2m = q2m * 0.5 * dt
108+
109+
return boris_velocity_update(v, E, B, qdt_2m)
110+
end
111+
101112
"""
102113
update_velocity!(xv, paramBoris, param, dt, t)
103114
@@ -184,22 +195,22 @@ Trace particles using the Boris method with specified `prob`.
184195
- `save_work::Bool`: save the work done by the electric field. Default is `false`.
185196
"""
186197
@inline function solve(
187-
prob::TraceProblem, ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial();
198+
prob::TraceProblem, ensemblealg::EA = EnsembleSerial();
188199
trajectories::Int = 1, savestepinterval::Int = 1, dt::AbstractFloat,
189-
isoutofdomain::Function = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1,
200+
isoutofdomain::F = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1,
190201
save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true,
191202
save_fields::Bool = false, save_work::Bool = false
192-
)
203+
) where {EA <: BasicEnsembleAlgorithm, F}
193204
return _solve(
194205
ensemblealg, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
195206
save_start, save_end, save_everystep, Val(save_fields), Val(save_work)
196207
)
197208
end
198209

199210
function _dispatch_boris!(
200-
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n,
211+
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n,
201212
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
202-
) where {SaveFields, SaveWork}
213+
) where {SaveFields, SaveWork, F}
203214
return if n == 1
204215
_boris!(
205216
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@@ -213,10 +224,10 @@ function _dispatch_boris!(
213224
end
214225
end
215226

216-
function _solve(
217-
::EnsembleSerial, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
227+
@inline function _solve(
228+
::EnsembleSerial, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n,
218229
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
219-
) where {SaveFields, SaveWork}
230+
) where {SaveFields, SaveWork, F}
220231
sols, nt,
221232
nout = _prepare(
222233
prob, trajectories, dt, savestepinterval,
@@ -231,10 +242,10 @@ function _solve(
231242
return sols
232243
end
233244

234-
function _solve(
235-
::EnsembleThreads, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
245+
@inline function _solve(
246+
::EnsembleThreads, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n,
236247
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
237-
) where {SaveFields, SaveWork}
248+
) where {SaveFields, SaveWork, F}
238249
sols, nt,
239250
nout = _prepare(
240251
prob, trajectories, dt, savestepinterval,
@@ -351,13 +362,41 @@ end
351362
"""
352363
Apply Boris method for particles with index in `irange`.
353364
"""
354-
@muladd function _generic_boris!(
355-
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
365+
@inline function _boris_loop!(
366+
traj, tsave, iout, r, v, p, dt, nt, tspan,
367+
savestepinterval, save_everystep, isoutofdomain::F1, velocity_updater::F2,
368+
::Val{SaveFields}, ::Val{SaveWork}
369+
) where {F1, F2, SaveFields, SaveWork}
370+
it = 1
371+
while it <= nt
372+
v_prev = v
373+
t = (it - 0.5) * dt
374+
v = velocity_updater(v, r, p, dt, t)
375+
376+
if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0
377+
iout += 1
378+
if iout <= length(traj)
379+
t_current = tspan[1] + (it - 1) * dt
380+
v_save = velocity_updater(v_prev, r, p, 0.5 * dt, t_current)
381+
data = vcat(r, v_save)
382+
traj[iout] = _prepare_saved_data(data, p, t_current, Val(SaveFields), Val(SaveWork))
383+
tsave[iout] = t_current
384+
end
385+
end
386+
387+
r += v * dt
388+
isoutofdomain(vcat(r, v), p, it * dt) && break
389+
it += 1
390+
end
391+
return it, iout, r, v
392+
end
393+
394+
@inline @muladd function _generic_boris!(
395+
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F1,
356396
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork},
357-
velocity_updater, alg_name
358-
) where {SaveFields, SaveWork}
397+
velocity_updater::F2, alg_name
398+
) where {SaveFields, SaveWork, F1, F2}
359399
(; tspan, p, u0) = prob
360-
q2m, m, Efunc, Bfunc, _ = p
361400
T = eltype(u0)
362401

363402
vars_dim = 6
@@ -368,57 +407,31 @@ Apply Boris method for particles with index in `irange`.
368407
vars_dim += 4
369408
end
370409

371-
@fastmath @inbounds for i in irange
410+
@inbounds for i in irange
372411
traj = Vector{SVector{vars_dim, T}}(undef, nout)
373412
tsave = Vector{typeof(tspan[1] + dt)}(undef, nout)
374413

375414
# set initial conditions for each trajectory i
376415
iout = 0
377416
new_prob = prob.prob_func(prob, i, false)
378-
# Load independent r and v SVector from u0
379417
u0_i = SVector{6, T}(new_prob.u0)
380418
r = u0_i[SVector(1, 2, 3)]
381419
v = u0_i[SVector(4, 5, 6)]
382420

383421
if save_start
384422
iout += 1
385-
traj[iout] = _prepare_saved_data(
386-
u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork)
387-
)
423+
traj[iout] = _prepare_saved_data(u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork))
388424
tsave[iout] = tspan[1]
389425
end
390426

391427
# push velocity back in time by 1/2 dt
392428
v = velocity_updater(v, r, p, -0.5 * dt, tspan[1])
393429

394-
it = 1
395-
while it <= nt
396-
v_prev = v
397-
t = (it - 0.5) * dt
398-
v = velocity_updater(v, r, p, dt, t)
399-
400-
if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0
401-
iout += 1
402-
if iout <= nout
403-
t_current = tspan[1] + (it - 1) * dt
404-
# Approximate v_n from v_{n-1/2} (v_prev)
405-
v_save = velocity_updater(
406-
v_prev, r, p, 0.5 * dt,
407-
t_current
408-
)
409-
410-
data = vcat(r, v_save)
411-
traj[iout] = _prepare_saved_data(
412-
data, p, t_current, Val(SaveFields), Val(SaveWork)
413-
)
414-
tsave[iout] = t_current
415-
end
416-
end
417-
418-
r += v * dt
419-
isoutofdomain(vcat(r, v), p, it * dt) && break
420-
it += 1
421-
end
430+
it, iout, r, v = _boris_loop!(
431+
traj, tsave, iout, r, v, p, dt, nt, tspan,
432+
savestepinterval, save_everystep, isoutofdomain, velocity_updater,
433+
Val(SaveFields), Val(SaveWork)
434+
)
422435

423436
final_step = min(it, nt)
424437
should_save_final = false
@@ -464,10 +477,10 @@ end
464477
"""
465478
Apply Boris method for particles with index in `irange`.
466479
"""
467-
@muladd function _boris!(
468-
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
480+
@inline @muladd function _boris!(
481+
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F,
469482
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
470-
) where {SaveFields, SaveWork}
483+
) where {SaveFields, SaveWork, F}
471484

472485
_generic_boris!(
473486
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@@ -541,10 +554,10 @@ Reference: [Zenitani & Kato 2025](https://arxiv.org/abs/2505.02270)
541554
return v_new
542555
end
543556

544-
@muladd function _multistep_boris!(
545-
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n_steps::Int,
557+
@inline @muladd function _multistep_boris!(
558+
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n_steps::Int,
546559
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
547-
) where {SaveFields, SaveWork}
560+
) where {SaveFields, SaveWork, F}
548561

549562
velocity_updater = (v, r, p, dt, t) ->
550563
update_velocity_multistep(v, r, p, dt, t, n_steps)

0 commit comments

Comments
 (0)