|
| 1 | +# These are the systems that require sorting. |
| 2 | +# TODO: The `DEMSystem` should be added here in the future. |
| 3 | +# Boundary particles always stay fixed relative to each other, TLSPH computes in the initial configuration. |
| 4 | +const RequiresSortingSystem = AbstractFluidSystem |
| 5 | + |
| 6 | +mutable struct SortingCallback{I} |
| 7 | + interval::I |
| 8 | + last_t::Float64 |
| 9 | +end |
| 10 | + |
| 11 | +""" |
| 12 | + SortingCallback(; interval=-1, dt=0.0, initial_sort=true) |
| 13 | +
|
| 14 | +Reorders particles according to neighborhood-search cells for performance optimization. |
| 15 | +
|
| 16 | +When particles become very unordered throughout a long-running simulation, performance |
| 17 | +degrades due to increased cache-misses (on CPUs) and lack of block structure (on GPUs). |
| 18 | +On GPUs, a fully shuffled particle ordering causes a 3-4x slowdown compared to a sorted configuration. |
| 19 | +On CPUs the performance penalty grows linearly with the problem size and can reach up to a |
| 20 | +10x slowdown for very large problems (65M particles). |
| 21 | +See [#1044](https://github.com/trixi-framework/TrixiParticles.jl/pull/1044) for more details. |
| 22 | +
|
| 23 | +# Keywords |
| 24 | +- `interval`: Sort particles at the end of every `interval` time steps. |
| 25 | +- `dt`: Sort particles in regular intervals of `dt` in terms of integration time. |
| 26 | + This callback does not add extra time steps / `tstops`; instead, sorting is |
| 27 | + triggered at the first solver step after each `dt` interval has elapsed. |
| 28 | +- `initial_sort=true`: When enabled, particles are sorted at the beginning of the simulation. |
| 29 | + When the initial configuration is a perfect grid of particles, |
| 30 | + sorting at the beginning is not necessary and might even slightly |
| 31 | + slow down the first time steps, since a perfect grid is even better |
| 32 | + than sorting by NHS cell index. |
| 33 | +""" |
| 34 | +function SortingCallback(; interval::Integer=-1, dt=0.0) |
| 35 | + if dt > 0 && interval !== -1 |
| 36 | + throw(ArgumentError("Setting both interval and dt is not supported!")) |
| 37 | + end |
| 38 | + |
| 39 | + # Sort in intervals in terms of simulation time |
| 40 | + if dt > 0 |
| 41 | + interval = Float64(dt) |
| 42 | + |
| 43 | + # Sort every time step (default) |
| 44 | + elseif interval == -1 |
| 45 | + interval = 1 |
| 46 | + end |
| 47 | + |
| 48 | + sorting_callback! = SortingCallback(interval, 0.0) |
| 49 | + |
| 50 | + # The first one is the `condition`, the second the `affect!` |
| 51 | + return DiscreteCallback(sorting_callback!, sorting_callback!, |
| 52 | + initialize=(initial_sort!), save_positions=(false, false)) |
| 53 | +end |
| 54 | + |
| 55 | +# `initialize` |
| 56 | +function initial_sort!(cb, u, t, integrator) |
| 57 | + # The `SortingCallback` is either `cb.affect!` (with `DiscreteCallback`) |
| 58 | + # or `cb.affect!.affect!` (with `PeriodicCallback`). |
| 59 | + # Let recursive dispatch handle this. |
| 60 | + |
| 61 | + initial_sort!(cb.affect!, u, t, integrator) |
| 62 | +end |
| 63 | + |
| 64 | +function initial_sort!(cb::SortingCallback, u, t, integrator) |
| 65 | + return cb(integrator) |
| 66 | +end |
| 67 | + |
| 68 | +# `condition` with `interval` |
| 69 | +function (sorting_callback!::SortingCallback{Int})(u, t, integrator) |
| 70 | + (; interval) = sorting_callback! |
| 71 | + |
| 72 | + return !isfinished(integrator) && condition_integrator_interval(integrator, interval) |
| 73 | +end |
| 74 | + |
| 75 | +# condition with `dt` |
| 76 | +function (sorting_callback!::SortingCallback)(u, t, integrator) |
| 77 | + (; interval, last_t) = sorting_callback! |
| 78 | + |
| 79 | + return (t - last_t) > interval |
| 80 | +end |
| 81 | + |
| 82 | +# `affect!` |
| 83 | +function (sorting_callback!::SortingCallback)(integrator) |
| 84 | + semi = integrator.p |
| 85 | + v_ode, u_ode = integrator.u.x |
| 86 | + |
| 87 | + @trixi_timeit timer() "sorting callback" begin |
| 88 | + foreach_system(semi) do system |
| 89 | + v = wrap_v(v_ode, system, semi) |
| 90 | + u = wrap_u(u_ode, system, semi) |
| 91 | + |
| 92 | + sort_particles!(system, v, u, semi) |
| 93 | + end |
| 94 | + end |
| 95 | + |
| 96 | + # Tell OrdinaryDiffEq that `integrator.u` has been modified |
| 97 | + u_modified!(integrator, true) |
| 98 | + |
| 99 | + return integrator |
| 100 | +end |
| 101 | + |
| 102 | +sort_particles!(system, v, u, semi) = system |
| 103 | + |
| 104 | +function sort_particles!(system::RequiresSortingSystem, v, u, semi) |
| 105 | + nhs = get_neighborhood_search(system, semi) |
| 106 | + |
| 107 | + if !(nhs isa GridNeighborhoodSearch) |
| 108 | + throw(ArgumentError("`SortingCallback` can only be used with a `GridNeighborhoodSearch`")) |
| 109 | + end |
| 110 | + |
| 111 | + sort_particles!(system, v, u, nhs, nhs.cell_list, semi) |
| 112 | +end |
| 113 | + |
| 114 | +# TODO: Sort also masses and particle spacings for variable smoothing lengths. |
| 115 | +function sort_particles!(system::RequiresSortingSystem, v, u, nhs, |
| 116 | + cell_list::FullGridCellList, semi) |
| 117 | + cell_coords = allocate(semi.parallelization_backend, SVector{ndims(system), Int}, |
| 118 | + nparticles(system)) |
| 119 | + @threaded semi for particle in each_active_particle(system) |
| 120 | + point_coords = current_coords(u, system, particle) |
| 121 | + cell_coords[particle] = PointNeighbors.cell_coords(point_coords, nhs) |
| 122 | + end |
| 123 | + |
| 124 | + # TODO `sortperm` works on CUDA but not (yet) on Metal |
| 125 | + perm = sortperm(transfer2cpu(cell_coords)) |
| 126 | + |
| 127 | + sort_system!(system, v, u, perm, system.buffer) |
| 128 | + |
| 129 | + return system |
| 130 | +end |
| 131 | + |
| 132 | +function sort_system!(system, v, u, perm, buffer::Nothing) |
| 133 | + system_coords = current_coordinates(u, system) |
| 134 | + system_velocity = current_velocity(v, system) |
| 135 | + system_density = current_density(v, system) |
| 136 | + system_pressure = current_pressure(v, system) |
| 137 | + |
| 138 | + system_coords .= system_coords[:, perm] |
| 139 | + system_velocity .= system_velocity[:, perm] |
| 140 | + system_pressure .= system_pressure[perm] |
| 141 | + system_density .= system_density[perm] |
| 142 | + |
| 143 | + return system |
| 144 | +end |
| 145 | + |
| 146 | +function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SortingCallback{Int}}) |
| 147 | + @nospecialize cb # reduce precompilation time |
| 148 | + print(io, "SortingCallback(interval=", cb.affect!.interval, ")") |
| 149 | +end |
| 150 | + |
| 151 | +function Base.show(io::IO, |
| 152 | + cb::DiscreteCallback{<:Any, <:SortingCallback}) |
| 153 | + @nospecialize cb # reduce precompilation time |
| 154 | + print(io, "SortingCallback(dt=", cb.affect!.affect!.interval, ")") |
| 155 | +end |
| 156 | + |
| 157 | +function Base.show(io::IO, ::MIME"text/plain", |
| 158 | + cb::DiscreteCallback{<:Any, <:SortingCallback{Int}}) |
| 159 | + @nospecialize cb # reduce precompilation time |
| 160 | + |
| 161 | + if get(io, :compact, false) |
| 162 | + show(io, cb) |
| 163 | + else |
| 164 | + sorting_cb = cb.affect! |
| 165 | + setup = [ |
| 166 | + "interval" => sorting_cb.interval |
| 167 | + ] |
| 168 | + summary_box(io, "SortingCallback", setup) |
| 169 | + end |
| 170 | +end |
| 171 | + |
| 172 | +function Base.show(io::IO, ::MIME"text/plain", |
| 173 | + cb::DiscreteCallback{<:Any, <:SortingCallback}) |
| 174 | + @nospecialize cb # reduce precompilation time |
| 175 | + |
| 176 | + if get(io, :compact, false) |
| 177 | + show(io, cb) |
| 178 | + else |
| 179 | + sorting_cb = cb.affect!.affect! |
| 180 | + setup = [ |
| 181 | + "dt" => sorting_cb.interval |
| 182 | + ] |
| 183 | + summary_box(io, "SortingCallback", setup) |
| 184 | + end |
| 185 | +end |
0 commit comments