Skip to content

Commit 0bd448b

Browse files
authored
add allknn to find the knn of each point in a tree (#236)
1 parent c073298 commit 0bd448b

9 files changed

Lines changed: 166 additions & 33 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,20 @@ A kNN search finds the `k` nearest neighbors to a given point or points. This is
7272
```julia
7373
knn(tree, point[s], k [, skip=Returns(false)]) -> idxs, dists
7474
knn!(idxs, dists, tree, point, k [, skip=Returns(false)])
75+
allknn(tree, k [, skip=Returns(false)]) -> idxs, dists
7576
```
7677

7778
* `tree`: The tree instance.
7879
* `point[s]`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors.
7980
* `k`: Number of nearest neighbors to find.
8081
* `skip` (optional): A predicate function to skip certain points, e.g., points already visited.
81-
82+
* `allknn`: Finds the `k` nearest neighbors for every point stored in the tree itself, automatically excluding the point being queried.
8283

8384
For the single closest neighbor, you can use `nn`:
8485

8586
```julia
8687
nn(tree, point[s] [, skip=Returns(false)]) -> idx, dist
88+
allnn(tree [, skip=Returns(false)]) -> idxs, dists
8789
```
8890

8991
Examples:

src/NearestNeighbors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Base: setindex
77
using AbstractTrees: AbstractTrees
88

99
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree
10-
export knn, knn!, nn, inrange, inrange!, inrangecount, inrange_pairs # TODOs?, npairs
10+
export knn, knn!, nn, allnn, allknn, inrange, inrange!, inrangecount, inrange_pairs # TODOs?, npairs
1111
export injectdata
1212
export TreeNode, treeroot, leafpoints, leaf_point_indices, treeregion
1313
export preorder, postorder, leaves

src/ball_tree.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ function _knn(tree::BallTree,
167167
best_idxs::Union{Integer, AbstractVector{<:Integer}},
168168
best_dists::Union{Number, AbstractVector},
169169
::Union{Nothing, AbstractVector},
170-
skip::F) where {F}
171-
return knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, nothing)
170+
skip::F,
171+
self_idx::Int=0) where {F}
172+
return knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, nothing, self_idx)
172173
end
173174

174175

@@ -178,9 +179,10 @@ function knn_kernel!(tree::BallTree{V},
178179
best_idxs::Union{Integer, AbstractVector{<:Integer}},
179180
best_dists::Union{Number, AbstractVector},
180181
skip::F,
181-
dedup::MaybeBitSet) where {V, F}
182+
dedup::MaybeBitSet,
183+
self_idx::Int=0) where {V, F}
182184
if isleaf(tree.tree_data.n_internal_nodes, index)
183-
return add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, dedup)
185+
return add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, dedup, self_idx)
184186
end
185187

186188
left_sphere = tree.hyper_spheres[getleft(index)]
@@ -192,16 +194,16 @@ function knn_kernel!(tree::BallTree{V},
192194
best_dist_1 = first(best_dists)
193195
if left_dist <= best_dist_1 || right_dist <= best_dist_1
194196
if left_dist < right_dist
195-
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup)
197+
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup, self_idx)
196198
best_dist_1 = first(best_dists)
197199
if right_dist <= best_dist_1
198-
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup)
200+
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup, self_idx)
199201
end
200202
else
201-
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup)
203+
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup, self_idx)
202204
best_dist_1 = first(best_dists)
203205
if left_dist <= best_dist_1
204-
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup)
206+
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup, self_idx)
205207
end
206208
end
207209
end

src/brute_tree.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function BruteTree(data::AbstractVector{V}, metric::PreMetric = Euclidean();
3030
end
3131
end
3232

33-
BruteTree(storedata ? Vector(data) : Vector{V}(), metric, reorder)
33+
BruteTree(storedata ? Vector(data) : Vector{V}(), metric, false)
3434
end
3535

3636
function BruteTree(data::AbstractVecOrMat{T}, metric::PreMetric = Euclidean();
@@ -46,20 +46,22 @@ function _knn(tree::BruteTree{V},
4646
best_idxs::Union{Integer, AbstractVector{<:Integer}},
4747
best_dists::Union{Number, AbstractVector},
4848
::Union{Nothing, AbstractVector},
49-
skip::F) where {V, F}
49+
skip::F,
50+
self_idx::Int=0) where {V, F}
5051

51-
return knn_kernel!(tree, point, best_idxs, best_dists, skip, nothing)
52+
return knn_kernel!(tree, point, best_idxs, best_dists, skip, nothing, self_idx)
5253
end
5354

5455
function knn_kernel!(tree::BruteTree{V},
5556
point::AbstractVector,
5657
best_idxs::Union{Integer, AbstractVector{<:Integer}},
5758
best_dists::Union{Number, AbstractVector},
5859
skip::F,
59-
dedup::MaybeBitSet) where {V, F}
60+
dedup::MaybeBitSet,
61+
self_idx::Int=0) where {V, F}
6062
has_set = dedup !== nothing
6163
for i in 1:length(tree.data)
62-
if skip(i)
64+
if i == self_idx || skip(i)
6365
continue
6466
end
6567

src/kd_tree.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ function _knn(tree::KDTree,
160160
best_idxs::Union{Integer, AbstractVector{<:Integer}},
161161
best_dists::Union{Number, AbstractVector},
162162
best_dists_final::Union{Nothing, AbstractVector},
163-
skip::F) where {F}
163+
skip::F,
164+
self_idx::Int=0) where {F}
164165
init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point)
165-
best_idxs, best_dists = knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip, nothing)
166+
best_idxs, best_dists = knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip, nothing, self_idx)
166167
best_dists isa Number && return best_idxs, eval_end(tree.metric, best_dists)
167168
@simd for i in eachindex(best_dists)
168169
@inbounds best_dists_final[i] = eval_end(tree.metric, best_dists[i])
@@ -178,10 +179,11 @@ function knn_kernel!(tree::KDTree{V},
178179
min_dist,
179180
hyper_rec::HyperRectangle,
180181
skip::F,
181-
dedup::MaybeBitSet) where {V, F}
182+
dedup::MaybeBitSet,
183+
self_idx::Int=0) where {V, F}
182184
# At a leaf node. Go through all points in node and add those in range
183185
if isleaf(tree.tree_data.n_internal_nodes, index)
184-
return add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip, dedup)
186+
return add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip, dedup, self_idx)
185187
end
186188

187189
split_dim = tree.split_dims[index]
@@ -205,7 +207,7 @@ function knn_kernel!(tree::KDTree{V},
205207
hyper_rec_close = left_region
206208
end
207209
# Always call closer sub tree
208-
best_idxs, best_dists = knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip, dedup)
210+
best_idxs, best_dists = knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip, dedup, self_idx)
209211

210212
if M isa Chebyshev
211213
new_min = get_min_distance_no_end(M, hyper_rec_far, point)
@@ -215,7 +217,7 @@ function knn_kernel!(tree::KDTree{V},
215217

216218
best_dist_1 = first(best_dists)
217219
if new_min < best_dist_1
218-
best_idxs, best_dists = knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip, dedup)
220+
best_idxs, best_dists = knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip, dedup, self_idx)
219221
end
220222
return best_idxs, best_dists
221223
end

src/knn.jl

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,58 @@ function check_k(tree, k)
44
end
55
end
66

7+
"""
8+
allnn(tree::NNTree [, skip=always_false]) -> indices, distances
9+
10+
Compute the nearest neighbor for every point stored in `tree`, excluding each
11+
point itself. Returns two vectors of length `npoints` containing the neighbor
12+
index and distance for each point.
13+
"""
14+
function allnn(tree::NNTree{V}, skip::F=Returns(false)) where {V, F<:Function}
15+
inner_tree = get_tree(tree)
16+
n_points = length(inner_tree.data)
17+
n_points == 0 && return Vector{Int}(), Vector{get_T(eltype(V))}()
18+
n_points == 1 && throw(ArgumentError("allnn requires at least 2 points"))
19+
20+
idxs = Vector{Int}(undef, n_points)
21+
dists = Vector{get_T(eltype(V))}(undef, n_points)
22+
23+
for i in 1:n_points
24+
orig_idx = inner_tree.reordered ? inner_tree.indices[i] : i
25+
best_idx, best_dist = _knn(tree, inner_tree.data[i], -1, dist_typemax(inner_tree), nothing, skip, orig_idx)
26+
idxs[orig_idx] = inner_tree.reordered ? inner_tree.indices[best_idx] : best_idx
27+
dists[orig_idx] = best_dist
28+
end
29+
30+
return idxs, dists
31+
end
32+
33+
"""
34+
allknn(tree::NNTree, k [, sortres=false, skip=always_false]) -> indices, distances
35+
36+
Compute the `k` nearest neighbors for every point stored in `tree`, excluding
37+
each point itself. Returns two vectors of length `npoints`, each containing a
38+
length-`k` vector of neighbor indices and distances, respectively. Set
39+
`sortres=true` to order neighbors by distance.
40+
"""
41+
function allknn(tree::NNTree{V}, k::Int, sortres=false, skip::F=Returns(false)) where {V, F<:Function}
42+
inner_tree = get_tree(tree)
43+
n_points = length(inner_tree.data)
44+
n_points == 0 && return Vector{Vector{Int}}(), Vector{Vector{get_T(eltype(V))}}()
45+
k < 0 && throw(ArgumentError("k < 0"))
46+
k <= n_points - 1 || throw(ArgumentError("k must be <= number of points - 1 for allknn"))
47+
48+
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
49+
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
50+
51+
for i in 1:n_points
52+
orig_idx = inner_tree.reordered ? inner_tree.indices[i] : i
53+
knn_point!(tree, inner_tree.data[i], sortres, dists[orig_idx], idxs[orig_idx], skip, orig_idx)
54+
end
55+
56+
return idxs, dists
57+
end
58+
759
"""
860
knn(tree::NNTree, points, k [, skip=always_false]) -> indices, distances
961
@@ -35,10 +87,10 @@ function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false,
3587
return idxs, dists
3688
end
3789

38-
knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} =
39-
_knn_point!(tree, point, sortres, dist, idx, skip)
90+
knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F, self_idx::Int=0) where {V, T <: Number, F} =
91+
_knn_point!(tree, point, sortres, dist, idx, skip, self_idx)
4092

41-
function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist_final, idx, skip::F) where {V, T <: Number, F}
93+
function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist_final, idx, skip::F, self_idx::Int) where {V, T <: Number, F}
4294
fill!(idx, -1)
4395
inner_tree = get_tree(tree)
4496

@@ -51,7 +103,7 @@ function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist_fi
51103
end
52104
fill!(dist_internal, dist_typemax(inner_tree))
53105

54-
_knn(tree, point, idx, dist_internal, dist_final, skip)
106+
_knn(tree, point, idx, dist_internal, dist_final, skip, self_idx)
55107

56108
if skip !== Returns(false)
57109
skipped_idxs = findall(==(-1), idx)
@@ -138,7 +190,7 @@ See also: `knn`.
138190
function nn(tree::NNTree{V}, point::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function}
139191
check_for_nan_in_points(point)
140192
check_k(tree, 1)
141-
best_idx, best_dist = _knn(tree, point, -1, dist_typemax(get_tree(tree)), nothing, skip)
193+
best_idx, best_dist = _knn(tree, point, -1, dist_typemax(get_tree(tree)), nothing, skip, 0)
142194
inner_tree = get_tree(tree)
143195
final_idx = inner_tree.reordered ? inner_tree.indices[best_idx] : best_idx
144196
return final_idx, best_dist

src/periodic_tree.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ function _knn(tree::PeriodicTree{V,M},
180180
best_idxs::Union{Integer, AbstractVector{<:Integer}},
181181
best_dists::Union{Number, AbstractVector},
182182
best_dists_final::Union{Nothing, AbstractVector},
183-
skip::F) where {V, M, F}
183+
skip::F,
184+
self_idx::Int=0) where {V, M, F}
184185

185186
dedup_state = empty!(tree.dedup_set)
186187
# Search all periodic mirror boxes
@@ -207,12 +208,12 @@ function _knn(tree::PeriodicTree{V,M},
207208

208209
# Search the underlying tree with the shifted query point
209210
if tree.tree isa KDTree
210-
best_idxs, best_dists = knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, min_dist_to_canonical, tree.tree.hyper_rec, skip, dedup_state)
211+
best_idxs, best_dists = knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, min_dist_to_canonical, tree.tree.hyper_rec, skip, dedup_state, self_idx)
211212
elseif tree.tree isa BallTree
212-
best_idxs, best_dists = knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, skip, dedup_state)
213+
best_idxs, best_dists = knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, skip, dedup_state, self_idx)
213214
else
214215
@assert tree.tree isa BruteTree
215-
best_idxs, best_dists = knn_kernel!(tree.tree, point_shifted, best_idxs, best_dists, skip, dedup_state)
216+
best_idxs, best_dists = knn_kernel!(tree.tree, point_shifted, best_idxs, best_dists, skip, dedup_state, self_idx)
216217
end
217218
end
218219

src/tree_ops.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,14 @@ end
121121
best_idxs::Union{Integer, AbstractVector{<:Integer}},
122122
tree::NNTree, index::Int, point::AbstractVector,
123123
do_end::Bool, skip::F,
124-
dedup::MaybeBitSet) where {F}
124+
dedup::MaybeBitSet, self_idx::Int=0) where {F}
125125
has_set = dedup !== nothing
126126
for z in get_leaf_range(tree.tree_data, index)
127-
if skip(tree.indices[z])
127+
orig_idx = tree.indices[z]
128+
if orig_idx == self_idx || skip(orig_idx)
128129
continue
129130
end
130-
idx = tree.reordered ? z : tree.indices[z]
131+
idx = tree.reordered ? z : orig_idx
131132
dist_d = evaluate_maybe_end(tree.metric, tree.data[idx], point, do_end)
132133
update_existing_neighbor!(dedup, idx, dist_d, best_idxs, best_dists) && continue
133134
best_dist_1 = first(best_dists)

test/test_allknn.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
module TestAllKNN
2+
isdefined(Main, :TestSetup) || @eval Main include(joinpath(@__DIR__, "TestSetup.jl"))
3+
4+
using NearestNeighbors
5+
using ..Main.TestSetup: trees_with_brute
6+
using StableRNGs
7+
using StaticArrays
8+
using Test
9+
10+
@testset "allknn" begin
11+
rng = StableRNG(42)
12+
data = rand(rng, 3, 40)
13+
k = 3
14+
bounds_min = fill(0.0, 3)
15+
bounds_max = fill(1.0, 3)
16+
17+
for reorder in (false, true)
18+
for Tree in trees_with_brute
19+
# BruteTree doesn't support reorder (no indices field)
20+
tree = Tree === BruteTree ? Tree(data) : Tree(data; leafsize=8, reorder)
21+
for treelike in (tree, PeriodicTree(tree, bounds_min, bounds_max))
22+
idxs, dists = allknn(treelike, k, true)
23+
@test length(idxs) == size(data, 2)
24+
@test length(dists) == size(data, 2)
25+
@test all(length(v) == k for v in idxs)
26+
@test all(length(v) == k for v in dists)
27+
28+
# Matches per-point knn with an explicit self-skip (original indices)
29+
for i in 1:size(data, 2)
30+
point = SVector{3,Float64}(data[:, i])
31+
skip_self = x -> x == i
32+
exp_idxs, exp_dists = knn(treelike, point, k, true, skip_self)
33+
@test idxs[i] == exp_idxs
34+
@test dists[i] == exp_dists
35+
@test !(i in idxs[i])
36+
end
37+
end
38+
end
39+
end
40+
end
41+
42+
@testset "allnn" begin
43+
rng = StableRNG(123)
44+
data = rand(rng, 3, 30)
45+
bounds_min = fill(0.0, 3)
46+
bounds_max = fill(1.0, 3)
47+
48+
for reorder in (false, true)
49+
for Tree in trees_with_brute
50+
# BruteTree doesn't support reorder (no indices field)
51+
tree = Tree === BruteTree ? Tree(data) : Tree(data; leafsize=8, reorder)
52+
for treelike in (tree, PeriodicTree(tree, bounds_min, bounds_max))
53+
idxs, dists = allnn(treelike)
54+
@test length(idxs) == size(data, 2)
55+
@test length(dists) == size(data, 2)
56+
57+
# Matches per-point nn with an explicit self-skip
58+
for i in 1:size(data, 2)
59+
point = SVector{3,Float64}(data[:, i])
60+
skip_self = x -> x == i
61+
exp_idx, exp_dist = nn(treelike, point, skip_self)
62+
@test idxs[i] == exp_idx
63+
@test dists[i] == exp_dist
64+
@test idxs[i] != i
65+
end
66+
end
67+
end
68+
end
69+
end
70+
71+
end # module

0 commit comments

Comments
 (0)