Skip to content

Commit eaff332

Browse files
committed
add v3 that uses own walkers with fat node
1 parent 656d4c6 commit eaff332

3 files changed

Lines changed: 297 additions & 36 deletions

File tree

benchmark/benchmark_tree_walk.jl

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using BenchmarkTools
22
using NearestNeighbors
33
using StaticArrays
44
using AbstractTrees
5-
using NearestNeighbors: children, _treeindex, children2, preorder2, leaves2, _isleaf2
5+
using NearestNeighbors: children, _treeindex, children2, preorder2, leaves2, _isleaf2, children3, preorder3, leaves3, _isleaf3
66

77
# Create test trees of various sizes
88
function make_trees(n_points, ndim=3)
@@ -30,6 +30,14 @@ function count_nodes_v2(tree)
3030
return count
3131
end
3232

33+
function count_nodes_v3(tree)
34+
count = 0
35+
for _ in preorder3(tree)
36+
count += 1
37+
end
38+
return count
39+
end
40+
3341
# Benchmark 3: Sum of all leaf point norms
3442
function sum_leaf_norms_v1(tree)
3543
root = treeroot(tree)
@@ -52,6 +60,16 @@ function sum_leaf_norms_v2(tree)
5260
return total
5361
end
5462

63+
function sum_leaf_norms_v3(tree)
64+
total = 0.0
65+
for node in leaves3(tree)
66+
for pt in NearestNeighbors.leafpoints3(node)
67+
total += sum(abs2, pt)
68+
end
69+
end
70+
return total
71+
end
72+
5573
# Benchmark 4: Collect all node indices
5674
function collect_indices_v1(tree)
5775
root = treeroot(tree)
@@ -70,26 +88,12 @@ function collect_indices_v2(tree)
7088
return indices
7189
end
7290

73-
# Benchmark 5: Filter internal nodes
74-
function count_internal_nodes_v1(tree)
75-
root = treeroot(tree)
76-
count = 0
77-
for node in PreOrderDFS(root)
78-
if !isempty(children(node))
79-
count += 1
80-
end
81-
end
82-
return count
83-
end
84-
85-
function count_internal_nodes_v2(tree)
86-
count = 0
87-
for node in preorder2(tree)
88-
if !_isleaf2(tree, node)
89-
count += 1
90-
end
91+
function collect_indices_v3(tree)
92+
indices = Int[]
93+
for node in preorder3(tree)
94+
push!(indices, node.index)
9195
end
92-
return count
96+
return indices
9397
end
9498

9599
# Run benchmarks
@@ -98,7 +102,7 @@ function run_benchmarks()
98102

99103
println("="^80)
100104
println("Tree Walking Benchmark Comparison")
101-
println("AbstractTrees (v1) vs isbits nodes (v2)")
105+
println("AbstractTrees (v1) vs isbits nodes (v2) vs wrapped nodes (v3)")
102106
println("="^80)
103107

104108
for n in sizes
@@ -108,55 +112,56 @@ function run_benchmarks()
108112

109113
kdtree, balltree = make_trees(n)
110114

111-
# Verify both approaches give same results
112-
@assert count_nodes_v1(kdtree) == count_nodes_v2(kdtree)
113-
@assert sum_leaf_norms_v1(kdtree) sum_leaf_norms_v2(kdtree)
114-
@assert collect_indices_v1(kdtree) == collect_indices_v2(kdtree)
115-
@assert count_internal_nodes_v1(kdtree) == count_internal_nodes_v2(kdtree)
115+
# Verify all approaches give same results
116+
@assert count_nodes_v1(kdtree) == count_nodes_v2(kdtree) == count_nodes_v3(kdtree)
117+
@assert sum_leaf_norms_v1(kdtree) sum_leaf_norms_v2(kdtree) sum_leaf_norms_v3(kdtree)
118+
@assert collect_indices_v1(kdtree) == collect_indices_v2(kdtree) == collect_indices_v3(kdtree)
116119

117120
println("\n--- KDTree ---")
118121

119122
print("Count nodes (v1): ")
120123
@btime count_nodes_v1($kdtree)
121124
print("Count nodes (v2): ")
122125
@btime count_nodes_v2($kdtree)
126+
print("Count nodes (v3): ")
127+
@btime count_nodes_v3($kdtree)
123128

124129
print("\nSum leaf norms (v1): ")
125130
@btime sum_leaf_norms_v1($kdtree)
126131
print("Sum leaf norms (v2): ")
127132
@btime sum_leaf_norms_v2($kdtree)
133+
print("Sum leaf norms (v3): ")
134+
@btime sum_leaf_norms_v3($kdtree)
128135

129136
print("\nCollect indices (v1): ")
130137
@btime collect_indices_v1($kdtree)
131138
print("Collect indices (v2): ")
132139
@btime collect_indices_v2($kdtree)
133-
134-
print("\nCount internal (v1): ")
135-
@btime count_internal_nodes_v1($kdtree)
136-
print("Count internal (v2): ")
137-
@btime count_internal_nodes_v2($kdtree)
140+
print("Collect indices (v3): ")
141+
@btime collect_indices_v3($kdtree)
138142

139143
println("\n--- BallTree ---")
140144

141145
print("Count nodes (v1): ")
142146
@btime count_nodes_v1($balltree)
143147
print("Count nodes (v2): ")
144148
@btime count_nodes_v2($balltree)
149+
print("Count nodes (v3): ")
150+
@btime count_nodes_v3($balltree)
145151

146152
print("\nSum leaf norms (v1): ")
147153
@btime sum_leaf_norms_v1($balltree)
148154
print("Sum leaf norms (v2): ")
149155
@btime sum_leaf_norms_v2($balltree)
156+
print("Sum leaf norms (v3): ")
157+
@btime sum_leaf_norms_v3($balltree)
150158

151159
print("\nCollect indices (v1): ")
152160
@btime collect_indices_v1($balltree)
153161
print("Collect indices (v2): ")
154162
@btime collect_indices_v2($balltree)
155-
156-
print("\nCount internal (v1): ")
157-
@btime count_internal_nodes_v1($balltree)
158-
print("Count internal (v2): ")
159-
@btime count_internal_nodes_v2($balltree)
163+
print("Collect indices (v3): ")
164+
@btime collect_indices_v3($balltree)
160165
end
161166
end
162167

src/NearestNeighbors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ include("periodic_tree.jl")
6565
include("datafreetree.jl")
6666
include("tree_walk.jl")
6767
include("tree_walk_2.jl")
68+
include("tree_walk_3.jl")
6869
include("knn.jl")
6970
include("inrange.jl")
7071

0 commit comments

Comments
 (0)