diff --git a/src/neighborhood_search.jl b/src/neighborhood_search.jl index b3ba12ca..d080bce8 100644 --- a/src/neighborhood_search.jl +++ b/src/neighborhood_search.jl @@ -1,6 +1,7 @@ abstract type AbstractNeighborhoodSearch end @inline search_radius(search::AbstractNeighborhoodSearch) = search.search_radius +@inline Base.eltype(search::AbstractNeighborhoodSearch) = eltype(search_radius(search)) """ requires_update(search::AbstractNeighborhoodSearch) @@ -220,7 +221,7 @@ end neighbor_coords = extract_svector(neighbor_system_coords, Val(ndims(neighborhood_search)), neighbor) - pos_diff = point_coords - neighbor_coords + pos_diff = convert.(eltype(neighborhood_search), point_coords - neighbor_coords) distance2 = dot(pos_diff, pos_diff) pos_diff, diff --git a/src/nhs_grid.jl b/src/nhs_grid.jl index ef55284c..b1ee4faf 100644 --- a/src/nhs_grid.jl +++ b/src/nhs_grid.jl @@ -523,7 +523,7 @@ end neighbor_coords = extract_svector(neighbor_system_coords, Val(ndims(neighborhood_search)), neighbor) - pos_diff = point_coords - neighbor_coords + pos_diff = convert.(eltype(neighborhood_search), point_coords - neighbor_coords) distance2 = dot(pos_diff, pos_diff) pos_diff, diff --git a/src/nhs_precomputed.jl b/src/nhs_precomputed.jl index 69749f1c..60d35a4e 100644 --- a/src/nhs_precomputed.jl +++ b/src/nhs_precomputed.jl @@ -110,12 +110,15 @@ end for neighbor_ in eachindex(neighbors) neighbor = @inbounds neighbors[neighbor_] - # Making the following `@inbounds` yields a ~2% speedup on an NVIDIA H100. - # But we don't know if `neighbor` (extracted from the cell list) is in bounds. - neighbor_coords = extract_svector(neighbor_system_coords, - Val(ndims(neighborhood_search)), neighbor) - - pos_diff = point_coords - neighbor_coords + # Making this `@inbounds` is not perfectly safe because + # `neighbor` (extracted from the neighbor list) is only guaranteed to be in bounds + # if the neighbor lists were constructed correctly and have not been corrupted. + # However, adding this `@inbounds` yields a ~20% speedup for TLSPH on GPUs (A4500). + neighbor_coords = @inbounds extract_svector(neighbor_system_coords, + Val(ndims(neighborhood_search)), + neighbor) + + pos_diff = convert.(eltype(neighborhood_search), point_coords - neighbor_coords) distance2 = dot(pos_diff, pos_diff) pos_diff,