Skip to content

Commit c4c1ff2

Browse files
author
KristofferC
committed
add knninrange to uniformly sample k points in range within a given range
1 parent 0824413 commit c4c1ff2

11 files changed

Lines changed: 179 additions & 11 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.4.24"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
89

910
[compat]

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ A range search finds all neighbors within the range `r` of given point(s). This
157157
```julia
158158
inrange(tree, point[s], radius) -> idxs
159159
inrange!(idxs, tree, point, radius)
160+
knninrange(tree, point[s], radius, k) -> idxs
161+
knninrange!(idxs, tree, point, radius, k)
162+
inrangecount(tree, point[s], radius) -> count
160163
```
161164

162165
* `tree`: The tree instance.
163166
* `point[s]`: A vector or matrix of points to find neighbors for.
164167
* `radius`: Search radius.
168+
* `k`: Maximum number of neighbors to return when sampling with `knninrange`. Every point inside the radius has equal probability of being selected.
165169

166170
Note: Distances are not returned, only indices.
167171

@@ -186,6 +190,11 @@ idxs = inrange(balltree, point, r)
186190
idxs = Int32[]
187191
inrange!(idxs, balltree, point, r)
188192

193+
# Sample up to `k` neighbors uniformly at random without allocating new buffers
194+
buf = zeros(Int, 5)
195+
nsampled = knninrange!(buf, balltree, point, r, 5)
196+
random_subset = buf[1:nsampled]
197+
189198
# counts points without allocating index arrays
190199
neighborscount = inrangecount(balltree, point, r)
191200
```

benchmark/Manifest.toml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/NearestNeighbors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ module NearestNeighbors
33
using Distances: Distances, PreMetric, Metric, UnionMinkowskiMetric, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters, Euclidean, Cityblock, Minkowski, Chebyshev, Hamming, Mahalanobis, WeightedEuclidean, WeightedCityblock, WeightedMinkowski
44

55
using StaticArrays: StaticArrays, MVector, SVector
6+
using Random
67
using Base: setindex
78

89
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree
9-
export knn, knn!, nn, inrange, inrange!, inrangecount, inrange_pairs # TODOs?, npairs
10+
export knn, knn!, nn, inrange, inrange!, inrangecount, inrange_pairs, knninrange, knninrange! # TODOs?, npairs
1011
export injectdata
1112

1213
export Euclidean,

src/ball_tree.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ end
210210
function _inrange(tree::BallTree{V},
211211
point::AbstractVector,
212212
radius::Number,
213-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
213+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
214214
skip::F) where {V, F}
215215
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
216216
return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, nothing) # Call the recursive range finder
@@ -220,7 +220,7 @@ function inrange_kernel!(tree::BallTree,
220220
index::Int,
221221
point::AbstractVector,
222222
query_ball::HyperSphere,
223-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
223+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
224224
skip::F,
225225
dedup::MaybeBitSet) where {F}
226226

src/brute_tree.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878
function _inrange(tree::BruteTree,
7979
point::AbstractVector,
8080
radius::Number,
81-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
81+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
8282
skip::F,) where {F}
8383
return inrange_kernel!(tree, point, radius, idx_in_ball, skip, nothing)
8484
end
@@ -87,7 +87,7 @@ end
8787
function inrange_kernel!(tree::BruteTree,
8888
point::AbstractVector,
8989
r::Number,
90-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
90+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
9191
skip::Function,
9292
dedup::MaybeBitSet)
9393
count = 0

src/inrange.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,46 @@
11
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))
22

3+
mutable struct ReservoirSampler{T<:Integer,RNG<:AbstractRNG,V<:AbstractVector{T}} <: AbstractVector{T}
4+
storage::V
5+
capacity::Int
6+
len::Int
7+
seen::Int
8+
rng::RNG
9+
end
10+
11+
function ReservoirSampler(storage::AbstractVector{T}, capacity::Integer, rng::AbstractRNG) where {T<:Integer}
12+
capacity < 0 && throw(ArgumentError("k must be ≥ 0"))
13+
capacity <= length(storage) || throw(ArgumentError("storage length must be ≥ k"))
14+
return ReservoirSampler{T, typeof(rng), typeof(storage)}(storage, capacity, 0, 0, rng)
15+
end
16+
17+
Base.IndexStyle(::Type{<:ReservoirSampler}) = IndexLinear()
18+
Base.size(rs::ReservoirSampler) = (rs.len,)
19+
Base.length(rs::ReservoirSampler) = rs.len
20+
Base.getindex(rs::ReservoirSampler, i::Int) = rs.storage[i]
21+
Base.setindex!(rs::ReservoirSampler, value, i::Int) = setindex!(rs.storage, value, i)
22+
23+
function Base.push!(rs::ReservoirSampler{T}, value) where {T}
24+
rs.seen += 1
25+
rs.capacity == 0 && return rs
26+
if rs.seen <= rs.capacity
27+
rs.len = rs.seen
28+
rs.storage[rs.len] = value
29+
else
30+
j = rand(rs.rng, 1:rs.seen)
31+
if j <= rs.capacity
32+
rs.storage[j] = value
33+
end
34+
end
35+
return rs
36+
end
37+
38+
function Base.sort!(rs::ReservoirSampler; kwargs...)
39+
rs.len <= 1 && return rs
40+
sort!(view(rs.storage, 1:rs.len); kwargs...)
41+
return rs
42+
end
43+
344
"""
445
inrange(tree::NNTree, points, radius) -> indices
546
@@ -147,3 +188,69 @@ function inrange_pairs(tree::NNTree, radius::Number, sortres=false, skip::F=Retu
147188
check_radius(radius)
148189
return _inrange_pairs(tree, radius, sortres, skip)
149190
end
191+
192+
"""
193+
knninrange(tree::NNTree, point, radius, k; rng=Random.default_rng(), sortres=false, skip=Returns(false))
194+
195+
Return up to `k` indices drawn uniformly at random (without replacement) from the points
196+
that lie within `radius` of `point`.
197+
198+
This behaves similarly to `inrange`, but it avoids returning more than `k` neighbors.
199+
200+
See also: `knninrange!`.
201+
"""
202+
function knninrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, k::Integer;
203+
rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: Number, F}
204+
check_input(tree, point)
205+
check_radius(radius)
206+
k < 0 && throw(ArgumentError("k must be ≥ 0"))
207+
k == 0 && return Int[]
208+
buf = Vector{Int}(undef, k)
209+
nsampled = knninrange!(buf, tree, point, radius, k; rng=rng, sortres=sortres, skip=skip)
210+
resize!(buf, nsampled)
211+
return buf
212+
end
213+
214+
function knninrange(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, k::Integer;
215+
rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: AbstractVector, F}
216+
check_input(tree, points)
217+
check_radius(radius)
218+
return [knninrange(tree, points[i], radius, k; rng=rng, sortres=sortres, skip=skip) for i in 1:length(points)]
219+
end
220+
221+
function knninrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, k::Integer;
222+
rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: Number, F}
223+
check_input(tree, points)
224+
check_radius(radius)
225+
dim = size(points, 1)
226+
n_points = size(points, 2)
227+
idxs = Vector{Vector{Int}}(undef, n_points)
228+
for i in 1:n_points
229+
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
230+
idxs[i] = knninrange(tree, point, radius, k; rng=rng, sortres=sortres, skip=skip)
231+
end
232+
return idxs
233+
end
234+
235+
"""
236+
knninrange!(idxs, tree, point, radius, k; rng=Random.default_rng(), sortres=false, skip=Returns(false))
237+
238+
Mutating version of `knninrange`. The first `k` entries of `idxs` are used as storage for the
239+
reservoir sampler and will contain the sampled indices after the call returns. The function returns
240+
the number of valid samples that were written (i.e. `min(k, number_in_range)`).
241+
242+
The length of `idxs` must be at least `k`. The contents beyond the returned sample length are left
243+
untouched.
244+
"""
245+
function knninrange!(idxs::AbstractVector{<:Integer}, tree::NNTree{V}, point::AbstractVector{T},
246+
radius::Number, k::Integer=length(idxs); rng::AbstractRNG=Random.default_rng(),
247+
sortres=false, skip::F=Returns(false)) where {V, T <: Number, F}
248+
check_input(tree, point)
249+
check_radius(radius)
250+
k < 0 && throw(ArgumentError("k must be ≥ 0"))
251+
k == 0 && return 0
252+
k <= length(idxs) || throw(ArgumentError("idxs must have length ≥ k"))
253+
sampler = ReservoirSampler(idxs, k, rng)
254+
_inrange_point!(tree, point, radius, sortres, sampler, skip)
255+
return length(sampler)
256+
end

src/kd_tree.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function _inrange(
222222
tree::KDTree,
223223
point::AbstractVector,
224224
radius::Number,
225-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
225+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
226226
skip::F) where {F}
227227
init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point)
228228
init_max_contribs = get_max_distance_contributions(tree.metric, tree.hyper_rec, point)
@@ -239,7 +239,7 @@ function inrange_kernel!(
239239
index::Int,
240240
point::AbstractVector,
241241
r::Number,
242-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
242+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
243243
hyper_rec::HyperRectangle,
244244
min_dist,
245245
max_dist_contribs::SVector,

src/periodic_tree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ end
233233
function _inrange(tree::PeriodicTree{V},
234234
point::AbstractVector,
235235
radius::Number,
236-
idx_in_ball::Union{Nothing, Vector{<:Integer}},
236+
idx_in_ball::Union{Nothing, AbstractVector{<:Integer}},
237237
skip::F) where {V, F}
238238

239239
dedup_state = empty!(tree.dedup_set)

src/tree_ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ end
180180

181181
# Add all points in this subtree since we have determined
182182
# they are all within the desired range
183-
function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}, skip::Function,
183+
function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, skip::Function,
184184
dedup::MaybeBitSet)
185185
tree_data = tree.tree_data
186186
if isleaf(tree_data.n_internal_nodes, index)

0 commit comments

Comments
 (0)