11import genesis as gs
22import taichi as ti
33from genesis .repr_base import RBC
4+ import numpy as np
45
56
67@ti .data_oriented
@@ -157,19 +158,27 @@ class Node:
157158 # Nodes of the BVH, first n_aabbs - 1 are internal nodes, last n_aabbs are leaf nodes
158159 self .nodes = self .Node .field (shape = (self .n_batches , self .n_aabbs * 2 - 1 ))
159160 # Whether an internal node has been visited during traversal
160- self .internal_node_visited = ti .field (ti .u8 , shape = (self .n_batches , self .n_aabbs - 1 ))
161+ self .internal_node_active = ti .field (ti .u1 , shape = (self .n_batches , self .n_aabbs - 1 ))
162+ self .internal_node_ready = ti .field (ti .u1 , shape = (self .n_batches , self .n_aabbs - 1 ))
163+ self .updated = ti .field (ti .u1 , shape = ())
161164
162165 # Query results, vec3 of batch id, self id, query id
163166 self .query_result = ti .field (gs .ti_ivec3 , shape = (self .max_n_query_results ))
164167 # Count of query results
165168 self .query_result_count = ti .field (ti .i32 , shape = ())
166169
167- @ti .kernel
168170 def build (self ):
169171 """
170172 Build the BVH from the axis-aligned bounding boxes (AABBs).
171173 """
174+ self .compute_aabb_centers_and_scales ()
175+ self .compute_morton_codes ()
176+ self .radix_sort_morton_codes ()
177+ self .build_radix_tree ()
178+ self .compute_bounds ()
172179
180+ @ti .kernel
181+ def compute_aabb_centers_and_scales (self ):
173182 for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
174183 self .aabb_centers [i_b , i_a ] = (self .aabbs [i_b , i_a ].min + self .aabbs [i_b , i_a ].max ) / 2
175184
@@ -184,14 +193,9 @@ def build(self):
184193 for i_b in ti .ndrange (self .n_batches ):
185194 scale = self .aabb_max [i_b ] - self .aabb_min [i_b ]
186195 for i in ti .static (range (3 )):
187- self .scale [i_b ][i ] = ti .select (scale [i ] > 1e-7 , 1.0 / scale [i ], 1 )
196+ self .scale [i_b ][i ] = ti .select (scale [i ] > gs . EPS , 1.0 / scale [i ], 1.0 )
188197
189- self .compute_morton_codes ()
190- self .radix_sort_morton_codes ()
191- self .build_radix_tree ()
192- self .compute_bounds ()
193-
194- @ti .func
198+ @ti .kernel
195199 def compute_morton_codes (self ):
196200 """
197201 Compute the Morton codes for each AABB.
@@ -223,38 +227,43 @@ def expand_bits(self, v):
223227 v = (v * ti .u32 (0x00000005 )) & ti .u32 (0x49249249 )
224228 return v
225229
226- @ti .func
227230 def radix_sort_morton_codes (self ):
228231 """
229232 Radix sort the morton codes, using 8 bits at a time.
230233 """
231- for i in ti .static (range (8 )):
232- # Clear histogram
233- for i_b , j in ti .ndrange (self .n_batches , 256 ):
234- self .hist [i_b , j ] = 0
234+ for i in range (8 ):
235+ self ._kernel_radix_sort_morton_codes_one_round (i )
235236
236- # Fill histogram
237- for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
237+ @ti .kernel
238+ def _kernel_radix_sort_morton_codes_one_round (self , i : int ):
239+ # Clear histogram
240+ self .hist .fill (0 )
241+
242+ # Fill histogram
243+ for i_b in range (self .n_batches ):
244+ # This is now sequential
245+ # TODO Parallelize, need to use groups to handle data to remain stable, could be not worth it
246+ for i_a in range (self .n_aabbs ):
238247 code = (self .morton_codes [i_b , i_a ] >> (i * 8 )) & 0xFF
239248 self .offset [i_b , i_a ] = ti .atomic_add (self .hist [i_b , ti .i32 (code )], 1 )
240249
241- # Compute prefix sum
242- for i_b in ti .ndrange (self .n_batches ):
243- self .prefix_sum [i_b , 0 ] = 0
244- for j in range (1 , 256 ): # sequential prefix sum
245- self .prefix_sum [i_b , j ] = self .prefix_sum [i_b , j - 1 ] + self .hist [i_b , j - 1 ]
250+ # Compute prefix sum
251+ for i_b in ti .ndrange (self .n_batches ):
252+ self .prefix_sum [i_b , 0 ] = 0
253+ for j in range (1 , 256 ): # sequential prefix sum
254+ self .prefix_sum [i_b , j ] = self .prefix_sum [i_b , j - 1 ] + self .hist [i_b , j - 1 ]
246255
247- # Reorder morton codes
248- for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
249- code = (self .morton_codes [i_b , i_a ] >> (i * 8 )) & 0xFF
250- idx = ti .i32 (self .offset [i_b , i_a ] + self .prefix_sum [i_b , ti .i32 (code )])
251- self .tmp_morton_codes [i_b , idx ] = self .morton_codes [i_b , i_a ]
256+ # Reorder morton codes
257+ for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
258+ code = (self .morton_codes [i_b , i_a ] >> (i * 8 )) & 0xFF
259+ idx = ti .i32 (self .offset [i_b , i_a ] + self .prefix_sum [i_b , ti .i32 (code )])
260+ self .tmp_morton_codes [i_b , idx ] = self .morton_codes [i_b , i_a ]
252261
253- # Swap the temporary and original morton codes
254- for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
255- self .morton_codes [i_b , i_a ] = self .tmp_morton_codes [i_b , i_a ]
262+ # Swap the temporary and original morton codes
263+ for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
264+ self .morton_codes [i_b , i_a ] = self .tmp_morton_codes [i_b , i_a ]
256265
257- @ti .func
266+ @ti .kernel
258267 def build_radix_tree (self ):
259268 """
260269 Build the radix tree from the sorted morton codes.
@@ -321,31 +330,51 @@ def delta(self, i, j, i_b):
321330 break
322331 return result
323332
324- @ti .func
325333 def compute_bounds (self ):
326334 """
327335 Compute the bounds of the BVH nodes.
328336
329- Starts from the leaf nodes and works upwards.
337+ Starts from the leaf nodes and works upwards layer by layer .
330338 """
331- for i_b , i in ti .ndrange (self .n_batches , self .n_aabbs - 1 ):
332- self .internal_node_visited [i_b , i ] = ti .u8 (0 )
339+ self ._kernel_compute_bounds_init ()
340+ while self .updated [None ]:
341+ self ._kernel_compute_bounds_one_layer ()
342+
343+ @ti .kernel
344+ def _kernel_compute_bounds_init (self ):
345+ self .updated [None ] = True
346+ self .internal_node_active .fill (0 )
347+ self .internal_node_ready .fill (0 )
333348
334349 for i_b , i in ti .ndrange (self .n_batches , self .n_aabbs ):
335350 idx = ti .i32 (self .morton_codes [i_b , i ])
336351 self .nodes [i_b , i + self .n_aabbs - 1 ].bound .min = self .aabbs [i_b , idx ].min
337352 self .nodes [i_b , i + self .n_aabbs - 1 ].bound .max = self .aabbs [i_b , idx ].max
353+ parent_idx = self .nodes [i_b , i + self .n_aabbs - 1 ].parent
354+ if parent_idx != - 1 :
355+ self .internal_node_active [i_b , parent_idx ] = 1
338356
339- cur_idx = self .nodes [i_b , i + self .n_aabbs - 1 ].parent
340- while cur_idx != - 1 :
341- visited = ti .u1 (ti .atomic_or (self .internal_node_visited [i_b , cur_idx ], ti .u8 (1 )))
342- if not visited :
343- break
344- left_bound = self .nodes [i_b , self .nodes [i_b , cur_idx ].left ].bound
345- right_bound = self .nodes [i_b , self .nodes [i_b , cur_idx ].right ].bound
346- self .nodes [i_b , cur_idx ].bound .min = ti .min (left_bound .min , right_bound .min )
347- self .nodes [i_b , cur_idx ].bound .max = ti .max (left_bound .max , right_bound .max )
348- cur_idx = self .nodes [i_b , cur_idx ].parent
357+ @ti .kernel
358+ def _kernel_compute_bounds_one_layer (self ):
359+ self .updated [None ] = False
360+ for i_b , i in ti .ndrange (self .n_batches , self .n_aabbs - 1 ):
361+ if self .internal_node_active [i_b , i ] == 0 :
362+ continue
363+ left_bound = self .nodes [i_b , self .nodes [i_b , i ].left ].bound
364+ right_bound = self .nodes [i_b , self .nodes [i_b , i ].right ].bound
365+ self .nodes [i_b , i ].bound .min = ti .min (left_bound .min , right_bound .min )
366+ self .nodes [i_b , i ].bound .max = ti .max (left_bound .max , right_bound .max )
367+ parent_idx = self .nodes [i_b , i ].parent
368+ if parent_idx != - 1 :
369+ self .internal_node_ready [i_b , parent_idx ] = 1
370+ self .internal_node_active [i_b , i ] = 0
371+ self .updated [None ] = True
372+
373+ for i_b , i in ti .ndrange (self .n_batches , self .n_aabbs - 1 ):
374+ if self .internal_node_ready [i_b , i ] == 0 :
375+ continue
376+ self .internal_node_active [i_b , i ] = 1
377+ self .internal_node_ready [i_b , i ] = 0
349378
350379 @ti .kernel
351380 def query (self , aabbs : ti .template ()):
0 commit comments