Skip to content

Commit 05261ab

Browse files
williamjsdavisKristofferC
andauthored
Add error checks for constructing trees and doing queries
Co-authored-by: KristofferC <kristoffer.carlsson@juliacomputing.com>
1 parent 8598bfe commit 05261ab

8 files changed

Lines changed: 108 additions & 0 deletions

File tree

src/ball_tree.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ function BallTree(data::AbstractVector{V},
3838
parallel::Bool = Threads.nthreads() > 1) where {V <: AbstractArray}
3939
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)
4040

41+
# Reject data containing NaNs early to avoid undefined behaviour later on.
42+
check_for_nan(data)
43+
4144
tree_data = TreeData(data, leafsize)
4245
n_p = length(data)
4346

src/brute_tree.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Note: `leafsize` and `reorder` parameters are ignored for BruteTree.
2121
"""
2222
function BruteTree(data::AbstractVector{V}, metric::PreMetric = Euclidean();
2323
reorder::Bool=false, leafsize::Int=0, storedata::Bool=true) where {V <: AbstractVector}
24+
check_for_nan(data)
2425
if metric isa Distances.UnionMetrics
2526
p = parameters(metric)
2627
if p !== nothing && length(p) != length(V)
@@ -34,6 +35,7 @@ end
3435

3536
function BruteTree(data::AbstractVecOrMat{T}, metric::PreMetric = Euclidean();
3637
reorder::Bool=false, leafsize::Int=0, storedata::Bool=true) where {T}
38+
check_for_nan(data)
3739
dim = size(data, 1)
3840
BruteTree(copy_svec(T, data, Val(dim)),
3941
metric, reorder = reorder, leafsize = leafsize, storedata = storedata)

src/inrange.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function inrange(tree::NNTree,
2121
sortres=false,
2222
skip::F = Returns(false)) where {T <: AbstractVector, F}
2323
check_input(tree, points)
24+
check_for_nan_in_points(points)
2425
check_radius(radius)
2526

2627
idxs = [Vector{Int}() for _ in 1:length(points)]
@@ -63,6 +64,7 @@ See also: `inrange`, `inrangecount`.
6364
"""
6465
function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false)) where {V, T <: Number}
6566
check_input(tree, point)
67+
check_for_nan_in_points(point)
6668
check_radius(radius)
6769
length(idxs) == 0 || throw(ArgumentError("idxs must be empty"))
6870
inrange_point!(tree, point, radius, sortres, idxs, skip)
@@ -81,6 +83,7 @@ end
8183
function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false)) where {V, T <: Number, dim, F}
8284
# TODO: DRY with inrange for AbstractVector
8385
check_input(tree, points)
86+
check_for_nan_in_points(points)
8487
check_radius(radius)
8588
n_points = size(points, 2)
8689
idxs = [Vector{Int}() for _ in 1:n_points]
@@ -107,6 +110,7 @@ Count all the points in the tree which are closer than `radius` to `points`.
107110
"""
108111
function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F}
109112
check_input(tree, point)
113+
check_for_nan_in_points(point)
110114
check_radius(radius)
111115
return inrange_point!(tree, point, radius, false, nothing, skip)
112116
end
@@ -115,11 +119,13 @@ function inrangecount(tree::NNTree,
115119
points::AbstractVector{T},
116120
radius::Number, skip::F=Returns(false)) where {T <: AbstractVector, F}
117121
check_input(tree, points)
122+
check_for_nan_in_points(points)
118123
check_radius(radius)
119124
return inrange_point!.(Ref(tree), points, radius, false, nothing, skip)
120125
end
121126

122127
function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number}
128+
check_for_nan_in_points(point)
123129
dim = size(point, 1)
124130
npoints = size(point, 2)
125131
if isbitstype(T)

src/kd_tree.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ function KDTree(data::AbstractVector{V},
3535
parallel::Bool = Threads.nthreads() > 1) where {V <: AbstractArray, M <: MinkowskiMetric}
3636
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)
3737

38+
# Reject data containing NaNs early to avoid undefined behaviour later on.
39+
check_for_nan(data)
40+
3841
tree_data = TreeData(data, leafsize)
3942
n_p = length(data)
4043

src/knn.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ See also: `knn!`, `nn`.
2424
"""
2525
function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: AbstractVector, F<:Function}
2626
check_input(tree, points)
27+
check_for_nan_in_points(points)
2728
check_k(tree, k)
2829
n_points = length(points)
2930
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
@@ -73,6 +74,7 @@ Useful to avoid allocations or specify the element type of the output vectors.
7374
See also: `knn`, `nn`.
7475
"""
7576
function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function}
77+
check_for_nan_in_points(point)
7678
check_k(tree, k)
7779
length(idxs) == k || throw(ArgumentError("idxs must be of length k"))
7880
length(dists) == k || throw(ArgumentError("dists must be of length k"))
@@ -91,6 +93,7 @@ function knn(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, sortres=false,
9193

9294
# TODO: DRY with knn for AbstractVector
9395
check_input(tree, points)
96+
check_for_nan_in_points(points)
9497
check_k(tree, k)
9598
n_points = size(points, 2)
9699
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]

src/utilities.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,28 @@ end
9191
# Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors
9292
copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} =
9393
[SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)]::Vector{SVector{dim,T}}
94+
95+
# Check for NaN values in data; throw if any are present
96+
function check_for_nan(data)
97+
@inbounds for p in data
98+
if any(isnan, p)
99+
throw(ArgumentError("Tree cannot be constructed from data containing NaN values"))
100+
end
101+
end
102+
return
103+
end
104+
105+
# Check for NaN values in input points; throw if any are present
106+
function check_for_nan_in_points(points::Union{AbstractVector, AbstractMatrix})
107+
if any(isnan, points)
108+
throw(ArgumentError("Tree cannot be queried with points containing NaN values"))
109+
end
110+
return
111+
end
112+
113+
function check_for_nan_in_points(points::AbstractVector{<:AbstractVector})
114+
for p in points
115+
check_for_nan_in_points(p)
116+
end
117+
return
118+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ include("test_monkey.jl")
3535
include("test_datafreetree.jl")
3636
include("test_tree_data.jl")
3737
include("test_periodic.jl")
38+
include("test_tree_nan.jl")
3839

3940
@testset "views of SVector" begin
4041
x = [rand(SVector{3}) for i in 1:20]

test/test_tree_nan.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Tests for KDTree, BallTree, BruteTree that reject data containing NaNs
2+
3+
@testset "Trees reject NaNs" begin
4+
data_vec = [SVector{2,Float64}(NaN, 0.0), SVector{2,Float64}(1.0, 1.0)]
5+
data_mat = [NaN 0.0; 1.0 1.0]
6+
7+
for TreeType in (KDTree, BallTree, BruteTree)
8+
@test_throws ArgumentError TreeType(data_vec)
9+
@test_throws ArgumentError TreeType(data_mat)
10+
end
11+
end
12+
13+
@testset "knn rejects NaNs" begin
14+
for TreeType in (KDTree, BallTree, BruteTree)
15+
data = [SVector{2,Float64}(0.0, 0.0), SVector{2,Float64}(1.0, 1.0)]
16+
tree = TreeType(data)
17+
18+
# Single query point (vector) containing NaN
19+
@test_throws ArgumentError knn(tree, [NaN, 0.0], 1)
20+
21+
# Vector-of-vectors query containing NaN
22+
query_vec = [SVector{2,Float64}(NaN, 0.0)]
23+
@test_throws ArgumentError knn(tree, query_vec, 1)
24+
25+
# Matrix query containing NaN
26+
query_mat = [NaN 0.0; 0.0 1.0]
27+
@test_throws ArgumentError knn(tree, query_mat, 1)
28+
end
29+
end
30+
31+
@testset "inrange rejects NaNs" begin
32+
for TreeType in (KDTree, BallTree, BruteTree)
33+
data = [SVector{2,Float64}(0.0, 0.0), SVector{2,Float64}(1.0, 1.0)]
34+
tree = TreeType(data)
35+
36+
# Single query point (vector) containing NaN
37+
@test_throws ArgumentError inrange(tree, [NaN, 0.0], 1.0)
38+
39+
# Vector-of-vectors query containing NaN
40+
query_vec = [SVector{2,Float64}(NaN, 0.0)]
41+
@test_throws ArgumentError inrange(tree, query_vec, 1.0)
42+
43+
# Matrix query containing NaN
44+
query_mat = [NaN 0.0; 0.0 1.0]
45+
@test_throws ArgumentError inrange(tree, query_mat, 1.0)
46+
end
47+
end
48+
49+
@testset "inrangecount rejects NaNs" begin
50+
for TreeType in (KDTree, BallTree, BruteTree)
51+
data = [SVector{2,Float64}(0.0, 0.0), SVector{2,Float64}(1.0, 1.0)]
52+
tree = TreeType(data)
53+
54+
# Single query point (vector) containing NaN
55+
@test_throws ArgumentError inrangecount(tree, [NaN, 0.0], 1.0)
56+
57+
# Vector-of-vectors query containing NaN
58+
query_vec = [SVector{2,Float64}(NaN, 0.0)]
59+
@test_throws ArgumentError inrangecount(tree, query_vec, 1.0)
60+
61+
# Matrix query containing NaN
62+
query_mat = [NaN 0.0; 0.0 1.0]
63+
@test_throws ArgumentError inrangecount(tree, query_mat, 1.0)
64+
end
65+
end

0 commit comments

Comments
 (0)