@@ -112,7 +112,7 @@ class LBVH(RBC):
112112 https://research.nvidia.com/sites/default/files/pubs/2012-06_Maximizing-Parallelism-in/karras2012hpg_paper.pdf
113113 """
114114
115- def __init__ (self , aabb : AABB , max_n_query_result_per_aabb : int = 8 ):
115+ def __init__ (self , aabb : AABB , max_n_query_result_per_aabb : int = 8 , n_radix_sort_groups : int = 256 ):
116116 self .aabbs = aabb .aabbs
117117 self .n_aabbs = aabb .n_aabbs
118118 self .n_batches = aabb .n_batches
@@ -130,12 +130,18 @@ def __init__(self, aabb: AABB, max_n_query_result_per_aabb: int = 8):
130130 # Histogram for radix sort
131131 self .hist = ti .field (ti .u32 , shape = (self .n_batches , 256 ))
132132 # Prefix sum for histogram
133- self .prefix_sum = ti .field (ti .u32 , shape = (self .n_batches , 256 ))
133+ self .prefix_sum = ti .field (ti .u32 , shape = (self .n_batches , 256 + 1 ))
134134 # Offset for radix sort
135135 self .offset = ti .field (ti .u32 , shape = (self .n_batches , self .n_aabbs ))
136136 # Temporary storage for radix sort
137137 self .tmp_morton_codes = ti .field (ti .types .vector (2 , ti .u32 ), shape = (self .n_batches , self .n_aabbs ))
138138
139+ self .n_radix_sort_groups = n_radix_sort_groups
140+ self .hist_group = ti .field (ti .u32 , shape = (self .n_batches , self .n_radix_sort_groups , 256 + 1 ))
141+ self .prefix_sum_group = ti .field (ti .u32 , shape = (self .n_batches , self .n_radix_sort_groups + 1 , 256 ))
142+ self .group_size = self .n_aabbs // self .n_radix_sort_groups
143+ self .visited = ti .field (ti .u8 , shape = (self .n_aabbs ,))
144+
139145 @ti .dataclass
140146 class Node :
141147 """
@@ -176,6 +182,19 @@ def build(self):
176182 self .build_radix_tree ()
177183 self .compute_bounds ()
178184
185+ @ti .func
186+ def filter (self , i_a , i_q ):
187+ """
188+ Filter function that always returns False.
189+
190+ This function does not filter out any AABB by default.
191+ It can be overridden in subclasses to implement custom filtering logic.
192+
193+ i_a: index of the found AABB
194+ i_q: index of the query AABB
195+ """
196+ return False
197+
179198 @ti .kernel
180199 def compute_aabb_centers_and_scales (self ):
181200 for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
@@ -230,8 +249,12 @@ def radix_sort_morton_codes(self):
230249 """
231250 Radix sort the morton codes, using 8 bits at a time.
232251 """
233- for i in range (8 ):
234- self ._kernel_radix_sort_morton_codes_one_round (i )
252+ # The last 32 bits are the index of the AABB which are already sorted, no need to sort
253+ for i in range (4 , 8 ):
254+ if self .n_radix_sort_groups == 1 :
255+ self ._kernel_radix_sort_morton_codes_one_round (i )
256+ else :
257+ self ._kernel_radix_sort_morton_codes_one_round_group (i )
235258
236259 @ti .kernel
237260 def _kernel_radix_sort_morton_codes_one_round (self , i : int ):
@@ -262,6 +285,50 @@ def _kernel_radix_sort_morton_codes_one_round(self, i: int):
262285 for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
263286 self .morton_codes [i_b , i_a ] = self .tmp_morton_codes [i_b , i_a ]
264287
288+ @ti .kernel
289+ def _kernel_radix_sort_morton_codes_one_round_group (self , i : int ):
290+ # Clear histogram
291+ self .hist_group .fill (0 )
292+
293+ # Fill histogram
294+ for i_b , i_g in ti .ndrange (self .n_batches , self .n_radix_sort_groups ):
295+ start = i_g * self .group_size
296+ end = ti .select (
297+ i_g == self .n_radix_sort_groups - 1 ,
298+ self .n_aabbs ,
299+ (i_g + 1 ) * self .group_size ,
300+ )
301+ for i_a in range (start , end ):
302+ code = (self .morton_codes [i_b , i_a ][1 - (i // 4 )] >> ((i % 4 ) * 8 )) & 0xFF
303+ self .offset [i_b , i_a ] = self .hist_group [i_b , i_g , code ]
304+ self .hist_group [i_b , i_g , code ] = self .hist_group [i_b , i_g , code ] + 1
305+
306+ # Compute prefix sum
307+ for i_b , i_c in ti .ndrange (self .n_batches , 256 ):
308+ self .prefix_sum_group [i_b , 0 , i_c ] = 0
309+ for i_g in range (1 , self .n_radix_sort_groups + 1 ): # sequential prefix sum
310+ self .prefix_sum_group [i_b , i_g , i_c ] = (
311+ self .prefix_sum_group [i_b , i_g - 1 , i_c ] + self .hist_group [i_b , i_g - 1 , i_c ]
312+ )
313+ for i_b in range (self .n_batches ):
314+ self .prefix_sum [i_b , 0 ] = 0
315+ for i_c in range (1 , 256 + 1 ): # sequential prefix sum
316+ self .prefix_sum [i_b , i_c ] = (
317+ self .prefix_sum [i_b , i_c - 1 ] + self .prefix_sum_group [i_b , self .n_radix_sort_groups , i_c - 1 ]
318+ )
319+
320+ # Reorder morton codes
321+ for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
322+ code = (self .morton_codes [i_b , i_a ][1 - (i // 4 )] >> ((i % 4 ) * 8 )) & 0xFF
323+ i_g = ti .min (i_a // self .group_size , self .n_radix_sort_groups - 1 )
324+ idx = ti .i32 (self .prefix_sum [i_b , code ] + self .prefix_sum_group [i_b , i_g , code ] + self .offset [i_b , i_a ])
325+ # Use the group prefix sum to find the correct index
326+ self .tmp_morton_codes [i_b , idx ] = self .morton_codes [i_b , i_a ]
327+
328+ # Swap the temporary and original morton codes
329+ for i_b , i_a in ti .ndrange (self .n_batches , self .n_aabbs ):
330+ self .morton_codes [i_b , i_a ] = self .tmp_morton_codes [i_b , i_a ]
331+
265332 @ti .kernel
266333 def build_radix_tree (self ):
267334 """
@@ -396,10 +463,13 @@ def query(self, aabbs: ti.template()):
396463 if aabbs [i_b , i_q ].intersects (node .bound ):
397464 # If it's a leaf node, add the AABB index to the query results
398465 if node .left == - 1 and node .right == - 1 :
466+ i_a = ti .i32 (self .morton_codes [i_b , node_idx - (self .n_aabbs - 1 )][1 ])
467+ # Check if the filter condition is met
468+ if self .filter (i_a , i_q ):
469+ continue
399470 idx = ti .atomic_add (self .query_result_count [None ], 1 )
400471 if idx < self .max_n_query_results :
401- code = self .morton_codes [i_b , node_idx - (self .n_aabbs - 1 )][1 ]
402- self .query_result [idx ] = gs .ti_ivec3 (i_b , ti .i32 (code ), i_q ) # Store the AABB index
472+ self .query_result [idx ] = gs .ti_ivec3 (i_b , i_a , i_q ) # Store the AABB index
403473 else :
404474 # Push children onto the stack
405475 if node .right != - 1 :
@@ -408,3 +478,35 @@ def query(self, aabbs: ti.template()):
408478 if node .left != - 1 :
409479 query_stack [stack_depth ] = node .left
410480 stack_depth += 1
481+
482+
483+ @ti .data_oriented
484+ class FEMSurfaceTetLBVH (LBVH ):
485+ """
486+ FEMSurfaceTetLBVH is a specialized Linear BVH for FEM surface tetrahedrals.
487+
488+ It extends the LBVH class to support filtering based on FEM surface tetrahedral elements.
489+ """
490+
491+ def __init__ (self , fem_solver , aabb : AABB , max_n_query_result_per_aabb : int = 8 , n_radix_sort_groups : int = 256 ):
492+ super ().__init__ (aabb , max_n_query_result_per_aabb , n_radix_sort_groups )
493+ self .fem_solver = fem_solver
494+
495+ @ti .func
496+ def filter (self , i_a , i_q ):
497+ """
498+ Filter function for FEM surface tets. Filter out tet that share vertices.
499+
500+ This is used to avoid self-collisions in FEM surface tets.
501+
502+ i_a: index of the found AABB
503+ i_q: index of the query AABB
504+ """
505+
506+ result = i_a >= i_q
507+ i_av = self .fem_solver .elements_i [self .fem_solver .surface_elements [i_a ]].el2v
508+ i_qv = self .fem_solver .elements_i [self .fem_solver .surface_elements [i_q ]].el2v
509+ for i , j in ti .static (ti .ndrange (4 , 4 )):
510+ if i_av [i ] == i_qv [j ]:
511+ result = True
512+ return result
0 commit comments