Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/PointNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export foreach_point_neighbor, foreach_neighbor
export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch
export DictionaryCellList, FullGridCellList
export ParallelUpdate, SemiParallelUpdate, SerialUpdate
export initialize!, update!, initialize_grid!, update_grid!
export requires_update, initialize!, update!, initialize_grid!, update_grid!
export PolyesterBackend, ThreadsDynamicBackend, ThreadsStaticBackend
export PeriodicBox, copy_neighborhood_search

Expand Down
37 changes: 33 additions & 4 deletions src/neighborhood_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@ abstract type AbstractNeighborhoodSearch end

@inline search_radius(search::AbstractNeighborhoodSearch) = search.search_radius

"""
requires_update(search::AbstractNeighborhoodSearch)

Returns a tuple `(x_changed, y_changed)` indicating if this type of neighborhood search
requires an update when the coordinates of the points in `x` or `y` change.
"""
function requires_update(::AbstractNeighborhoodSearch)
error("`requires_update` not implemented for this neighborhood search.")
end

"""
initialize!(search::AbstractNeighborhoodSearch, x, y)

Expand Down Expand Up @@ -206,13 +216,32 @@ end
return nothing
end

@inline function foreach_neighbor(f, system_coords, neighbor_system_coords,
neighborhood_search, point;
search_radius = search_radius(neighborhood_search))
@propagate_inbounds function foreach_neighbor(f, system_coords, neighbor_system_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point;
search_radius = search_radius(neighborhood_search))
# Due to https://github.com/JuliaLang/julia/issues/30411, we cannot just remove
# a `@boundscheck` by calling this function with `@inbounds` because it has a kwarg.
# We have to use `@propagate_inbounds`, which will also remove boundschecks
# in the neighbor loop, which is not safe (see comment below).
# To avoid this, we have to use a function barrier to disable the `@inbounds` again.
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)

foreach_neighbor(f, neighbor_system_coords, neighborhood_search,
point, point_coords, search_radius)
end

# This is the generic function that is called for `TrivialNeighborhoodSearch`.
# For `GridNeighborhoodSearch`, a specialized function is used for slightly better
# performance. `PrecomputedNeighborhoodSearch` can skip the distance check altogether.
@inline function foreach_neighbor(f, neighbor_system_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point, point_coords, search_radius)
(; periodic_box) = neighborhood_search

point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
for neighbor in eachneighbor(point_coords, neighborhood_search)
# Making the following `@inbounds` yields a ~2% speedup on an NVIDIA H100.
# But we don't know if `neighbor` (extracted from the cell list) is in bounds.
neighbor_coords = extract_svector(neighbor_system_coords,
Val(ndims(neighborhood_search)), neighbor)

Expand Down
32 changes: 10 additions & 22 deletions src/nhs_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ function GridNeighborhoodSearch{NDIMS}(; search_radius = 0.0, n_points = 0,
cell_size, update_buffer, update_strategy)
end

@inline Base.ndims(::GridNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS

@inline requires_update(::GridNeighborhoodSearch) = (false, true)

"""
ParallelUpdate()

Expand Down Expand Up @@ -158,8 +162,6 @@ end
push!(update_buffer, index_type(cell_list)[])
end

@inline Base.ndims(::GridNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS

function initialize!(neighborhood_search::GridNeighborhoodSearch,
x::AbstractMatrix, y::AbstractMatrix)
initialize_grid!(neighborhood_search, y)
Expand Down Expand Up @@ -355,24 +357,11 @@ function update_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any, Paralle
return neighborhood_search
end

@propagate_inbounds function foreach_neighbor(f, system_coords, neighbor_system_coords,
neighborhood_search::GridNeighborhoodSearch,
point;
search_radius = search_radius(neighborhood_search))
# Due to https://github.com/JuliaLang/julia/issues/30411, we cannot just remove
# a `@boundscheck` by calling this function with `@inbounds` because it has a kwarg.
# We have to use `@propagate_inbounds`, which will also remove boundschecks
# in the neighbor loop, which is not safe (see comment below).
# To avoid this, we have to use a function barrier to disable the `@inbounds` again.
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)

__foreach_neighbor(f, system_coords, neighbor_system_coords, neighborhood_search,
point, point_coords, search_radius)
end

@inline function __foreach_neighbor(f, system_coords, neighbor_system_coords,
neighborhood_search::GridNeighborhoodSearch,
point, point_coords, search_radius)
# Specialized version of the function in `neighborhood_search.jl`, which is faster
# than looping over `eachneighbor`.
@inline function foreach_neighbor(f, neighbor_system_coords,
neighborhood_search::GridNeighborhoodSearch,
point, point_coords, search_radius)
(; periodic_box) = neighborhood_search

cell = cell_coords(point_coords, neighborhood_search)
Expand All @@ -393,8 +382,7 @@ end
distance2 = dot(pos_diff, pos_diff)

pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2,
search_radius,
periodic_box)
search_radius, periodic_box)

if distance2 <= search_radius^2
distance = sqrt(distance2)
Expand Down
15 changes: 10 additions & 5 deletions src/nhs_precomputed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ end

@inline Base.ndims(::PrecomputedNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS

@inline requires_update(::PrecomputedNeighborhoodSearch) = (true, true)

@inline function search_radius(search::PrecomputedNeighborhoodSearch)
return search_radius(search.neighborhood_search)
end
Expand Down Expand Up @@ -92,14 +94,17 @@ function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y)
end
end

@inline function foreach_neighbor(f, system_coords, neighbor_system_coords,
@inline function foreach_neighbor(f, neighbor_system_coords,
neighborhood_search::PrecomputedNeighborhoodSearch,
point; search_radius = nothing)
point, point_coords, search_radius)
(; periodic_box, neighbor_lists) = neighborhood_search
(; search_radius) = neighborhood_search.neighborhood_search

point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
for neighbor in neighbor_lists[point]
neighbors = @inbounds neighbor_lists[point]
for neighbor_ in eachindex(neighbors)
neighbor = @inbounds neighbors[neighbor_]

# Making the following `@inbounds` yields a ~2% speedup on an NVIDIA H100.
# But we don't know if `neighbor` (extracted from the cell list) is in bounds.
neighbor_coords = extract_svector(neighbor_system_coords,
Val(ndims(neighborhood_search)), neighbor)

Expand Down
2 changes: 2 additions & 0 deletions src/nhs_trivial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ end

@inline Base.ndims(::TrivialNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS

@inline requires_update(::TrivialNeighborhoodSearch) = (false, false)

@inline initialize!(search::TrivialNeighborhoodSearch, x, y) = search

@inline function update!(search::TrivialNeighborhoodSearch, x, y;
Expand Down
Loading