Skip to content

Commit 314f8b6

Browse files
committed
VariableNHSHandler and GridNHSHandler are running correctly with the 2d dam break example.
1 parent a4dcd48 commit 314f8b6

File tree

2 files changed

+106
-48
lines changed

2 files changed

+106
-48
lines changed

src/general/neighborhood_search.jl

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -114,85 +114,135 @@ function create_neighborhood_search_handler(::Type{<:PairsNHSHandler},
114114
neighborhood_search, systems)
115115
# Create a tuple of n neighborhood searches for each of the n systems.
116116
# We will need one neighborhood search for each pair of systems.
117-
return Tuple(Tuple(create_neighborhood_search(neighborhood_search,
118-
system, neighbor)
117+
searches = Tuple(Tuple(create_neighborhood_search(neighborhood_search,
118+
system, neighbor)
119119
for neighbor in systems)
120120
for system in systems)
121+
122+
return PairsNHSHandler(searches)
121123
end
122124

123125
function get_neighborhood_search(handler::PairsNHSHandler, system_index, neighbor_index, _)
124126
return handler.neighborhood_searches[system_index][neighbor_index]
125127
end
126128

127-
struct VariableRadiusNHS{SR, NHS} <: PointNeighbors.AbstractNeighborhoodSearch
129+
struct VariableNHSHandler{NHS} <: AbstractNHSHandler
130+
neighborhood_searches::NHS
131+
end
132+
133+
function create_neighborhood_search_handler(::Type{<:VariableNHSHandler}, nhs, systems)
134+
# Find a list of search radii that will be requested for each system
135+
search_radii = Tuple([compact_support(system, neighbor) for neighbor in systems]
136+
for system in systems)
137+
search_radii = Tuple.(unique.(search_radii))
138+
139+
# For each system, create a neighborhood search for each unique search radius
140+
searches = Tuple(VariableSearchRadiusNHS(nhs,
141+
search_radii[system_indices(system,
142+
systems)],
143+
nparticles(system))
144+
for system in systems)
145+
146+
# For each system, create a neighborhood search for each unique search radius
147+
return VariableNHSHandler(searches)
148+
end
149+
150+
function get_neighborhood_search(handler::VariableNHSHandler, system_index, neighbor_index,
151+
search_radius)
152+
handler.neighborhood_searches[neighbor_index]
153+
end
154+
155+
struct VariableSearchRadiusNHS{SR, NHS} <: PointNeighbors.AbstractNeighborhoodSearch
128156
search_radii :: SR
129157
neighborhood_searches :: NHS
130158
end
131159

132160
function VariableSearchRadiusNHS(nhs_implementation, search_radii, n_particles)
133-
searches = Tuple(copy_neighborhood_search(nhs_implementation, search_radius, n_particles)
161+
searches = Tuple(PointNeighbors.copy_neighborhood_search(nhs_implementation,
162+
search_radius, n_particles)
134163
for search_radius in search_radii)
135-
return VariableRadiusNHS(search_radii, searches)
164+
return VariableSearchRadiusNHS(search_radii, searches)
136165
end
137166

138-
@inline Base.ndims(::TrivialNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
167+
@inline Base.ndims(search::VariableSearchRadiusNHS) = ndims(first(search.neighborhood_searches))
139168

140-
@inline requires_update(::TrivialNeighborhoodSearch) = (false, true)
169+
@inline requires_update(::VariableSearchRadiusNHS) = (false, true)
141170

142-
@inline function initialize!(search::TrivialNeighborhoodSearch, x, y;
143-
parallelization_backend = default_backend(x),
144-
eachindex_y = axes(y, 2))
171+
# TODO
172+
@inline function initialize!(search::VariableSearchRadiusNHS, x, y;
173+
parallelization_backend=default_backend(x),
174+
eachindex_y=axes(y, 2))
145175
return search
146176
end
147177

148-
@inline function update!(search::TrivialNeighborhoodSearch, x, y;
149-
points_moving = (true, true),
150-
parallelization_backend = default_backend(x),
151-
eachindex_y = axes(y, 2))
178+
# TODO
179+
@inline function update!(search::VariableSearchRadiusNHS, x, y;
180+
points_moving=(true, true),
181+
parallelization_backend=default_backend(x),
182+
eachindex_y=axes(y, 2))
152183
return search
153184
end
154185

155186
# Create a copy of a neighborhood search but with a different search radius
156-
function copy_neighborhood_search(nhs::TrivialNeighborhoodSearch, search_radius, x, y)
157-
return nhs
187+
function PointNeighbors.copy_neighborhood_search(nhs::VariableSearchRadiusNHS,
188+
search_radius, x, y)
189+
search_radii = Tuple(PointNeighbors.search_radius(search)
190+
for search in nhs.neighborhood_searches)
191+
searches = Tuple(PointNeighbors.copy_neighborhood_search(search)
192+
for search in nhs.neighborhood_searches)
193+
194+
return VariableSearchRadiusNHS(search_radii, searches)
158195
end
159196

160-
function copy_neighborhood_search(nhs::TrivialNeighborhoodSearch,
161-
search_radius, n_points; eachpoint = 1:n_points)
197+
function PointNeighbors.copy_neighborhood_search(nhs::VariableSearchRadiusNHS,
198+
search_radius, n_points;
199+
eachpoint=1:n_points)
162200
return TrivialNeighborhoodSearch{ndims(nhs)}(; search_radius, eachpoint,
163-
periodic_box = nhs.periodic_box)
201+
periodic_box=nhs.periodic_box)
164202
end
165203

166204
@inline function foreach_neighbor(f, neighbor_system_coords,
167-
neighborhood_search::VariableRadiusNHS,
205+
neighborhood_search::VariableSearchRadiusNHS,
168206
point, point_coords, search_radius)
169207
# Find the index of the search radius in the list of search radii for this system
170208
search_radii = neighborhood_search.search_radii
171-
idx = searchsortedfirst(search_radii, search_radius - eps(search_radius))
209+
idx = searchsortedfirst(SVector(search_radii), search_radius - eps(search_radius))
172210

173211
nhs = neighborhood_search.neighborhood_searches[idx]
174212

175-
PointNeighbors.foreach_neighbor(f, neighbor_system_coords, nhs, point, point_coords, search_radius)
213+
PointNeighbors.foreach_neighbor(f, neighbor_system_coords, nhs, point, point_coords,
214+
search_radius)
176215
end
177216

178-
struct VariableNHSHandler{NHS} <: AbstractNHSHandler
217+
struct GridNHSHandler{SI, NHS} <: AbstractNHSHandler
218+
search_radii::SI
179219
neighborhood_searches::NHS
180220
end
181221

182-
function create_neighborhood_search_handler(::Type{<:VariableNHSHandler}, nhs, systems)
222+
function create_neighborhood_search_handler(::Type{<:GridNHSHandler}, nhs, systems)
183223
# Find a list of search radii that will be requested for each system
184-
search_radii = Tuple([compact_support(system, neighbor) for neighbor in systems] for system in systems)
224+
search_radii = Tuple([compact_support(system, neighbor) for neighbor in systems]
225+
for system in systems)
185226
search_radii = Tuple.(unique.(search_radii))
186227

187228
# For each system, create a neighborhood search for each unique search radius
188-
return Tuple(VariableSearchRadiusNHS(nhs, search_radii[system_indices(system, systems)], nparticles(system))
189-
for system in systems)
229+
searches = Tuple(Tuple(copy_neighborhood_search(nhs, search_radius, nparticles(system))
230+
for search_radius in
231+
search_radii[system_indices(system, systems)])
232+
for system in systems)
233+
234+
return GridNHSHandler(search_radii, searches)
190235
end
191236

192-
function get_neighborhood_search(handler::VariableNHSHandler, system_index, neighbor_index,
237+
function get_neighborhood_search(handler::GridNHSHandler, system_index, neighbor_index,
193238
search_radius)
194-
handler.neighborhood_searches[neighbor_index]
239+
# Find the index of the search radius in the list of search radii for this system
240+
search_radii = handler.search_radii[neighbor_index]
241+
idx = searchsortedfirst(SVector(search_radii), search_radius - eps(search_radius))
242+
# Return the neighborhood search at that index
243+
return handler.neighborhood_searches[neighbor_index][idx]
195244
end
245+
# ==================================
196246

197247
@inline function get_neighborhood_search(system, semi)
198248
return get_neighborhood_search(system, system, semi)
@@ -205,20 +255,21 @@ end
205255
end
206256

207257
@inline function get_neighborhood_search(system, neighbor_system, semi)
208-
(; neighborhood_searches) = semi
258+
(; neighborhood_search_handler) = semi
209259

210260
system_index = system_indices(system, semi)
211261
neighbor_index = system_indices(neighbor_system, semi)
212262

213263
search_radius = compact_support(system, neighbor_system)
214264

215-
return get_neighborhood_search(neighborhood_searches, system_index, neighbor_index,
265+
return get_neighborhood_search(neighborhood_search_handler, system_index,
266+
neighbor_index,
216267
search_radius)
217268
end
218269

219270
@inline function get_neighborhood_search(system::TotalLagrangianSPHSystem,
220271
neighbor_system::TotalLagrangianSPHSystem, semi)
221-
(; neighborhood_searches) = semi
272+
(; neighborhood_search_handler) = semi
222273

223274
system_index = system_indices(system, semi)
224275
neighbor_index = system_indices(neighbor_system, semi)
@@ -231,7 +282,8 @@ end
231282

232283
search_radius = compact_support(system, neighbor_system)
233284

234-
return get_neighborhood_search(neighborhood_searches, system_index, neighbor_index,
285+
return get_neighborhood_search(neighborhood_search_handler, system_index,
286+
neighbor_index,
235287
search_radius)
236288
end
237289

src/general/semidiscretization.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,34 @@ semi = Semidiscretization(fluid_system, boundary_system,
4949
└──────────────────────────────────────────────────────────────────────────────────────────────────┘
5050
```
5151
"""
52-
struct Semidiscretization{BACKEND, S, RU, RV, NS, UCU, IT}
53-
systems :: S
54-
ranges_u :: RU
55-
ranges_v :: RV
56-
neighborhood_searches :: NS
57-
parallelization_backend :: BACKEND
58-
update_callback_used :: UCU
59-
integrate_tlsph :: IT # `false` if TLSPH integration is decoupled
52+
struct Semidiscretization{BACKEND, S, RU, RV, NSH, UCU, IT}
53+
systems :: S
54+
ranges_u :: RU
55+
ranges_v :: RV
56+
neighborhood_search_handler :: NSH
57+
parallelization_backend :: BACKEND
58+
update_callback_used :: UCU
59+
integrate_tlsph :: IT # `false` if TLSPH integration is decoupled
6060

6161
# Dispatch at `systems` to distinguish this constructor from the one below when
6262
# 4 systems are passed.
6363
# This is an internal constructor only used in `test/count_allocations.jl`
6464
# and by Adapt.jl.
65-
function Semidiscretization(systems::Tuple, ranges_u, ranges_v, neighborhood_searches,
65+
function Semidiscretization(systems::Tuple, ranges_u, ranges_v,
66+
neighborhood_search_handler,
6667
parallelization_backend::PointNeighbors.ParallelizationBackend,
6768
update_callback_used, integrate_tlsph)
6869
new{typeof(parallelization_backend), typeof(systems), typeof(ranges_u),
69-
typeof(ranges_v), typeof(neighborhood_searches),
70+
typeof(ranges_v), typeof(neighborhood_search_handler),
7071
typeof(update_callback_used),
7172
typeof(integrate_tlsph)}(systems, ranges_u, ranges_v,
72-
neighborhood_searches, parallelization_backend,
73+
neighborhood_search_handler, parallelization_backend,
7374
update_callback_used, integrate_tlsph)
7475
end
7576
end
7677

7778
function Semidiscretization(systems::Union{AbstractSystem, Nothing}...;
79+
nhs_handler=GridNHSHandler,
7880
neighborhood_search=GridNeighborhoodSearch{ndims(first(systems))}(),
7981
parallelization_backend=PolyesterBackend())
8082
systems = filter(system -> !isnothing(system), systems)
@@ -98,7 +100,9 @@ function Semidiscretization(systems::Union{AbstractSystem, Nothing}...;
98100
ranges_v = Tuple((sum(sizes_v[1:(i - 1)]) + 1):sum(sizes_v[1:i])
99101
for i in eachindex(sizes_v))
100102

101-
searches = create_neighborhood_search_handler(PairsNHSHandler, neighborhood_search, systems)
103+
neighborhood_search_handler = create_neighborhood_search_handler(nhs_handler,
104+
neighborhood_search,
105+
systems)
102106

103107
# These will be set to true inside the `UpdateCallback`.
104108
# Some techniques require the use of this callback, and this flag can be used
@@ -110,7 +114,8 @@ function Semidiscretization(systems::Union{AbstractSystem, Nothing}...;
110114
# with this set to false.
111115
integrate_tlsph = Ref(true)
112116

113-
return Semidiscretization(systems, ranges_u, ranges_v, searches,
117+
return Semidiscretization(systems, ranges_u, ranges_v,
118+
neighborhood_search_handler,
114119
parallelization_backend, update_callback_used,
115120
integrate_tlsph)
116121
end
@@ -139,15 +144,16 @@ function Base.show(io::IO, ::MIME"text/plain", semi::Semidiscretization)
139144
summary_line(io, "#spatial dimensions", ndims(semi.systems[1]))
140145
summary_line(io, "#systems", length(semi.systems))
141146
summary_line(io, "neighborhood search",
142-
semi.neighborhood_searches |> eltype |> eltype |> nameof)
147+
semi.neighborhood_search_handler |> eltype |> eltype |> nameof)
143148
summary_line(io, "total #particles", sum(nparticles.(semi.systems)))
144149
summary_line(io, "eltype", eltype(semi.systems[1]))
145150
summary_line(io, "coordinates eltype", coordinates_eltype(semi.systems[1]))
146151
summary_footer(io)
147152
end
148153
end
149154

150-
@inline system_indices(system, semi::Semidiscretization) = system_indices(system, semi.systems)
155+
@inline system_indices(system,
156+
semi::Semidiscretization) = system_indices(system, semi.systems)
151157

152158
@inline function system_indices(system, systems)
153159
# Note that this takes only about 5 ns, while mapping systems to indices with a `Dict`

0 commit comments

Comments
 (0)