Skip to content

Commit b91014e

Browse files
authored
make inrangecount not copy input matrix (cf apply 965c14) (#225)
1 parent 05261ab commit b91014e

2 files changed

Lines changed: 22 additions & 9 deletions

File tree

src/inrange.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,21 @@ function inrangecount(tree::NNTree,
124124
return inrange_point!.(Ref(tree), points, radius, false, nothing, skip)
125125
end
126126

127-
function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number}
128-
check_for_nan_in_points(point)
129-
dim = size(point, 1)
130-
npoints = size(point, 2)
131-
if isbitstype(T)
132-
new_data = copy_svec(T, point, Val(dim))
133-
else
134-
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
127+
function inrangecount(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F}
128+
dim = size(points, 1)
129+
inrangecount_matrix(tree, points, radius, Val(dim), skip)
130+
end
131+
132+
function inrangecount_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, skip::F=Returns(false)) where {V, T <: Number, dim, F}
133+
check_input(tree, points)
134+
check_for_nan_in_points(points)
135+
check_radius(radius)
136+
n_points = size(points, 2)
137+
counts = Vector{Int}(undef, n_points)
138+
139+
for i in 1:n_points
140+
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
141+
counts[i] = inrange_point!(tree, point, radius, false, nothing, skip)
135142
end
136-
return inrangecount(tree, new_data, radius)
143+
return counts
137144
end

test/test_inrange.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,11 @@ end
100100
return inrange(b, point, 0.1)
101101
end
102102

103+
function foo2(data, point)
104+
b = KDTree(data)
105+
return inrangecount(b, point, 0.1)
106+
end
107+
103108
@inferred foo([1.0 3.4; 4.5 3.4], [4.5; 3.4])
109+
@inferred foo2([1.0 3.4; 4.5 3.4], [4.5; 3.4])
104110
end

0 commit comments

Comments
 (0)