Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,36 @@ inrange!(idxs, balltree, point, r)
neighborscount = inrangecount(balltree, point, r)
```

### Self-Pair Searches

Find all pairs of points within a tree that are within a given radius of each other:

```julia
inrange_pairs(tree, radius, sortres) -> pairs
```

* `tree`: The tree instance (KDTree, BallTree, or BruteTree).
* `radius`: Search radius.
* `sortres` (optional): Sort the result pairs (default: false).


Returns a vector of tuples `(i, j)` where `i < j` representing pairs of point indices within the radius.

Example:

```julia
using NearestNeighbors
data = rand(3, 100)
kdtree = KDTree(data)

# Find all pairs within radius 0.1
pairs = inrange_pairs(kdtree, 0.1)

# pairs might look like:
# [(1, 5), (2, 47), (3, 89), ...]
# Each tuple (i,j) means points i and j are within distance 0.1
```

## Periodic Boundary Conditions

The `PeriodicTree` provides nearest neighbor searches with periodic boundary conditions. It reuses an internal deduplication buffer, so the same `PeriodicTree` instance should not be queried concurrently from multiple threads without external synchronization.
Expand Down
11 changes: 7 additions & 4 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using StaticArrays: StaticArrays, MVector, SVector
using Base: setindex

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree
export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs
export knn, knn!, nn, inrange, inrange!, inrangecount, inrange_pairs # TODOs?, npairs
export injectdata

export Euclidean,
Expand Down Expand Up @@ -64,9 +64,12 @@ include("knn.jl")
include("inrange.jl")

for dim in (2, 3)
tree = KDTree(rand(dim, 10))
knn(tree, rand(dim), 5)
inrange(tree, rand(dim), 0.5)
for Tree in (KDTree, BallTree)
tree = Tree(rand(dim, 10))
knn(tree, rand(dim), 5)
inrange(tree, rand(dim), 0.5)
inrange_pairs(tree, 0.5)
end
end

end # module
160 changes: 160 additions & 0 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,163 @@ function inrange_kernel!(tree::BallTree,
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, dedup)
end
end

# Add every pair from two subtrees without distance checks once their bounds are fully inside the radius
function _addall_balltree_self!(results::Vector{NTuple{2,Int}}, tree::BallTree, idx::Int, other_idx::Int, skip)
leaf_here = isleaf(tree.tree_data.n_internal_nodes, idx)
leaf_other = isleaf(tree.tree_data.n_internal_nodes, other_idx)
if leaf_here
if leaf_other
_addall_balltree_self_leaf_pairs!(results, tree, idx, other_idx, skip)
else
_addall_balltree_self!(results, tree, idx, getleft(other_idx), skip)
_addall_balltree_self!(results, tree, idx, getright(other_idx), skip)
end
else
if leaf_other
_addall_balltree_self!(results, tree, getleft(idx), other_idx, skip)
_addall_balltree_self!(results, tree, getright(idx), other_idx, skip)
else
_addall_balltree_self!(results, tree, getleft(idx), getleft(other_idx), skip)
_addall_balltree_self!(results, tree, getleft(idx), getright(other_idx), skip)
if idx == other_idx
_addall_balltree_self!(results, tree, getright(idx), getright(other_idx), skip)
else
_addall_balltree_self!(results, tree, getright(idx), getleft(other_idx), skip)
_addall_balltree_self!(results, tree, getright(idx), getright(other_idx), skip)
end
end
end
return
end

# Add all pairs between two leaf nodes when every combination is known to be within the radius
function _addall_balltree_self_leaf_pairs!(results::Vector{NTuple{2,Int}}, tree::BallTree, leaf_idx::Int, other_leaf_idx::Int, skip)
point_range = get_leaf_range(tree.tree_data, leaf_idx)
if leaf_idx == other_leaf_idx
@inbounds for i in point_range
idx_i = tree.indices[i]
skip(idx_i) && continue
for j in (i + 1):last(point_range)
idx_j = tree.indices[j]
if skip(idx_j)
continue
end
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
push!(results, (a, b))
end
end
else
query_range = get_leaf_range(tree.tree_data, other_leaf_idx)
@inbounds for i in point_range
idx_i = tree.indices[i]
skip(idx_i) && continue
for j in query_range
idx_j = tree.indices[j]
if skip(idx_j)
continue
end
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
push!(results, (a, b))
end
end
end
return
end

# Add only the leaf pairs that satisfy the radius when bounds overlap but are not fully enclosed
function _add_balltree_self_leaf_pairs!(results::Vector{NTuple{2,Int}}, tree::BallTree, leaf_idx::Int, other_leaf_idx::Int, r::Number, skip)
point_range = get_leaf_range(tree.tree_data, leaf_idx)
is_minkowski = tree.metric isa MinkowskiMetric
if leaf_idx == other_leaf_idx
@inbounds for i in point_range
idx_i = tree.indices[i]
skip(idx_i) && continue
point_i = tree.data[tree.reordered ? i : tree.indices[i]]
for j in (i + 1):last(point_range)
idx_j = tree.indices[j]
if skip(idx_j)
continue
end
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
push!(results, (a, b))
end
end
end
else
query_range = get_leaf_range(tree.tree_data, other_leaf_idx)
@inbounds for i in point_range
idx_i = tree.indices[i]
skip(idx_i) && continue
point_i = tree.data[tree.reordered ? i : tree.indices[i]]
for j in query_range
idx_j = tree.indices[j]
if skip(idx_j)
continue
end
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
push!(results, (a, b))
end
end
end
end
return
end

function _inrange_balltree_self!(results::Vector{NTuple{2,Int}},
tree::BallTree,
idx::Int,
other_idx::Int,
r::Number,
skip::F) where {F}
if idx > other_idx
idx, other_idx = other_idx, idx
end

sphere = tree.hyper_spheres[idx]
other_sphere = tree.hyper_spheres[other_idx]
min_d, max_d = get_min_max_distance(tree.metric, sphere, other_sphere)
if min_d > r
return
elseif max_d < r
_addall_balltree_self!(results, tree, idx, other_idx, skip)
return
end

leaf_here = isleaf(tree.tree_data.n_internal_nodes, idx)
leaf_other = isleaf(tree.tree_data.n_internal_nodes, other_idx)
if leaf_here
if leaf_other
_add_balltree_self_leaf_pairs!(results, tree, idx, other_idx, r, skip)
else
_inrange_balltree_self!(results, tree, idx, getleft(other_idx), r, skip)
_inrange_balltree_self!(results, tree, idx, getright(other_idx), r, skip)
end
else
if leaf_other
_inrange_balltree_self!(results, tree, getleft(idx), other_idx, r, skip)
_inrange_balltree_self!(results, tree, getright(idx), other_idx, r, skip)
else
_inrange_balltree_self!(results, tree, getleft(idx), getleft(other_idx), r, skip)
_inrange_balltree_self!(results, tree, getleft(idx), getright(other_idx), r, skip)
if idx == other_idx
_inrange_balltree_self!(results, tree, getright(idx), getright(other_idx), r, skip)
else
_inrange_balltree_self!(results, tree, getright(idx), getleft(other_idx), r, skip)
_inrange_balltree_self!(results, tree, getright(idx), getright(other_idx), r, skip)
end
end
end
return
end

function _inrange_pairs(tree::BallTree{V}, radius::Number, sortres, skip::F) where {V, F}
isempty(tree.data) && return NTuple{2,Int}[]
pairs = NTuple{2,Int}[]
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, radius) : radius
_inrange_balltree_self!(pairs, tree, 1, 1, r, skip)
sortres && sort!(pairs)
return pairs
end
16 changes: 16 additions & 0 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,19 @@ function inrange_kernel!(tree::BruteTree,
end
return count
end

function _inrange_pairs(tree::BruteTree{V}, radius::Number, sortres, skip::F) where {V, F}
pairs = NTuple{2,Int}[]
for i in 1:length(tree.data)
skip(i) && continue
for j in (i + 1):length(tree.data)
skip(j) && continue
d = evaluate(tree.metric, tree.data[i], tree.data[j])
if d <= radius
push!(pairs, (i, j))
end
end
end
sortres && sort!(pairs)
return pairs
end
25 changes: 25 additions & 0 deletions src/hyperrectangles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ get_min_distance_no_end(m, rec, point) =
return old_min + diff_tot
end

# Compute min and max possible distances between two hyper rectangles
function get_min_max_distance(m::Metric, r1::HyperRectangle{V}, r2::HyperRectangle{V}) where {V}
p = Distances.parameters(m)
T = eltype(V)
min_acc = zero(T)
max_acc = zero(T)
@inbounds for dim in eachindex(r1.mins)
lo1 = r1.mins[dim]; hi1 = r1.maxes[dim]
lo2 = r2.mins[dim]; hi2 = r2.maxes[dim]
min_raw = if hi1 < lo2
lo2 - hi1
elseif hi2 < lo1
lo1 - hi2
else
zero(T)
end
max_raw = max(abs(hi1 - lo2), abs(hi2 - lo1))
min_op = p === nothing ? eval_op(m, min_raw, zero(T)) : eval_op(m, min_raw, zero(T), p[dim])
max_op = p === nothing ? eval_op(m, max_raw, zero(T)) : eval_op(m, max_raw, zero(T), p[dim])
min_acc = eval_reduce(m, min_acc, min_op)
max_acc = eval_reduce(m, max_acc, max_op)
end
return min_acc, max_acc
end

# Compute per-dimension contributions for max distance
function get_max_distance_contributions(m::Metric, rec::HyperRectangle{V}, point::AbstractVector{T}) where {V,T}
p = Distances.parameters(m)
Expand Down
7 changes: 7 additions & 0 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,10 @@ function distance_to_sphere(metric::Metric, point, sphere::HyperSphere)
dist = evaluate(metric, point, sphere.center) - sphere.r
return max(zero(eltype(dist)), dist)
end

@inline function get_min_max_distance(m::Metric, s1::HyperSphere, s2::HyperSphere)
dist = evaluate(m, s1.center, s2.center)
min_d = max(zero(dist), dist - (s1.r + s2.r))
max_d = dist + s1.r + s2.r
return min_d, max_d
end
5 changes: 5 additions & 0 deletions src/inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,8 @@ function inrangecount_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius:
end
return counts
end

function inrange_pairs(tree::NNTree, radius::Number, sortres=false, skip::F=Returns(false)) where {F}
check_radius(radius)
return _inrange_pairs(tree, radius, sortres, skip)
end
Loading