Skip to content

Commit 5c58597

Browse files
authored
Merge branch 'main' into czh/render
2 parents 4ffabbf + d201735 commit 5c58597

28 files changed

+5251
-2825
lines changed

genesis/engine/bvh.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class LBVH(RBC):
112112
https://research.nvidia.com/sites/default/files/pubs/2012-06_Maximizing-Parallelism-in/karras2012hpg_paper.pdf
113113
"""
114114

115-
def __init__(self, aabb: AABB, max_n_query_result_per_aabb: int = 8):
115+
def __init__(self, aabb: AABB, max_n_query_result_per_aabb: int = 8, n_radix_sort_groups: int = 256):
116116
self.aabbs = aabb.aabbs
117117
self.n_aabbs = aabb.n_aabbs
118118
self.n_batches = aabb.n_batches
@@ -130,12 +130,18 @@ def __init__(self, aabb: AABB, max_n_query_result_per_aabb: int = 8):
130130
# Histogram for radix sort
131131
self.hist = ti.field(ti.u32, shape=(self.n_batches, 256))
132132
# Prefix sum for histogram
133-
self.prefix_sum = ti.field(ti.u32, shape=(self.n_batches, 256))
133+
self.prefix_sum = ti.field(ti.u32, shape=(self.n_batches, 256 + 1))
134134
# Offset for radix sort
135135
self.offset = ti.field(ti.u32, shape=(self.n_batches, self.n_aabbs))
136136
# Temporary storage for radix sort
137137
self.tmp_morton_codes = ti.field(ti.types.vector(2, ti.u32), shape=(self.n_batches, self.n_aabbs))
138138

139+
self.n_radix_sort_groups = n_radix_sort_groups
140+
self.hist_group = ti.field(ti.u32, shape=(self.n_batches, self.n_radix_sort_groups, 256 + 1))
141+
self.prefix_sum_group = ti.field(ti.u32, shape=(self.n_batches, self.n_radix_sort_groups + 1, 256))
142+
self.group_size = self.n_aabbs // self.n_radix_sort_groups
143+
self.visited = ti.field(ti.u8, shape=(self.n_aabbs,))
144+
139145
@ti.dataclass
140146
class Node:
141147
"""
@@ -176,6 +182,19 @@ def build(self):
176182
self.build_radix_tree()
177183
self.compute_bounds()
178184

185+
@ti.func
186+
def filter(self, i_a, i_q):
187+
"""
188+
Filter function that always returns False.
189+
190+
This function does not filter out any AABB by default.
191+
It can be overridden in subclasses to implement custom filtering logic.
192+
193+
i_a: index of the found AABB
194+
i_q: index of the query AABB
195+
"""
196+
return False
197+
179198
@ti.kernel
180199
def compute_aabb_centers_and_scales(self):
181200
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
@@ -230,8 +249,12 @@ def radix_sort_morton_codes(self):
230249
"""
231250
Radix sort the morton codes, using 8 bits at a time.
232251
"""
233-
for i in range(8):
234-
self._kernel_radix_sort_morton_codes_one_round(i)
252+
# The last 32 bits are the index of the AABB which are already sorted, no need to sort
253+
for i in range(4, 8):
254+
if self.n_radix_sort_groups == 1:
255+
self._kernel_radix_sort_morton_codes_one_round(i)
256+
else:
257+
self._kernel_radix_sort_morton_codes_one_round_group(i)
235258

236259
@ti.kernel
237260
def _kernel_radix_sort_morton_codes_one_round(self, i: int):
@@ -262,6 +285,50 @@ def _kernel_radix_sort_morton_codes_one_round(self, i: int):
262285
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
263286
self.morton_codes[i_b, i_a] = self.tmp_morton_codes[i_b, i_a]
264287

288+
@ti.kernel
289+
def _kernel_radix_sort_morton_codes_one_round_group(self, i: int):
290+
# Clear histogram
291+
self.hist_group.fill(0)
292+
293+
# Fill histogram
294+
for i_b, i_g in ti.ndrange(self.n_batches, self.n_radix_sort_groups):
295+
start = i_g * self.group_size
296+
end = ti.select(
297+
i_g == self.n_radix_sort_groups - 1,
298+
self.n_aabbs,
299+
(i_g + 1) * self.group_size,
300+
)
301+
for i_a in range(start, end):
302+
code = (self.morton_codes[i_b, i_a][1 - (i // 4)] >> ((i % 4) * 8)) & 0xFF
303+
self.offset[i_b, i_a] = self.hist_group[i_b, i_g, code]
304+
self.hist_group[i_b, i_g, code] = self.hist_group[i_b, i_g, code] + 1
305+
306+
# Compute prefix sum
307+
for i_b, i_c in ti.ndrange(self.n_batches, 256):
308+
self.prefix_sum_group[i_b, 0, i_c] = 0
309+
for i_g in range(1, self.n_radix_sort_groups + 1): # sequential prefix sum
310+
self.prefix_sum_group[i_b, i_g, i_c] = (
311+
self.prefix_sum_group[i_b, i_g - 1, i_c] + self.hist_group[i_b, i_g - 1, i_c]
312+
)
313+
for i_b in range(self.n_batches):
314+
self.prefix_sum[i_b, 0] = 0
315+
for i_c in range(1, 256 + 1): # sequential prefix sum
316+
self.prefix_sum[i_b, i_c] = (
317+
self.prefix_sum[i_b, i_c - 1] + self.prefix_sum_group[i_b, self.n_radix_sort_groups, i_c - 1]
318+
)
319+
320+
# Reorder morton codes
321+
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
322+
code = (self.morton_codes[i_b, i_a][1 - (i // 4)] >> ((i % 4) * 8)) & 0xFF
323+
i_g = ti.min(i_a // self.group_size, self.n_radix_sort_groups - 1)
324+
idx = ti.i32(self.prefix_sum[i_b, code] + self.prefix_sum_group[i_b, i_g, code] + self.offset[i_b, i_a])
325+
# Use the group prefix sum to find the correct index
326+
self.tmp_morton_codes[i_b, idx] = self.morton_codes[i_b, i_a]
327+
328+
# Swap the temporary and original morton codes
329+
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
330+
self.morton_codes[i_b, i_a] = self.tmp_morton_codes[i_b, i_a]
331+
265332
@ti.kernel
266333
def build_radix_tree(self):
267334
"""
@@ -396,10 +463,13 @@ def query(self, aabbs: ti.template()):
396463
if aabbs[i_b, i_q].intersects(node.bound):
397464
# If it's a leaf node, add the AABB index to the query results
398465
if node.left == -1 and node.right == -1:
466+
i_a = ti.i32(self.morton_codes[i_b, node_idx - (self.n_aabbs - 1)][1])
467+
# Check if the filter condition is met
468+
if self.filter(i_a, i_q):
469+
continue
399470
idx = ti.atomic_add(self.query_result_count[None], 1)
400471
if idx < self.max_n_query_results:
401-
code = self.morton_codes[i_b, node_idx - (self.n_aabbs - 1)][1]
402-
self.query_result[idx] = gs.ti_ivec3(i_b, ti.i32(code), i_q) # Store the AABB index
472+
self.query_result[idx] = gs.ti_ivec3(i_b, i_a, i_q) # Store the AABB index
403473
else:
404474
# Push children onto the stack
405475
if node.right != -1:
@@ -408,3 +478,35 @@ def query(self, aabbs: ti.template()):
408478
if node.left != -1:
409479
query_stack[stack_depth] = node.left
410480
stack_depth += 1
481+
482+
483+
@ti.data_oriented
484+
class FEMSurfaceTetLBVH(LBVH):
485+
"""
486+
FEMSurfaceTetLBVH is a specialized Linear BVH for FEM surface tetrahedrals.
487+
488+
It extends the LBVH class to support filtering based on FEM surface tetrahedral elements.
489+
"""
490+
491+
def __init__(self, fem_solver, aabb: AABB, max_n_query_result_per_aabb: int = 8, n_radix_sort_groups: int = 256):
492+
super().__init__(aabb, max_n_query_result_per_aabb, n_radix_sort_groups)
493+
self.fem_solver = fem_solver
494+
495+
@ti.func
496+
def filter(self, i_a, i_q):
497+
"""
498+
Filter function for FEM surface tets. Filter out tet that share vertices.
499+
500+
This is used to avoid self-collisions in FEM surface tets.
501+
502+
i_a: index of the found AABB
503+
i_q: index of the query AABB
504+
"""
505+
506+
result = i_a >= i_q
507+
i_av = self.fem_solver.elements_i[self.fem_solver.surface_elements[i_a]].el2v
508+
i_qv = self.fem_solver.elements_i[self.fem_solver.surface_elements[i_q]].el2v
509+
for i, j in ti.static(ti.ndrange(4, 4)):
510+
if i_av[i] == i_qv[j]:
511+
result = True
512+
return result

0 commit comments

Comments
 (0)