Skip to content

Commit e1cc023

Browse files
Libero0809YilingQiao
authored andcommitted
[BUG FIX] Fix BVH build's radix sort. (Genesis-Embodied-AI#1305)
1 parent c43cf73 commit e1cc023

File tree

5 files changed

+112
-51
lines changed

5 files changed

+112
-51
lines changed

genesis/engine/bvh.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import genesis as gs
22
import taichi as ti
33
from genesis.repr_base import RBC
4+
import numpy as np
45

56

67
@ti.data_oriented
@@ -157,19 +158,27 @@ class Node:
157158
# Nodes of the BVH, first n_aabbs - 1 are internal nodes, last n_aabbs are leaf nodes
158159
self.nodes = self.Node.field(shape=(self.n_batches, self.n_aabbs * 2 - 1))
159160
# Whether an internal node has been visited during traversal
160-
self.internal_node_visited = ti.field(ti.u8, shape=(self.n_batches, self.n_aabbs - 1))
161+
self.internal_node_active = ti.field(ti.u1, shape=(self.n_batches, self.n_aabbs - 1))
162+
self.internal_node_ready = ti.field(ti.u1, shape=(self.n_batches, self.n_aabbs - 1))
163+
self.updated = ti.field(ti.u1, shape=())
161164

162165
# Query results, vec3 of batch id, self id, query id
163166
self.query_result = ti.field(gs.ti_ivec3, shape=(self.max_n_query_results))
164167
# Count of query results
165168
self.query_result_count = ti.field(ti.i32, shape=())
166169

167-
@ti.kernel
168170
def build(self):
169171
"""
170172
Build the BVH from the axis-aligned bounding boxes (AABBs).
171173
"""
174+
self.compute_aabb_centers_and_scales()
175+
self.compute_morton_codes()
176+
self.radix_sort_morton_codes()
177+
self.build_radix_tree()
178+
self.compute_bounds()
172179

180+
@ti.kernel
181+
def compute_aabb_centers_and_scales(self):
173182
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
174183
self.aabb_centers[i_b, i_a] = (self.aabbs[i_b, i_a].min + self.aabbs[i_b, i_a].max) / 2
175184

@@ -184,14 +193,9 @@ def build(self):
184193
for i_b in ti.ndrange(self.n_batches):
185194
scale = self.aabb_max[i_b] - self.aabb_min[i_b]
186195
for i in ti.static(range(3)):
187-
self.scale[i_b][i] = ti.select(scale[i] > 1e-7, 1.0 / scale[i], 1)
196+
self.scale[i_b][i] = ti.select(scale[i] > gs.EPS, 1.0 / scale[i], 1.0)
188197

189-
self.compute_morton_codes()
190-
self.radix_sort_morton_codes()
191-
self.build_radix_tree()
192-
self.compute_bounds()
193-
194-
@ti.func
198+
@ti.kernel
195199
def compute_morton_codes(self):
196200
"""
197201
Compute the Morton codes for each AABB.
@@ -223,38 +227,43 @@ def expand_bits(self, v):
223227
v = (v * ti.u32(0x00000005)) & ti.u32(0x49249249)
224228
return v
225229

226-
@ti.func
227230
def radix_sort_morton_codes(self):
228231
"""
229232
Radix sort the morton codes, using 8 bits at a time.
230233
"""
231-
for i in ti.static(range(8)):
232-
# Clear histogram
233-
for i_b, j in ti.ndrange(self.n_batches, 256):
234-
self.hist[i_b, j] = 0
234+
for i in range(8):
235+
self._kernel_radix_sort_morton_codes_one_round(i)
235236

236-
# Fill histogram
237-
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
237+
@ti.kernel
238+
def _kernel_radix_sort_morton_codes_one_round(self, i: int):
239+
# Clear histogram
240+
self.hist.fill(0)
241+
242+
# Fill histogram
243+
for i_b in range(self.n_batches):
244+
# This is now sequential
245+
# TODO Parallelize, need to use groups to handle data to remain stable, could be not worth it
246+
for i_a in range(self.n_aabbs):
238247
code = (self.morton_codes[i_b, i_a] >> (i * 8)) & 0xFF
239248
self.offset[i_b, i_a] = ti.atomic_add(self.hist[i_b, ti.i32(code)], 1)
240249

241-
# Compute prefix sum
242-
for i_b in ti.ndrange(self.n_batches):
243-
self.prefix_sum[i_b, 0] = 0
244-
for j in range(1, 256): # sequential prefix sum
245-
self.prefix_sum[i_b, j] = self.prefix_sum[i_b, j - 1] + self.hist[i_b, j - 1]
250+
# Compute prefix sum
251+
for i_b in ti.ndrange(self.n_batches):
252+
self.prefix_sum[i_b, 0] = 0
253+
for j in range(1, 256): # sequential prefix sum
254+
self.prefix_sum[i_b, j] = self.prefix_sum[i_b, j - 1] + self.hist[i_b, j - 1]
246255

247-
# Reorder morton codes
248-
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
249-
code = (self.morton_codes[i_b, i_a] >> (i * 8)) & 0xFF
250-
idx = ti.i32(self.offset[i_b, i_a] + self.prefix_sum[i_b, ti.i32(code)])
251-
self.tmp_morton_codes[i_b, idx] = self.morton_codes[i_b, i_a]
256+
# Reorder morton codes
257+
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
258+
code = (self.morton_codes[i_b, i_a] >> (i * 8)) & 0xFF
259+
idx = ti.i32(self.offset[i_b, i_a] + self.prefix_sum[i_b, ti.i32(code)])
260+
self.tmp_morton_codes[i_b, idx] = self.morton_codes[i_b, i_a]
252261

253-
# Swap the temporary and original morton codes
254-
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
255-
self.morton_codes[i_b, i_a] = self.tmp_morton_codes[i_b, i_a]
262+
# Swap the temporary and original morton codes
263+
for i_b, i_a in ti.ndrange(self.n_batches, self.n_aabbs):
264+
self.morton_codes[i_b, i_a] = self.tmp_morton_codes[i_b, i_a]
256265

257-
@ti.func
266+
@ti.kernel
258267
def build_radix_tree(self):
259268
"""
260269
Build the radix tree from the sorted morton codes.
@@ -321,31 +330,51 @@ def delta(self, i, j, i_b):
321330
break
322331
return result
323332

324-
@ti.func
325333
def compute_bounds(self):
326334
"""
327335
Compute the bounds of the BVH nodes.
328336
329-
Starts from the leaf nodes and works upwards.
337+
Starts from the leaf nodes and works upwards layer by layer.
330338
"""
331-
for i_b, i in ti.ndrange(self.n_batches, self.n_aabbs - 1):
332-
self.internal_node_visited[i_b, i] = ti.u8(0)
339+
self._kernel_compute_bounds_init()
340+
while self.updated[None]:
341+
self._kernel_compute_bounds_one_layer()
342+
343+
@ti.kernel
344+
def _kernel_compute_bounds_init(self):
345+
self.updated[None] = True
346+
self.internal_node_active.fill(0)
347+
self.internal_node_ready.fill(0)
333348

334349
for i_b, i in ti.ndrange(self.n_batches, self.n_aabbs):
335350
idx = ti.i32(self.morton_codes[i_b, i])
336351
self.nodes[i_b, i + self.n_aabbs - 1].bound.min = self.aabbs[i_b, idx].min
337352
self.nodes[i_b, i + self.n_aabbs - 1].bound.max = self.aabbs[i_b, idx].max
353+
parent_idx = self.nodes[i_b, i + self.n_aabbs - 1].parent
354+
if parent_idx != -1:
355+
self.internal_node_active[i_b, parent_idx] = 1
338356

339-
cur_idx = self.nodes[i_b, i + self.n_aabbs - 1].parent
340-
while cur_idx != -1:
341-
visited = ti.u1(ti.atomic_or(self.internal_node_visited[i_b, cur_idx], ti.u8(1)))
342-
if not visited:
343-
break
344-
left_bound = self.nodes[i_b, self.nodes[i_b, cur_idx].left].bound
345-
right_bound = self.nodes[i_b, self.nodes[i_b, cur_idx].right].bound
346-
self.nodes[i_b, cur_idx].bound.min = ti.min(left_bound.min, right_bound.min)
347-
self.nodes[i_b, cur_idx].bound.max = ti.max(left_bound.max, right_bound.max)
348-
cur_idx = self.nodes[i_b, cur_idx].parent
357+
@ti.kernel
358+
def _kernel_compute_bounds_one_layer(self):
359+
self.updated[None] = False
360+
for i_b, i in ti.ndrange(self.n_batches, self.n_aabbs - 1):
361+
if self.internal_node_active[i_b, i] == 0:
362+
continue
363+
left_bound = self.nodes[i_b, self.nodes[i_b, i].left].bound
364+
right_bound = self.nodes[i_b, self.nodes[i_b, i].right].bound
365+
self.nodes[i_b, i].bound.min = ti.min(left_bound.min, right_bound.min)
366+
self.nodes[i_b, i].bound.max = ti.max(left_bound.max, right_bound.max)
367+
parent_idx = self.nodes[i_b, i].parent
368+
if parent_idx != -1:
369+
self.internal_node_ready[i_b, parent_idx] = 1
370+
self.internal_node_active[i_b, i] = 0
371+
self.updated[None] = True
372+
373+
for i_b, i in ti.ndrange(self.n_batches, self.n_aabbs - 1):
374+
if self.internal_node_ready[i_b, i] == 0:
375+
continue
376+
self.internal_node_active[i_b, i] = 1
377+
self.internal_node_ready[i_b, i] = 0
349378

350379
@ti.kernel
351380
def query(self, aabbs: ti.template()):

genesis/engine/coupler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ def __init__(
649649
self._n_linesearch_iterations = options.n_linesearch_iterations
650650
self._linesearch_c = options.linesearch_c
651651
self._linesearch_tau = options.linesearch_tau
652+
self.default_deformable_g = 1.0e8 # default deformable geometry size
652653

653654
def build(self) -> None:
654655
self._B = self.sim._B
@@ -698,6 +699,7 @@ def init_fem_fields(self):
698699
self.max_fem_floor_contact_pairs = fem_solver.n_surfaces * fem_solver._B
699700
self.n_fem_floor_contact_pairs = ti.field(gs.ti_int, shape=())
700701
self.fem_floor_contact_pairs = self.fem_floor_contact_pair_type.field(shape=(self.max_fem_floor_contact_pairs,))
702+
701703
# Lookup table for marching tetrahedra edges
702704
kMarchingTetsEdgeTable_np = np.array(
703705
[
@@ -934,15 +936,12 @@ def fem_floor_detection(self, f: ti.i32):
934936
)
935937
self.fem_floor_contact_pairs[i_c].barycentric = barycentric
936938

937-
C = ti.static(1.0e8)
938-
deformable_g = C
939939
rigid_g = self.fem_pressure_gradient[i_b, i_e].z
940940
# TODO A better way to handle corner cases where pressure and pressure gradient are ill defined
941941
if total_area < gs.EPS or rigid_g < gs.EPS:
942942
self.fem_floor_contact_pairs[i_c].active = 0
943943
continue
944-
g = 1.0 / (1.0 / deformable_g + 1.0 / rigid_g) # harmonic average
945-
deformable_k = total_area * C
944+
g = self.default_deformable_g * rigid_g / (self.default_deformable_g + rigid_g) # harmonic average
946945
rigid_k = total_area * g
947946
rigid_phi0 = -pressure / g
948947
rigid_fn0 = total_area * pressure

genesis/engine/materials/FEM/elastic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def build_linear_corotated(self, fem_solver):
8282

8383
@ti.func
8484
def pre_compute_linear_corotated(self, J, F, i_e, i_b):
85+
# Computing Polar Decomposition instead of calling `R, P = ti.polar_decompose(F)` since `P` is not needed here
8586
U, S, V = ti.svd(F)
8687
R = U @ V.transpose()
8788
self.R[i_b, i_e] = R

genesis/engine/solvers/fem_solver.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,32 @@ def _func_compute_ele_energy(self, f: ti.i32):
528528

529529
self.elements_el_energy[i_b, i_e].energy += 0.5 * damping_beta_over_dt * St_x_diff.dot(H_St_x_diff)
530530

531+
# add linearized damping energy
532+
if self._damping_beta > gs.EPS:
533+
damping_beta_over_dt = self._damping_beta / self._substep_dt
534+
i_v = self.elements_i[i_e].el2v
535+
S = ti.Matrix.zero(gs.ti_float, 4, 3)
536+
B = self.elements_i[i_e].B
537+
S[:3, :] = B
538+
S[3, :] = -B[0, :] - B[1, :] - B[2, :]
539+
540+
x_diff = ti.Vector.zero(gs.ti_float, 12)
541+
for i in ti.static(range(4)):
542+
x_diff[i * 3 : i * 3 + 3] = (
543+
self.elements_v[f + 1, i_v[i], i_b].pos - self.elements_v[f, i_v[i], i_b].pos
544+
)
545+
St_x_diff = ti.Vector.zero(gs.ti_float, 9)
546+
for i, j in ti.static(ti.ndrange(3, 4)):
547+
St_x_diff[i * 3 : i * 3 + 3] += S[j, i] * x_diff[j * 3 : j * 3 + 3]
548+
549+
H_St_x_diff = ti.Vector.zero(gs.ti_float, 9)
550+
for i, j in ti.static(ti.ndrange(3, 3)):
551+
H_St_x_diff[i * 3 : i * 3 + 3] += (
552+
self.elements_el_hessian[i_b, i, j, i_e] @ St_x_diff[j * 3 : j * 3 + 3]
553+
)
554+
555+
self.elements_el_energy[i_b, i_e].energy += 0.5 * damping_beta_over_dt * St_x_diff.dot(H_St_x_diff)
556+
531557
@ti.kernel
532558
def accumulate_vertex_force_preconditioner(self, f: ti.i32):
533559
damping_alpha_dt = self._damping_alpha * self._substep_dt

tests/test_bvh.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
def lbvh():
1414
"""Fixture for a LBVH tree"""
1515

16-
n_aabbs = 20
16+
n_aabbs = 500
1717
n_batches = 10
1818
aabb = AABB(n_batches=n_batches, n_aabbs=n_aabbs)
19-
min = np.random.rand(n_batches, n_aabbs, 3).astype(np.float32)
19+
min = np.random.rand(n_batches, n_aabbs, 3).astype(np.float32) * 20.0
2020
max = min + np.random.rand(n_batches, n_aabbs, 3).astype(np.float32)
2121

2222
aabb.aabbs.min.from_numpy(min)
2323
aabb.aabbs.max.from_numpy(max)
2424

25-
lbvh = LBVH(aabb)
25+
lbvh = LBVH(aabb, max_n_query_result_per_aabb=32)
2626
lbvh.build()
2727
return lbvh
2828

@@ -70,6 +70,7 @@ def test_expand_bits():
7070
), f"Expected {str_expanded_x}, got {''.join(f'00{bit}' for bit in str_x)}"
7171

7272

73+
@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
7374
def test_build_tree(lbvh):
7475
nodes = lbvh.nodes.to_numpy()
7576
n_aabbs = lbvh.n_aabbs
@@ -116,13 +117,18 @@ def test_build_tree(lbvh):
116117
assert_allclose(parent_max, parent_max_expected, atol=1e-6, rtol=1e-5)
117118

118119

120+
@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
119121
def test_query(lbvh):
120122
aabbs = lbvh.aabbs
121123

122124
# Query the tree
123125
lbvh.query(aabbs)
124126

125127
query_result_count = lbvh.query_result_count.to_numpy()
128+
if query_result_count > lbvh.max_n_query_results:
129+
raise ValueError(
130+
f"Query result count {query_result_count} exceeds max_n_query_results {lbvh.max_n_query_results}"
131+
)
126132
query_result = lbvh.query_result.to_numpy()
127133

128134
n_aabbs = lbvh.n_aabbs

0 commit comments

Comments
 (0)