@@ -2,7 +2,7 @@ using BenchmarkTools
22using NearestNeighbors
33using StaticArrays
44using AbstractTrees
5- using NearestNeighbors: children, _treeindex, children2, preorder2, leaves2, _isleaf2
5+ using NearestNeighbors: children, _treeindex, children2, preorder2, leaves2, _isleaf2, children3, preorder3, leaves3, _isleaf3
66
77# Create test trees of various sizes
88function make_trees (n_points, ndim= 3 )
@@ -30,6 +30,14 @@ function count_nodes_v2(tree)
3030 return count
3131end
3232
33+ function count_nodes_v3 (tree)
34+ count = 0
35+ for _ in preorder3 (tree)
36+ count += 1
37+ end
38+ return count
39+ end
40+
3341# Benchmark 3: Sum of all leaf point norms
3442function sum_leaf_norms_v1 (tree)
3543 root = treeroot (tree)
@@ -52,6 +60,16 @@ function sum_leaf_norms_v2(tree)
5260 return total
5361end
5462
63+ function sum_leaf_norms_v3 (tree)
64+ total = 0.0
65+ for node in leaves3 (tree)
66+ for pt in NearestNeighbors. leafpoints3 (node)
67+ total += sum (abs2, pt)
68+ end
69+ end
70+ return total
71+ end
72+
5573# Benchmark 4: Collect all node indices
5674function collect_indices_v1 (tree)
5775 root = treeroot (tree)
@@ -70,26 +88,12 @@ function collect_indices_v2(tree)
7088 return indices
7189end
7290
73- # Benchmark 5: Filter internal nodes
74- function count_internal_nodes_v1 (tree)
75- root = treeroot (tree)
76- count = 0
77- for node in PreOrderDFS (root)
78- if ! isempty (children (node))
79- count += 1
80- end
81- end
82- return count
83- end
84-
85- function count_internal_nodes_v2 (tree)
86- count = 0
87- for node in preorder2 (tree)
88- if ! _isleaf2 (tree, node)
89- count += 1
90- end
91+ function collect_indices_v3 (tree)
92+ indices = Int[]
93+ for node in preorder3 (tree)
94+ push! (indices, node. index)
9195 end
92- return count
96+ return indices
9397end
9498
9599# Run benchmarks
@@ -98,7 +102,7 @@ function run_benchmarks()
98102
99103 println (" =" ^ 80 )
100104 println (" Tree Walking Benchmark Comparison" )
101- println (" AbstractTrees (v1) vs isbits nodes (v2)" )
105+ println (" AbstractTrees (v1) vs isbits nodes (v2) vs wrapped nodes (v3) " )
102106 println (" =" ^ 80 )
103107
104108 for n in sizes
@@ -108,55 +112,56 @@ function run_benchmarks()
108112
109113 kdtree, balltree = make_trees (n)
110114
111- # Verify both approaches give same results
112- @assert count_nodes_v1 (kdtree) == count_nodes_v2 (kdtree)
113- @assert sum_leaf_norms_v1 (kdtree) ≈ sum_leaf_norms_v2 (kdtree)
114- @assert collect_indices_v1 (kdtree) == collect_indices_v2 (kdtree)
115- @assert count_internal_nodes_v1 (kdtree) == count_internal_nodes_v2 (kdtree)
115+ # Verify all approaches give same results
116+ @assert count_nodes_v1 (kdtree) == count_nodes_v2 (kdtree) == count_nodes_v3 (kdtree)
117+ @assert sum_leaf_norms_v1 (kdtree) ≈ sum_leaf_norms_v2 (kdtree) ≈ sum_leaf_norms_v3 (kdtree)
118+ @assert collect_indices_v1 (kdtree) == collect_indices_v2 (kdtree) == collect_indices_v3 (kdtree)
116119
117120 println (" \n --- KDTree ---" )
118121
119122 print (" Count nodes (v1): " )
120123 @btime count_nodes_v1 ($ kdtree)
121124 print (" Count nodes (v2): " )
122125 @btime count_nodes_v2 ($ kdtree)
126+ print (" Count nodes (v3): " )
127+ @btime count_nodes_v3 ($ kdtree)
123128
124129 print (" \n Sum leaf norms (v1): " )
125130 @btime sum_leaf_norms_v1 ($ kdtree)
126131 print (" Sum leaf norms (v2): " )
127132 @btime sum_leaf_norms_v2 ($ kdtree)
133+ print (" Sum leaf norms (v3): " )
134+ @btime sum_leaf_norms_v3 ($ kdtree)
128135
129136 print (" \n Collect indices (v1): " )
130137 @btime collect_indices_v1 ($ kdtree)
131138 print (" Collect indices (v2): " )
132139 @btime collect_indices_v2 ($ kdtree)
133-
134- print (" \n Count internal (v1): " )
135- @btime count_internal_nodes_v1 ($ kdtree)
136- print (" Count internal (v2): " )
137- @btime count_internal_nodes_v2 ($ kdtree)
140+ print (" Collect indices (v3): " )
141+ @btime collect_indices_v3 ($ kdtree)
138142
139143 println (" \n --- BallTree ---" )
140144
141145 print (" Count nodes (v1): " )
142146 @btime count_nodes_v1 ($ balltree)
143147 print (" Count nodes (v2): " )
144148 @btime count_nodes_v2 ($ balltree)
149+ print (" Count nodes (v3): " )
150+ @btime count_nodes_v3 ($ balltree)
145151
146152 print (" \n Sum leaf norms (v1): " )
147153 @btime sum_leaf_norms_v1 ($ balltree)
148154 print (" Sum leaf norms (v2): " )
149155 @btime sum_leaf_norms_v2 ($ balltree)
156+ print (" Sum leaf norms (v3): " )
157+ @btime sum_leaf_norms_v3 ($ balltree)
150158
151159 print (" \n Collect indices (v1): " )
152160 @btime collect_indices_v1 ($ balltree)
153161 print (" Collect indices (v2): " )
154162 @btime collect_indices_v2 ($ balltree)
155-
156- print (" \n Count internal (v1): " )
157- @btime count_internal_nodes_v1 ($ balltree)
158- print (" Count internal (v2): " )
159- @btime count_internal_nodes_v2 ($ balltree)
163+ print (" Collect indices (v3): " )
164+ @btime collect_indices_v3 ($ balltree)
160165 end
161166end
162167
0 commit comments