|
98 | 98 | @test distance ≈ [0.02239688629947563, 0.13440059522389006] |
99 | 99 |
|
100 | 100 | @testset "skip_self keyword" begin |
101 | | - data = rand(3, 25) |
102 | | - tree = KDTree(data) |
103 | | - idxs_skip, dists_skip = knn(tree, data, 1, true; skip_self=true) |
104 | | - @test all(i -> idxs_skip[i][1] != i, eachindex(idxs_skip)) |
105 | | - expected_idxs = Vector{Int}(undef, size(data, 2)) |
106 | | - expected_dists = Vector{Float64}(undef, size(data, 2)) |
107 | | - for i in 1:size(data, 2) |
108 | | - single_idx, single_dist = knn(tree, data[:, i], 1, true, j -> j == i) |
109 | | - expected_idxs[i] = single_idx[1] |
110 | | - expected_dists[i] = single_dist[1] |
| 101 | + for tree_builder in (data -> KDTree(data), |
| 102 | + data -> BallTree(data), |
| 103 | + data -> BruteTree(data), |
| 104 | + data -> PeriodicTree(KDTree(data), zeros(3), ones(3))) |
| 105 | + data = rand(3, 25) |
| 106 | + tree = tree_builder(data) |
| 107 | + idxs_skip, dists_skip = knn(tree, data, 1, true; skip_self=true) |
| 108 | + @test all(i -> idxs_skip[i][1] != i, eachindex(idxs_skip)) |
| 109 | + expected_idxs = Vector{Int}(undef, size(data, 2)) |
| 110 | + expected_dists = Vector{Float64}(undef, size(data, 2)) |
| 111 | + for i in 1:size(data, 2) |
| 112 | + single_idx, single_dist = knn(tree, data[:, i], 1, true, j -> j == i) |
| 113 | + expected_idxs[i] = single_idx[1] |
| 114 | + expected_dists[i] = single_dist[1] |
| 115 | + end |
| 116 | + @test [idxs_skip[i][1] for i in eachindex(idxs_skip)] == expected_idxs |
| 117 | + @test [dists_skip[i][1] for i in eachindex(dists_skip)] ≈ expected_dists |
| 118 | + |
| 119 | + # Works for nn with batched queries too |
| 120 | + nn_idx, nn_dist = nn(tree, data; skip_self=true) |
| 121 | + @test nn_idx == expected_idxs |
| 122 | + @test nn_dist ≈ expected_dists |
| 123 | + |
| 124 | + # skip_self combines with a custom skip predicate |
| 125 | + block_idx = 5 |
| 126 | + idxs_blocked, _ = knn(tree, data, 1, true, j -> j == block_idx; skip_self=true) |
| 127 | + @test all(i -> idxs_blocked[i][1] != i && idxs_blocked[i][1] != block_idx, eachindex(idxs_blocked)) |
111 | 128 | end |
112 | | - @test [idxs_skip[i][1] for i in eachindex(idxs_skip)] == expected_idxs |
113 | | - @test [dists_skip[i][1] for i in eachindex(dists_skip)] ≈ expected_dists |
114 | | - |
115 | | - # Works for nn with batched queries too |
116 | | - nn_idx, nn_dist = nn(tree, data; skip_self=true) |
117 | | - @test nn_idx == expected_idxs |
118 | | - @test nn_dist ≈ expected_dists |
119 | | - |
120 | | - # skip_self combines with a custom skip predicate |
121 | | - block_idx = 5 |
122 | | - idxs_blocked, _ = knn(tree, data, 1, true, j -> j == block_idx; skip_self=true) |
123 | | - @test all(i -> idxs_blocked[i][1] != i && idxs_blocked[i][1] != block_idx, eachindex(idxs_blocked)) |
124 | 129 | end |
125 | 130 | end |
126 | 131 |
|
|
0 commit comments