|
1 | 1 | check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0")) |
2 | 2 |
|
| 3 | +mutable struct ReservoirSampler{T<:Integer,RNG<:AbstractRNG,V<:AbstractVector{T}} <: AbstractVector{T} |
| 4 | + storage::V |
| 5 | + capacity::Int |
| 6 | + len::Int |
| 7 | + seen::Int |
| 8 | + rng::RNG |
| 9 | +end |
| 10 | + |
| 11 | +function ReservoirSampler(storage::AbstractVector{T}, capacity::Integer, rng::AbstractRNG) where {T<:Integer} |
| 12 | + capacity < 0 && throw(ArgumentError("k must be ≥ 0")) |
| 13 | + capacity <= length(storage) || throw(ArgumentError("storage length must be ≥ k")) |
| 14 | + return ReservoirSampler{T, typeof(rng), typeof(storage)}(storage, capacity, 0, 0, rng) |
| 15 | +end |
| 16 | + |
| 17 | +Base.IndexStyle(::Type{<:ReservoirSampler}) = IndexLinear() |
| 18 | +Base.size(rs::ReservoirSampler) = (rs.len,) |
| 19 | +Base.length(rs::ReservoirSampler) = rs.len |
| 20 | +Base.getindex(rs::ReservoirSampler, i::Int) = rs.storage[i] |
| 21 | +Base.setindex!(rs::ReservoirSampler, value, i::Int) = setindex!(rs.storage, value, i) |
| 22 | + |
| 23 | +function Base.push!(rs::ReservoirSampler{T}, value) where {T} |
| 24 | + rs.seen += 1 |
| 25 | + rs.capacity == 0 && return rs |
| 26 | + if rs.seen <= rs.capacity |
| 27 | + rs.len = rs.seen |
| 28 | + rs.storage[rs.len] = value |
| 29 | + else |
| 30 | + j = rand(rs.rng, 1:rs.seen) |
| 31 | + if j <= rs.capacity |
| 32 | + rs.storage[j] = value |
| 33 | + end |
| 34 | + end |
| 35 | + return rs |
| 36 | +end |
| 37 | + |
| 38 | +function Base.sort!(rs::ReservoirSampler; kwargs...) |
| 39 | + rs.len <= 1 && return rs |
| 40 | + sort!(view(rs.storage, 1:rs.len); kwargs...) |
| 41 | + return rs |
| 42 | +end |
| 43 | + |
3 | 44 | """ |
4 | 45 | inrange(tree::NNTree, points, radius) -> indices |
5 | 46 |
|
@@ -147,3 +188,69 @@ function inrange_pairs(tree::NNTree, radius::Number, sortres=false, skip::F=Retu |
147 | 188 | check_radius(radius) |
148 | 189 | return _inrange_pairs(tree, radius, sortres, skip) |
149 | 190 | end |
| 191 | + |
| 192 | +""" |
| 193 | + knninrange(tree::NNTree, point, radius, k; rng=Random.default_rng(), sortres=false, skip=Returns(false)) |
| 194 | +
|
| 195 | +Return up to `k` indices drawn uniformly at random (without replacement) from the points |
| 196 | +that lie within `radius` of `point`. |
| 197 | +
|
| 198 | +This behaves similarly to `inrange`, but it avoids returning more than `k` neighbors. |
| 199 | +
|
| 200 | +See also: `knninrange!`. |
| 201 | +""" |
| 202 | +function knninrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, k::Integer; |
| 203 | + rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: Number, F} |
| 204 | + check_input(tree, point) |
| 205 | + check_radius(radius) |
| 206 | + k < 0 && throw(ArgumentError("k must be ≥ 0")) |
| 207 | + k == 0 && return Int[] |
| 208 | + buf = Vector{Int}(undef, k) |
| 209 | + nsampled = knninrange!(buf, tree, point, radius, k; rng=rng, sortres=sortres, skip=skip) |
| 210 | + resize!(buf, nsampled) |
| 211 | + return buf |
| 212 | +end |
| 213 | + |
| 214 | +function knninrange(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, k::Integer; |
| 215 | + rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: AbstractVector, F} |
| 216 | + check_input(tree, points) |
| 217 | + check_radius(radius) |
| 218 | + return [knninrange(tree, points[i], radius, k; rng=rng, sortres=sortres, skip=skip) for i in 1:length(points)] |
| 219 | +end |
| 220 | + |
| 221 | +function knninrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, k::Integer; |
| 222 | + rng::AbstractRNG=Random.default_rng(), sortres=false, skip::F=Returns(false)) where {V, T <: Number, F} |
| 223 | + check_input(tree, points) |
| 224 | + check_radius(radius) |
| 225 | + dim = size(points, 1) |
| 226 | + n_points = size(points, 2) |
| 227 | + idxs = Vector{Vector{Int}}(undef, n_points) |
| 228 | + for i in 1:n_points |
| 229 | + point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) |
| 230 | + idxs[i] = knninrange(tree, point, radius, k; rng=rng, sortres=sortres, skip=skip) |
| 231 | + end |
| 232 | + return idxs |
| 233 | +end |
| 234 | + |
| 235 | +""" |
| 236 | + knninrange!(idxs, tree, point, radius, k; rng=Random.default_rng(), sortres=false, skip=Returns(false)) |
| 237 | +
|
| 238 | +Mutating version of `knninrange`. The first `k` entries of `idxs` are used as storage for the |
| 239 | +reservoir sampler and will contain the sampled indices after the call returns. The function returns |
| 240 | +the number of valid samples that were written (i.e. `min(k, number_in_range)`). |
| 241 | +
|
| 242 | +The length of `idxs` must be at least `k`. The contents beyond the returned sample length are left |
| 243 | +untouched. |
| 244 | +""" |
| 245 | +function knninrange!(idxs::AbstractVector{<:Integer}, tree::NNTree{V}, point::AbstractVector{T}, |
| 246 | + radius::Number, k::Integer=length(idxs); rng::AbstractRNG=Random.default_rng(), |
| 247 | + sortres=false, skip::F=Returns(false)) where {V, T <: Number, F} |
| 248 | + check_input(tree, point) |
| 249 | + check_radius(radius) |
| 250 | + k < 0 && throw(ArgumentError("k must be ≥ 0")) |
| 251 | + k == 0 && return 0 |
| 252 | + k <= length(idxs) || throw(ArgumentError("idxs must have length ≥ k")) |
| 253 | + sampler = ReservoirSampler(idxs, k, rng) |
| 254 | + _inrange_point!(tree, point, radius, sortres, sampler, skip) |
| 255 | + return length(sampler) |
| 256 | +end |
0 commit comments