Skip to content

Commit 2317e22

Browse files
committed
add skip_self for inrange and knn
1 parent c073298 commit 2317e22

10 files changed

Lines changed: 209 additions & 92 deletions

File tree

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ knn!(idxs, dists, tree, point, k [, skip=Returns(false)])
7878
* `point[s]`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors.
7979
* `k`: Number of nearest neighbors to find.
8080
* `skip` (optional): A predicate function to skip certain points, e.g., points already visited.
81+
* `skip_self` (optional, batched queries only): Skip the point with the same index as the current query when the query set is identical to the tree data, e.g. `knn(tree, data, 1; skip_self=true)`.
8182

8283

8384
For the single closest neighbor, you can use `nn`:
@@ -145,6 +146,12 @@ dists
145146
# 0.04556078331418939
146147
# 0.049967238112417205
147148

149+
# Self-query the same dataset without returning each point as its own neighbor
150+
idxs, dists = knn(kdtree, data, 1; skip_self=true)
151+
152+
# Retrieve just the nearest neighbor per point
153+
nn_idx, nn_dist = nn(kdtree, data; skip_self=true)
154+
148155
# Preallocating input results
149156
idxs, dists = zeros(Int32, k), zeros(Float32, k)
150157
knn!(idxs, dists, kdtree, v, k)
@@ -162,6 +169,8 @@ inrange!(idxs, tree, point, radius)
162169
* `tree`: The tree instance.
163170
* `point[s]`: A vector or matrix of points to find neighbors for.
164171
* `radius`: Search radius.
172+
* `skip` (optional): Predicate to skip certain points.
173+
* `skip_self` (optional, batched queries only): When querying the same dataset, skip the point whose index matches the query.
165174

166175
Note: Distances are not returned, only indices.
167176

@@ -188,6 +197,10 @@ inrange!(idxs, balltree, point, r)
188197

189198
# counts points without allocating index arrays
190199
neighborscount = inrangecount(balltree, point, r)
200+
201+
# Self-query without returning each point itself
202+
idxs_self = inrange(balltree, data, r; skip_self=true)
203+
counts_self = inrangecount(balltree, data, r; skip_self=true)
191204
```
192205

193206
### Self-Pair Searches

src/ball_tree.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ function _knn(tree::BallTree,
167167
best_idxs::Union{Integer, AbstractVector{<:Integer}},
168168
best_dists::Union{Number, AbstractVector},
169169
::Union{Nothing, AbstractVector},
170-
skip::F) where {F}
171-
return knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, nothing)
170+
skip::F,
171+
self_idx::Int) where {F}
172+
return knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, nothing, self_idx)
172173
end
173174

174175

@@ -178,9 +179,10 @@ function knn_kernel!(tree::BallTree{V},
178179
best_idxs::Union{Integer, AbstractVector{<:Integer}},
179180
best_dists::Union{Number, AbstractVector},
180181
skip::F,
181-
dedup::MaybeBitSet) where {V, F}
182+
dedup::MaybeBitSet,
183+
self_idx::Int) where {V, F}
182184
if isleaf(tree.tree_data.n_internal_nodes, index)
183-
return add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, dedup)
185+
return add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, dedup, self_idx)
184186
end
185187

186188
left_sphere = tree.hyper_spheres[getleft(index)]
@@ -192,16 +194,16 @@ function knn_kernel!(tree::BallTree{V},
192194
best_dist_1 = first(best_dists)
193195
if left_dist <= best_dist_1 || right_dist <= best_dist_1
194196
if left_dist < right_dist
195-
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup)
197+
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup, self_idx)
196198
best_dist_1 = first(best_dists)
197199
if right_dist <= best_dist_1
198-
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup)
200+
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup, self_idx)
199201
end
200202
else
201-
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup)
203+
best_idxs, best_dists = knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, dedup, self_idx)
202204
best_dist_1 = first(best_dists)
203205
if left_dist <= best_dist_1
204-
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup)
206+
best_idxs, best_dists = knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, dedup, self_idx)
205207
end
206208
end
207209
end
@@ -212,9 +214,10 @@ function _inrange(tree::BallTree{V},
212214
point::AbstractVector,
213215
radius::Number,
214216
idx_in_ball::Union{Nothing, Vector{<:Integer}},
215-
skip::F) where {V, F}
217+
skip::F,
218+
self_idx::Int) where {V, F}
216219
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
217-
return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, nothing) # Call the recursive range finder
220+
return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, nothing, self_idx) # Call the recursive range finder
218221
end
219222

220223
function inrange_kernel!(tree::BallTree,
@@ -223,7 +226,8 @@ function inrange_kernel!(tree::BallTree,
223226
query_ball::HyperSphere,
224227
idx_in_ball::Union{Nothing, Vector{<:Integer}},
225228
skip::F,
226-
dedup::MaybeBitSet) where {F}
229+
dedup::MaybeBitSet,
230+
self_idx::Int) where {F}
227231

228232
if index > length(tree.hyper_spheres)
229233
return 0
@@ -241,17 +245,17 @@ function inrange_kernel!(tree::BallTree,
241245
# At a leaf node, check all points in the leaf node
242246
if isleaf(tree.tree_data.n_internal_nodes, index)
243247
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r
244-
return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, dedup)
248+
return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, dedup, self_idx)
245249
end
246250

247251
# The query ball encloses the sub tree bounding sphere. Add all points in the
248252
# sub tree without checking the distance function.
249253
if encloses_fast(dist, tree.metric, sphere, query_ball)
250-
return addall(tree, index, idx_in_ball, skip, dedup)
254+
return addall(tree, index, idx_in_ball, skip, dedup, self_idx)
251255
else
252256
# Recursively call the left and right sub tree.
253-
return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, dedup) +
254-
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, dedup)
257+
return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, dedup, self_idx) +
258+
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, dedup, self_idx)
255259
end
256260
end
257261

src/brute_tree.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,22 @@ function _knn(tree::BruteTree{V},
4646
best_idxs::Union{Integer, AbstractVector{<:Integer}},
4747
best_dists::Union{Number, AbstractVector},
4848
::Union{Nothing, AbstractVector},
49-
skip::F) where {V, F}
49+
skip::F,
50+
self_idx::Int) where {V, F}
5051

51-
return knn_kernel!(tree, point, best_idxs, best_dists, skip, nothing)
52+
return knn_kernel!(tree, point, best_idxs, best_dists, skip, nothing, self_idx)
5253
end
5354

5455
function knn_kernel!(tree::BruteTree{V},
5556
point::AbstractVector,
5657
best_idxs::Union{Integer, AbstractVector{<:Integer}},
5758
best_dists::Union{Number, AbstractVector},
5859
skip::F,
59-
dedup::MaybeBitSet) where {V, F}
60+
dedup::MaybeBitSet,
61+
self_idx::Int) where {V, F}
6062
has_set = dedup !== nothing
6163
for i in 1:length(tree.data)
62-
if skip(i)
64+
if (i == self_idx) || skip(i)
6365
continue
6466
end
6567

@@ -80,8 +82,9 @@ function _inrange(tree::BruteTree,
8082
point::AbstractVector,
8183
radius::Number,
8284
idx_in_ball::Union{Nothing, Vector{<:Integer}},
83-
skip::F,) where {F}
84-
return inrange_kernel!(tree, point, radius, idx_in_ball, skip, nothing)
85+
skip::F,
86+
self_idx::Int) where {F}
87+
return inrange_kernel!(tree, point, radius, idx_in_ball, skip, nothing, self_idx)
8588
end
8689

8790

@@ -90,11 +93,12 @@ function inrange_kernel!(tree::BruteTree,
9093
r::Number,
9194
idx_in_ball::Union{Nothing, Vector{<:Integer}},
9295
skip::Function,
93-
dedup::MaybeBitSet)
96+
dedup::MaybeBitSet,
97+
self_idx::Int)
9498
count = 0
9599
has_set = dedup !== nothing
96100
for i in 1:length(tree.data)
97-
if skip(i)
101+
if (i == self_idx) || skip(i)
98102
continue
99103
end
100104
d = evaluate(tree.metric, tree.data[i], point)

src/inrange.jl

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Find all the points in the tree which are closer than `radius` to `points`.
99
- `tree`: The tree instance
1010
- `points`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors
1111
- `radius`: Search radius
12+
- `skip` (optional): Predicate to skip certain points.
13+
- `skip_self` (optional, batched queries only): When querying the same dataset, skip the point whose index matches the query index.
1214
1315
# Returns
1416
- `indices`: Vector of indices of points within the radius
@@ -19,23 +21,25 @@ function inrange(tree::NNTree,
1921
points::AbstractVector{T},
2022
radius::Number,
2123
sortres=false,
22-
skip::F = Returns(false)) where {T <: AbstractVector, F}
24+
skip::F = Returns(false);
25+
skip_self::Bool=false) where {T <: AbstractVector, F}
2326
check_input(tree, points)
2427
check_for_nan_in_points(points)
2528
check_radius(radius)
2629

2730
idxs = [Vector{Int}() for _ in 1:length(points)]
2831

2932
for i in 1:length(points)
30-
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip)
33+
self_idx = skip_self ? i : 0
34+
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip, self_idx)
3135
end
3236
return idxs
3337
end
3438

35-
inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} = _inrange_point!(tree, point, radius, sortres, idx, skip)
39+
inrange_point!(tree, point, radius, sortres, idx, skip::F, self_idx::Int) where {F} = _inrange_point!(tree, point, radius, sortres, idx, skip, self_idx)
3640

37-
function _inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F}
38-
count = _inrange(tree, point, radius, idx, skip)
41+
function _inrange_point!(tree, point, radius, sortres, idx, skip::F, self_idx::Int) where {F}
42+
count = _inrange(tree, point, radius, idx, skip, self_idx)
3943
if idx !== nothing
4044
inner_tree = get_tree(tree)
4145
if inner_tree.reordered
@@ -62,25 +66,38 @@ Useful to avoid allocations or specify the element type of the output vector.
6266
6367
See also: `inrange`, `inrangecount`.
6468
"""
65-
function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false)) where {V, T <: Number}
69+
function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false); skip_self::Bool=false) where {V, T <: Number}
70+
skip_self && throw(ArgumentError("skip_self is only supported for batched queries; pass a skip predicate instead for single points"))
6671
check_input(tree, point)
6772
check_for_nan_in_points(point)
6873
check_radius(radius)
6974
length(idxs) == 0 || throw(ArgumentError("idxs must be empty"))
70-
inrange_point!(tree, point, radius, sortres, idxs, skip)
75+
inrange_point!(tree, point, radius, sortres, idxs, skip, 0)
7176
return idxs
7277
end
7378

74-
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number}
79+
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false; skip_self::Bool=false) where {V, T <: Number}
80+
skip_self && throw(ArgumentError("skip_self is only supported for batched queries; pass a skip predicate instead for single points"))
7581
return inrange!(Int[], tree, point, radius, sortres)
7682
end
7783

78-
function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false) where {V, T <: Number}
84+
# Single-point variant with an explicit skip predicate
85+
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres::Bool, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, F}
86+
skip_self && throw(ArgumentError("skip_self is only supported for batched queries; pass a skip predicate instead for single points"))
87+
return inrange!(Int[], tree, point, radius, sortres, skip)
88+
end
89+
90+
function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false; skip_self::Bool=false) where {V, T <: Number}
91+
dim = size(points, 1)
92+
inrange_matrix(tree, points, radius, Val(dim), sortres; skip_self)
93+
end
94+
95+
function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres::Bool, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, F}
7996
dim = size(points, 1)
80-
inrange_matrix(tree, points, radius, Val(dim), sortres)
97+
inrange_matrix(tree, points, radius, Val(dim), sortres, skip; skip_self)
8198
end
8299

83-
function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false)) where {V, T <: Number, dim, F}
100+
function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, dim, F}
84101
# TODO: DRY with inrange for AbstractVector
85102
check_input(tree, points)
86103
check_for_nan_in_points(points)
@@ -90,7 +107,8 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb
90107

91108
for i in 1:n_points
92109
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
93-
inrange_point!(tree, point, radius, sortres, idxs[i], skip)
110+
self_idx = skip_self ? i : 0
111+
inrange_point!(tree, point, radius, sortres, idxs[i], skip, self_idx)
94112
end
95113
return idxs
96114
end
@@ -104,32 +122,40 @@ Count all the points in the tree which are closer than `radius` to `points`.
104122
- `tree`: The tree instance
105123
- `points`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors
106124
- `radius`: Search radius
125+
- `skip` (optional): Predicate to skip certain points.
126+
- `skip_self` (optional, batched queries only): When querying the same dataset, skip the point whose index matches the query index.
107127
108128
# Returns
109129
- `count`: Number of points within the radius (integer for single point, vector for multiple points)
110130
"""
111-
function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F}
131+
function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, F}
132+
skip_self && throw(ArgumentError("skip_self is only supported for batched queries; pass a skip predicate instead for single points"))
112133
check_input(tree, point)
113134
check_for_nan_in_points(point)
114135
check_radius(radius)
115-
return inrange_point!(tree, point, radius, false, nothing, skip)
136+
return inrange_point!(tree, point, radius, false, nothing, skip, 0)
116137
end
117138

118139
function inrangecount(tree::NNTree,
119140
points::AbstractVector{T},
120-
radius::Number, skip::F=Returns(false)) where {T <: AbstractVector, F}
141+
radius::Number, skip::F=Returns(false); skip_self::Bool=false) where {T <: AbstractVector, F}
121142
check_input(tree, points)
122143
check_for_nan_in_points(points)
123144
check_radius(radius)
124-
return inrange_point!.(Ref(tree), points, radius, false, nothing, skip)
145+
counts = Vector{Int}(undef, length(points))
146+
for i in 1:length(points)
147+
self_idx = skip_self ? i : 0
148+
counts[i] = inrange_point!(tree, points[i], radius, false, nothing, skip, self_idx)
149+
end
150+
return counts
125151
end
126152

127-
function inrangecount(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F}
153+
function inrangecount(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, F}
128154
dim = size(points, 1)
129-
inrangecount_matrix(tree, points, radius, Val(dim), skip)
155+
inrangecount_matrix(tree, points, radius, Val(dim), skip; skip_self)
130156
end
131157

132-
function inrangecount_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, skip::F=Returns(false)) where {V, T <: Number, dim, F}
158+
function inrangecount_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, skip::F=Returns(false); skip_self::Bool=false) where {V, T <: Number, dim, F}
133159
check_input(tree, points)
134160
check_for_nan_in_points(points)
135161
check_radius(radius)
@@ -138,7 +164,8 @@ function inrangecount_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius:
138164

139165
for i in 1:n_points
140166
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
141-
counts[i] = inrange_point!(tree, point, radius, false, nothing, skip)
167+
self_idx = skip_self ? i : 0
168+
counts[i] = inrange_point!(tree, point, radius, false, nothing, skip, self_idx)
142169
end
143170
return counts
144171
end

0 commit comments

Comments
 (0)