Skip to content

Commit bbc064e

Browse files
committed
Introduce more index-based functions to enable more flexibility in handling system pair specific interactions
1 parent 9b9efbe commit bbc064e

File tree

3 files changed

+117
-47
lines changed

3 files changed

+117
-47
lines changed

src/general/neighborhood_search.jl

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,42 +102,43 @@ end
102102
end
103103

104104
@inline function get_neighborhood_search(system, semi)
105-
(; neighborhood_searches) = semi
105+
system_index = system_indices(system, semi)
106+
107+
return get_neighborhood_search(system, semi, system_index)
108+
end
106109

110+
@inline function get_neighborhood_search(system, neighbor_system, semi)
107111
system_index = system_indices(system, semi)
112+
neighbor_index = system_indices(neighbor_system, semi)
113+
114+
return get_neighborhood_search(system, neighbor_system, semi, system_index, neighbor_index)
115+
end
108116

109-
return neighborhood_searches[system_index][system_index]
117+
@inline function get_neighborhood_search(system, semi, system_index::Integer)
118+
return semi.neighborhood_searches[system_index][system_index]
110119
end
111120

112-
@inline function get_neighborhood_search(system::TotalLagrangianSPHSystem, semi)
121+
@inline function get_neighborhood_search(system::TotalLagrangianSPHSystem, semi,
122+
system_index::Integer)
113123
# For TLSPH, use the specialized self-interaction neighborhood search
114124
# for finding neighbors in the initial configuration.
115125
return system.self_interaction_nhs
116126
end
117127

128+
@inline function get_neighborhood_search(system, neighbor_system, semi,
129+
system_index::Integer, neighbor_index::Integer)
130+
return semi.neighborhood_searches[system_index][neighbor_index]
131+
end
118132
@inline function get_neighborhood_search(system::TotalLagrangianSPHSystem,
119-
neighbor_system::TotalLagrangianSPHSystem, semi)
120-
(; neighborhood_searches) = semi
121-
122-
system_index = system_indices(system, semi)
123-
neighbor_index = system_indices(neighbor_system, semi)
124-
133+
neighbor_system::TotalLagrangianSPHSystem, semi,
134+
system_index::Integer, neighbor_index::Integer)
125135
if system_index == neighbor_index
126136
# For TLSPH, use the specialized self-interaction neighborhood search
127137
# for finding neighbors in the initial configuration.
128138
return system.self_interaction_nhs
129139
end
130140

131-
return neighborhood_searches[system_index][neighbor_index]
132-
end
133-
134-
@inline function get_neighborhood_search(system, neighbor_system, semi)
135-
(; neighborhood_searches) = semi
136-
137-
system_index = system_indices(system, semi)
138-
neighbor_index = system_indices(neighbor_system, semi)
139-
140-
return neighborhood_searches[system_index][neighbor_index]
141+
return semi.neighborhood_searches[system_index][neighbor_index]
141142
end
142143

143144
function initialize_neighborhood_searches!(semi)
@@ -150,11 +151,13 @@ function initialize_neighborhood_searches!(semi)
150151
return semi
151152
end
152153

153-
function initialize_neighborhood_search!(semi, system, neighbor)
154+
function initialize_neighborhood_search!(semi, system, neighbor,
155+
system_index::Integer, neighbor_index::Integer)
154156
# TODO Initialize after adapting to the GPU.
155157
# Currently, this cannot use `semi.parallelization_backend`
156158
# because data is still on the CPU.
157-
PointNeighbors.initialize!(get_neighborhood_search(system, neighbor, semi),
159+
PointNeighbors.initialize!(get_neighborhood_search(system, neighbor, semi,
160+
system_index, neighbor_index),
158161
initial_coordinates(system),
159162
initial_coordinates(neighbor),
160163
eachindex_y=each_active_particle(neighbor),
@@ -164,19 +167,30 @@ function initialize_neighborhood_search!(semi, system, neighbor)
164167
end
165168

166169
function initialize_neighborhood_search!(semi, system::TotalLagrangianSPHSystem,
167-
neighbor::TotalLagrangianSPHSystem)
170+
neighbor::TotalLagrangianSPHSystem,
171+
system_index::Integer, neighbor_index::Integer)
168172
# For TLSPH, the self-interaction NHS is already initialized in the system constructor
169173
return semi
170174
end
171175

176+
function initialize_neighborhood_search!(semi, system, neighbor)
177+
system_index = system_indices(system, semi)
178+
neighbor_index = system_indices(neighbor, semi)
179+
180+
return initialize_neighborhood_search!(semi, system, neighbor,
181+
system_index, neighbor_index)
182+
end
183+
184+
# === Neighborhood search updates (per-system) ===
172185
function update_nhs!(semi, u_ode)
173186
# Update NHS for each pair of systems
174-
foreach_system(semi) do system
175-
u_system = wrap_u(u_ode, system, semi)
187+
foreach_system_indexed(semi) do system_index, system
188+
u_system = wrap_u(u_ode, system, semi, system_index)
176189

177-
foreach_system(semi) do neighbor
178-
u_neighbor = wrap_u(u_ode, neighbor, semi)
179-
neighborhood_search = get_neighborhood_search(system, neighbor, semi)
190+
foreach_system_indexed(semi) do neighbor_index, neighbor
191+
u_neighbor = wrap_u(u_ode, neighbor, semi, neighbor_index)
192+
neighborhood_search = get_neighborhood_search(system, neighbor, semi,
193+
system_index, neighbor_index)
180194

181195
update_nhs!(neighborhood_search, system, neighbor, u_system, u_neighbor, semi)
182196
end

src/general/semidiscretization.jl

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,38 @@ end
171171

172172
@inline foreach_system(f, systems) = foreach_noalloc(f, systems)
173173

174+
# This is just for readability to loop over all systems with indices without allocations.
175+
@inline function foreach_system_indexed(f, semi::Union{NamedTuple, Semidiscretization})
176+
return foreach_system_indexed(f, semi.systems)
177+
end
178+
179+
@inline function foreach_system_indexed(f, systems::Tuple)
180+
indices = ntuple(identity, Val(length(systems)))
181+
182+
return foreach_noalloc(indices, systems) do (index, system)
183+
f(index, system)
184+
end
185+
end
186+
187+
# This is just for readability to loop over all systems with wrapped arrays.
188+
@inline function foreach_system_wrapped(f, semi::Union{NamedTuple, Semidiscretization},
189+
v_ode, u_ode)
190+
return foreach_system_indexed(semi) do system_index, system
191+
f(system,
192+
wrap_v(v_ode, system, semi, system_index),
193+
wrap_u(u_ode, system, semi, system_index))
194+
end
195+
end
196+
197+
@inline function foreach_system_wrapped(f, semi::Union{NamedTuple, Semidiscretization},
198+
dv_ode, v_ode, u_ode)
199+
return foreach_system_indexed(semi) do system_index, system
200+
f(system,
201+
wrap_v(dv_ode, system, semi, system_index),
202+
wrap_v(v_ode, system, semi, system_index),
203+
wrap_u(u_ode, system, semi, system_index))
204+
end
205+
end
174206
"""
175207
semidiscretize(semi, tspan; reset_threads=true)
176208
@@ -334,24 +366,40 @@ end
334366
# We have to pass `system` here for type stability,
335367
# since the type of `system` determines the return type.
336368
@inline function wrap_v(v_ode, system, semi)
337-
(; ranges_v) = semi
369+
return wrap_v(v_ode, system, semi, system_indices(system, semi))
370+
end
338371

339-
range = ranges_v[system_indices(system, semi)]
372+
@inline function wrap_v(v_ode, system, semi, system_index::Integer)
373+
return wrap_v(v_ode, system, semi.ranges_v[system_index])
374+
end
340375

341-
@boundscheck @assert length(range) ==
342-
v_nvariables(system) * n_integrated_particles(system)
376+
@inline function wrap_v(v_ode, system, range::AbstractUnitRange)
377+
@boundscheck begin
378+
expected = v_nvariables(system) * n_integrated_particles(system)
379+
range_length = length(range)
380+
range_length == expected ||
381+
throw(DimensionMismatch("v range length $range_length does not match expected $expected"))
382+
end
343383

344384
return wrap_array(v_ode, range,
345385
(StaticInt(v_nvariables(system)), n_integrated_particles(system)))
346386
end
347387

348388
@inline function wrap_u(u_ode, system, semi)
349-
(; ranges_u) = semi
389+
return wrap_u(u_ode, system, semi, system_indices(system, semi))
390+
end
350391

351-
range = ranges_u[system_indices(system, semi)]
392+
@inline function wrap_u(u_ode, system, semi, system_index::Integer)
393+
return wrap_u(u_ode, system, semi.ranges_u[system_index])
394+
end
352395

353-
@boundscheck @assert length(range) ==
354-
u_nvariables(system) * n_integrated_particles(system)
396+
@inline function wrap_u(u_ode, system, range::AbstractUnitRange)
397+
@boundscheck begin
398+
expected = u_nvariables(system) * n_integrated_particles(system)
399+
range_length = length(range)
400+
range_length == expected ||
401+
throw(DimensionMismatch("u range length $range_length does not match expected $expected"))
402+
end
355403

356404
return wrap_array(u_ode, range,
357405
(StaticInt(u_nvariables(system)), n_integrated_particles(system)))
@@ -639,19 +687,18 @@ end
639687

640688
function system_interaction!(dv_ode, v_ode, u_ode, semi)
641689
# Call `interact!` for each pair of systems
642-
foreach_system(semi) do system
643-
foreach_system(semi) do neighbor
690+
foreach_system_indexed(semi) do system_index, system
691+
foreach_system_indexed(semi) do neighbor_index, neighbor
644692
# Construct string for the interactions timer.
645693
# Avoid allocations from string construction when no timers are used.
646694
if timeit_debug_enabled()
647-
system_index = system_indices(system, semi)
648-
neighbor_index = system_indices(neighbor, semi)
649695
timer_str = "$(timer_name(system))$system_index-$(timer_name(neighbor))$neighbor_index"
650696
else
651697
timer_str = ""
652698
end
653699

654-
interact!(dv_ode, v_ode, u_ode, system, neighbor, semi, timer_str=timer_str)
700+
interact!(dv_ode, v_ode, u_ode, system, neighbor, semi,
701+
system_index, neighbor_index; timer_str=timer_str)
655702
end
656703
end
657704

@@ -663,12 +710,21 @@ end
663710
# dv_ode, du_ode = copy(sol.u[end]).x; v_ode, u_ode = copy(sol.u[end]).x;
664711
# @btime TrixiParticles.interact!($dv_ode, $v_ode, $u_ode, $fluid_system, $fluid_system, $semi);
665712
@inline function interact!(dv_ode, v_ode, u_ode, system, neighbor, semi; timer_str="")
666-
dv = wrap_v(dv_ode, system, semi)
667-
v_system = wrap_v(v_ode, system, semi)
668-
u_system = wrap_u(u_ode, system, semi)
713+
system_index = system_indices(system, semi)
714+
neighbor_index = system_indices(neighbor, semi)
715+
716+
return interact!(dv_ode, v_ode, u_ode, system, neighbor, semi,
717+
system_index, neighbor_index; timer_str=timer_str)
718+
end
719+
720+
@inline function interact!(dv_ode, v_ode, u_ode, system, neighbor, semi,
721+
system_index::Integer, neighbor_index::Integer; timer_str="")
722+
dv = wrap_v(dv_ode, system, semi, system_index)
723+
v_system = wrap_v(v_ode, system, semi, system_index)
724+
u_system = wrap_u(u_ode, system, semi, system_index)
669725

670-
v_neighbor = wrap_v(v_ode, neighbor, semi)
671-
u_neighbor = wrap_u(u_ode, neighbor, semi)
726+
v_neighbor = wrap_v(v_ode, neighbor, semi, neighbor_index)
727+
u_neighbor = wrap_u(u_ode, neighbor, semi, neighbor_index)
672728

673729
@trixi_timeit timer() timer_str begin
674730
interact!(dv, v_system, u_system, v_neighbor, u_neighbor, system, neighbor, semi)

src/io/io.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ function create_meta_data_dict(callback, integrator)
3636

3737
systems = Dict{String, Any}()
3838
foreach_system(semi) do system
39-
idx = system_indices(system, semi)
40-
name = add_underscore_to_optional_prefix(prefix) * names[idx]
39+
system_index = system_indices(system, semi)
40+
name = add_underscore_to_optional_prefix(prefix) * names[system_index]
4141

4242
system_data = Dict{String, Any}()
4343
add_system_data!(system_data, system)

0 commit comments

Comments
 (0)