Skip to content

Commit d8b0de2

Browse files
authored
Implement sorting of neighbor lists for PrecomputedNeighborhoodSearch (#138)
* Implement sorting of neighbor lists for `PrecomputedNeighborhoodSearch` * Fix tests * Fix * Fix `freeze_neighborhood_search` * Add comments * Fix
1 parent beea9ae commit d8b0de2

File tree

6 files changed

+135
-56
lines changed

6 files changed

+135
-56
lines changed

src/PointNeighbors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using LinearAlgebra: dot
1111
using Polyester: Polyester
1212
@reexport using StaticArrays: SVector
1313

14-
include("vector_of_vectors.jl")
1514
include("util.jl")
15+
include("vector_of_vectors.jl")
1616
include("neighborhood_search.jl")
1717
include("nhs_trivial.jl")
1818
include("cell_lists/cell_lists.jl")

src/gpu.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ function Adapt.adapt_structure(to, nhs::PrecomputedNeighborhoodSearch)
3030
neighborhood_search = Adapt.adapt_structure(to, nhs.neighborhood_search)
3131

3232
return PrecomputedNeighborhoodSearch{ndims(nhs)}(neighbor_lists, search_radius,
33-
periodic_box, neighborhood_search)
33+
periodic_box, neighborhood_search,
34+
nhs.sort_neighbor_lists)
3435
end
3536

3637
function Adapt.adapt_structure(to, cell_list::SpatialHashingCellList{NDIMS}) where {NDIMS}

src/nhs_precomputed.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
periodic_box = nothing, update_strategy = nothing,
44
update_neighborhood_search = GridNeighborhoodSearch{NDIMS}(),
55
backend = DynamicVectorOfVectors{Int32},
6-
max_neighbors = max_neighbors(NDIMS))
6+
max_neighbors = max_neighbors(NDIMS),
7+
sort_neighbor_lists = true)
78
89
Neighborhood search with precomputed neighbor lists. A list of all neighbors is computed
910
for each point during initialization and update.
@@ -44,23 +45,29 @@ to strip the internal neighborhood search, which is not needed anymore.
4445
- `max_neighbors`: Maximum number of neighbors per particle. This will be used to
4546
allocate the `DynamicVectorOfVectors`. It is not used with
4647
other backends. The default is 64 in 2D and 324 in 3D.
48+
- `sort_neighbor_lists = true`: Whether to sort the neighbor lists after construction.
49+
This can improve cache hits on CPUs and improve coalesced
50+
memory access on GPUs.
4751
"""
4852
struct PrecomputedNeighborhoodSearch{NDIMS, NL, ELTYPE, PB, NHS} <:
4953
AbstractNeighborhoodSearch
5054
neighbor_lists :: NL
5155
search_radius :: ELTYPE
5256
periodic_box :: PB
5357
neighborhood_search :: NHS
58+
sort_neighbor_lists :: Bool
5459

5560
function PrecomputedNeighborhoodSearch{NDIMS}(neighbor_lists, search_radius,
5661
periodic_box,
57-
update_neighborhood_search) where {NDIMS}
62+
update_neighborhood_search,
63+
sort_neighbor_lists) where {NDIMS}
5864
return new{NDIMS, typeof(neighbor_lists),
5965
typeof(search_radius),
6066
typeof(periodic_box),
6167
typeof(update_neighborhood_search)}(neighbor_lists, search_radius,
6268
periodic_box,
63-
update_neighborhood_search)
69+
update_neighborhood_search,
70+
sort_neighbor_lists)
6471
end
6572
end
6673

@@ -73,11 +80,13 @@ function PrecomputedNeighborhoodSearch{NDIMS}(; search_radius = 0.0, n_points =
7380
periodic_box,
7481
update_strategy),
7582
backend = DynamicVectorOfVectors{Int32},
76-
max_neighbors = max_neighbors(NDIMS)) where {NDIMS}
83+
max_neighbors = max_neighbors(NDIMS),
84+
sort_neighbor_lists = true) where {NDIMS}
7785
neighbor_lists = construct_backend(backend, n_points, max_neighbors)
7886

7987
PrecomputedNeighborhoodSearch{NDIMS}(neighbor_lists, search_radius,
80-
periodic_box, update_neighborhood_search)
88+
periodic_box, update_neighborhood_search,
89+
sort_neighbor_lists)
8190
end
8291

8392
# Default values for maximum neighbor count
@@ -111,7 +120,7 @@ function initialize!(search::PrecomputedNeighborhoodSearch,
111120
initialize!(neighborhood_search, x, y; parallelization_backend)
112121

113122
initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
114-
parallelization_backend)
123+
parallelization_backend, search.sort_neighbor_lists)
115124

116125
return search
117126
end
@@ -132,14 +141,14 @@ function update!(search::PrecomputedNeighborhoodSearch,
132141
# Skip update if both point sets are static
133142
if any(points_moving)
134143
initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
135-
parallelization_backend)
144+
parallelization_backend, search.sort_neighbor_lists)
136145
end
137146

138147
return search
139148
end
140149

141150
function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
142-
parallelization_backend)
151+
parallelization_backend, sort_neighbor_lists)
143152
# Initialize neighbor lists
144153
empty!(neighbor_lists)
145154
resize!(neighbor_lists, size(x, 2))
@@ -155,7 +164,8 @@ function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
155164
end
156165

157166
function initialize_neighbor_lists!(neighbor_lists::DynamicVectorOfVectors,
158-
neighborhood_search, x, y, parallelization_backend)
167+
neighborhood_search, x, y, parallelization_backend,
168+
sort_neighbor_lists)
159169
resize!(neighbor_lists, size(x, 2))
160170

161171
# `Base.empty!.(neighbor_lists)`, but for all backends
@@ -168,6 +178,10 @@ function initialize_neighbor_lists!(neighbor_lists::DynamicVectorOfVectors,
168178
parallelization_backend) do point, neighbor, _, _
169179
pushat!(neighbor_lists, point, neighbor)
170180
end
181+
182+
if sort_neighbor_lists
183+
sorteach!(neighbor_lists)
184+
end
171185
end
172186

173187
@inline function foreach_neighbor(f, neighbor_system_coords,
@@ -225,5 +239,6 @@ end
225239
return PrecomputedNeighborhoodSearch{ndims(search)}(search.neighbor_lists,
226240
search.search_radius,
227241
search.periodic_box,
228-
nothing)
242+
nothing,
243+
search.sort_neighbor_lists)
229244
end

src/util.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ For GPU arrays, the respective `KernelAbstractions.Backend` is returned.
8080
"""
8181
@inline default_backend(::AbstractArray) = PolyesterBackend()
8282
@inline default_backend(x::AbstractGPUArray) = KernelAbstractions.get_backend(x)
83-
@inline default_backend(x::DynamicVectorOfVectors) = default_backend(x.backend)
83+
@inline default_backend(x::PermutedDimsArray) = default_backend(x.parent)
8484

8585
"""
8686
@threaded backend for ... end

src/vector_of_vectors.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ end
2323

2424
@inline Base.size(vov::DynamicVectorOfVectors) = (vov.length_[],)
2525

26+
@inline default_backend(x::DynamicVectorOfVectors) = default_backend(x.backend)
27+
2628
@inline function Base.getindex(vov::DynamicVectorOfVectors, i)
2729
(; backend, lengths) = vov
2830

@@ -162,6 +164,40 @@ end
162164
return vov
163165
end
164166

167+
# Sort each inner vector
168+
@inline function sorteach!(vov::DynamicVectorOfVectors)
169+
# TODO remove this check when Metal supports sorting
170+
if nameof(typeof(default_backend(vov.backend))) == :MetalBackend
171+
@warn "sorting neighbor lists is not supported on Metal. Skipping sort."
172+
return vov
173+
end
174+
175+
# Note that we cannot just do `sort!(vov[i])` on GPUs because that would call `sort!`
176+
# from within a GPU kernel, but this function is not GPU-compatible.
177+
# We might be able to use a sorting function from AcceleratedKernels.jl,
178+
# but for now the following workaround should be sufficient.
179+
180+
# Set all unused entries to `typemax` so that they are sorted to the end
181+
@threaded default_backend(vov.backend) for i in axes(vov.backend, 2)
182+
for j in (vov.lengths[i] + 1):size(vov.backend, 1)
183+
@inbounds vov.backend[j, i] = typemax(eltype(vov.backend))
184+
end
185+
end
186+
187+
# Now we can sort full columns.
188+
# Note that this will forward to highly optimized sorting functions on GPUs.
189+
# It currently does not work on Metal.
190+
sort!(vov.backend, dims = 1)
191+
192+
return vov
193+
end
194+
195+
@inline function sorteach!(vov::Vector{<:Vector{T}}) where {T}
196+
sort!.(vov)
197+
198+
return vov
199+
end
200+
165201
function max_inner_length(cells::DynamicVectorOfVectors, fallback)
166202
return size(cells.backend, 1)
167203
end

test/vector_of_vectors.jl

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
vov = PointNeighbors.DynamicVectorOfVectors{ELTYPE}(max_outer_length = 20,
99
max_inner_length = 4)
1010

11-
# Test internal size
12-
@test size(vov.backend) == (4, 20)
13-
1411
function verify(vov, vov_ref)
1512
@test length(vov) == length(vov_ref)
1613
@test eachindex(vov) == eachindex(vov_ref)
@@ -24,62 +21,92 @@
2421
end
2522
end
2623

27-
# Initial check
28-
verify(vov, vov_ref)
24+
@testset "Initial state" begin
25+
# Test internal size
26+
@test size(vov.backend) == (4, 20)
27+
28+
# Initial check
29+
verify(vov, vov_ref)
30+
end
31+
32+
@testset "First push!" begin
33+
push!(vov_ref, type.([1, 3, 2]))
34+
push!(vov, type.([1, 3, 2]))
35+
36+
verify(vov, vov_ref)
37+
end
38+
39+
@testset "push! multiple items" begin
40+
push!(vov_ref, type.([4]), type.([5, 6, 7, 8]))
41+
push!(vov, type.([4]), type.([5, 6, 7, 8]))
42+
43+
verify(vov, vov_ref)
44+
end
45+
46+
@testset "push! to an inner vector" begin
47+
push!(vov_ref[1], type(12))
48+
PointNeighbors.pushat!(vov, 1, type(12))
2949

30-
# First `push!`
31-
push!(vov_ref, type.([1, 2, 3]))
32-
push!(vov, type.([1, 2, 3]))
50+
verify(vov, vov_ref)
51+
end
3352

34-
verify(vov, vov_ref)
53+
@testset "push! overflow" begin
54+
error_ = ErrorException("cell list is full. Use a larger `max_points_per_cell`.")
55+
@test_throws error_ PointNeighbors.pushat!(vov, 1, type(13))
3556

36-
# `push!` multiple items
37-
push!(vov_ref, type.([4]), type.([5, 6, 7, 8]))
38-
push!(vov, type.([4]), type.([5, 6, 7, 8]))
57+
verify(vov, vov_ref)
58+
end
3959

40-
verify(vov, vov_ref)
60+
@testset "deleteat!" begin
61+
# Delete entry of inner vector. Note that this changes the order of the elements.
62+
deleteat!(vov_ref[3], 2)
63+
PointNeighbors.deleteatat!(vov, 3, 2)
4164

42-
# `push!` to an inner vector
43-
push!(vov_ref[1], type(12))
44-
PointNeighbors.pushat!(vov, 1, type(12))
65+
@test vov_ref[3] == type.([5, 7, 8])
66+
@test vov[3] == type.([5, 8, 7])
4567

46-
# `push!` overflow
47-
error_ = ErrorException("cell list is full. Use a larger `max_points_per_cell`.")
48-
@test_throws error_ PointNeighbors.pushat!(vov, 1, type(13))
68+
# Delete second to last entry
69+
deleteat!(vov_ref[3], 2)
70+
PointNeighbors.deleteatat!(vov, 3, 2)
4971

50-
verify(vov, vov_ref)
72+
@test vov_ref[3] == type.([5, 8])
73+
@test vov[3] == type.([5, 7])
5174

52-
# Delete entry of inner vector. Note that this changes the order of the elements.
53-
deleteat!(vov_ref[3], 2)
54-
PointNeighbors.deleteatat!(vov, 3, 2)
75+
# Delete last entry
76+
deleteat!(vov_ref[3], 2)
77+
PointNeighbors.deleteatat!(vov, 3, 2)
5578

56-
@test vov_ref[3] == type.([5, 7, 8])
57-
@test vov[3] == type.([5, 8, 7])
79+
# Now they are identical again
80+
verify(vov, vov_ref)
5881

59-
# Delete second to last entry
60-
deleteat!(vov_ref[3], 2)
61-
PointNeighbors.deleteatat!(vov, 3, 2)
82+
# Delete the remaining entry of this vector
83+
deleteat!(vov_ref[3], 1)
84+
PointNeighbors.deleteatat!(vov, 3, 1)
6285

63-
@test vov_ref[3] == type.([5, 8])
64-
@test vov[3] == type.([5, 7])
86+
verify(vov, vov_ref)
87+
end
6588

66-
# Delete last entry
67-
deleteat!(vov_ref[3], 2)
68-
PointNeighbors.deleteatat!(vov, 3, 2)
89+
# Skip for Tuples
90+
if ELTYPE <: Number
91+
@testset "sorteach!" begin
92+
# Make sure that the first inner vector is unsorted.
93+
# If this fails, make sure the tests above don't yield a sorted vector.
94+
@test vov[1] != sort(vov[1])
6995

70-
# Now they are identical again
71-
verify(vov, vov_ref)
96+
PointNeighbors.sorteach!(vov)
97+
PointNeighbors.sorteach!(vov_ref)
7298

73-
# Delete the remaining entry of this vector
74-
deleteat!(vov_ref[3], 1)
75-
PointNeighbors.deleteatat!(vov, 3, 1)
99+
@test vov[1] == sort(vov[1])
76100

77-
verify(vov, vov_ref)
101+
verify(vov, vov_ref)
102+
end
103+
end
78104

79-
# `empty!`
80-
empty!(vov_ref)
81-
empty!(vov)
105+
@testset "empty!" begin
106+
empty!(vov_ref)
107+
empty!(vov)
82108

83-
verify(vov, vov_ref)
109+
verify(vov, vov_ref)
110+
end
84111
end
85112
end

0 commit comments

Comments
 (0)