Skip to content

Commit 48c141e

Browse files
committed
fixed minor error
1 parent 0119348 commit 48c141e

File tree

3 files changed

+42
-31
lines changed

3 files changed

+42
-31
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,7 @@ def get_verts(self):
18791879
The vertices of the entity (using collision geoms).
18801880
"""
18811881

1882+
self._update_verts_for_geom()
18821883
if self.is_free:
18831884
tensor = torch.empty(
18841885
self._solver._batch_shape((self.n_verts, 3), True), dtype=gs.tc_float, device=gs.device
@@ -1891,22 +1892,20 @@ def get_verts(self):
18911892
self._kernel_get_fixed_verts(tensor)
18921893
return tensor
18931894

1894-
@ti.kernel
1895-
def _kernel_get_free_verts(self, tensor: ti.types.ndarray()):
1896-
for i_g_, i_b in ti.ndrange(self.n_geoms, self._solver._B):
1895+
@gs.assert_built
1896+
def _update_verts_for_geom(self):
1897+
for i_g_ in range(self.n_geoms):
18971898
i_g = i_g_ + self._geom_start
1898-
self._solver._func_update_verts_for_geom(i_g, i_b)
1899+
self._solver.update_verts_for_geom(i_g)
18991900

1901+
@ti.kernel
1902+
def _kernel_get_free_verts(self, tensor: ti.types.ndarray()):
19001903
for i, j, b in ti.ndrange(self.n_verts, 3, self._solver._B):
19011904
idx_vert = i + self._verts_state_start
19021905
tensor[b, i, j] = self._solver.free_verts_state.pos[idx_vert, b][j]
19031906

19041907
@ti.kernel
19051908
def _kernel_get_fixed_verts(self, tensor: ti.types.ndarray()):
1906-
for i_g_ in range(self.n_geoms):
1907-
i_g = i_g_ + self._geom_start
1908-
self._solver._func_update_verts_for_geom(i_g, 0)
1909-
19101909
for i, j in ti.ndrange(self.n_verts, 3):
19111910
idx_vert = i + self._verts_state_start
19121911
tensor[i, j] = self._solver.fixed_verts_state.pos[idx_vert][j]

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def get_verts(self):
294294
"""
295295
Get the vertices of the link's collision body (concatenation of all `link.geoms`) in the world frame.
296296
"""
297+
self._update_verts_for_geom()
297298
if self.is_free:
298299
tensor = torch.empty(
299300
self._solver._batch_shape((self.n_verts, 3), True), dtype=gs.tc_float, device=gs.device
@@ -306,22 +307,20 @@ def get_verts(self):
306307
self._kernel_get_fixed_verts(tensor)
307308
return tensor
308309

309-
@ti.kernel
310-
def _kernel_get_free_verts(self, tensor: ti.types.ndarray()):
311-
for i_g_, i_b in ti.ndrange(self.n_geoms, self._solver._B):
310+
@gs.assert_built
311+
def _update_verts_for_geom(self):
312+
for i_g_ in range(self.n_geoms):
312313
i_g = i_g_ + self._geom_start
313-
self._solver._func_update_verts_for_geom(i_g, i_b)
314+
self._solver.update_verts_for_geom(i_g)
314315

316+
@ti.kernel
317+
def _kernel_get_free_verts(self, tensor: ti.types.ndarray()):
315318
for i, j, b in ti.ndrange(self.n_verts, 3, self._solver._B):
316319
idx_vert = i + self._verts_state_start
317320
tensor[b, i, j] = self._solver.free_verts_state.pos[idx_vert, b][j]
318321

319322
@ti.kernel
320323
def _kernel_get_fixed_verts(self, tensor: ti.types.ndarray()):
321-
for i_g_ in range(self.n_geoms):
322-
i_g = i_g_ + self._geom_start
323-
self._solver._func_update_verts_for_geom(i_g, 0)
324-
325324
for i, j in ti.ndrange(self.n_verts, 3):
326325
idx_vert = i + self._verts_state_start
327326
tensor[i, j] = self._solver.fixed_verts_state.pos[idx_vert][j]

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4778,21 +4778,34 @@ def kernel_update_verts_for_geom(
47784778
):
47794779
_B = geoms_state.verts_updated.shape[1]
47804780
for i_b in range(_B):
4781-
if not geoms_state.verts_updated[i_g, i_b]:
4782-
if geoms_info.is_free[i_g]:
4783-
for i_v in range(geoms_info.vert_start[i_g], geoms_info.vert_end[i_g]):
4784-
verts_state_idx = verts_info.verts_state_idx[i_v]
4785-
free_verts_state.pos[verts_state_idx, i_b] = gu.ti_transform_by_trans_quat(
4786-
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
4787-
)
4788-
geoms_state.verts_updated[i_g, i_b] = 1
4789-
elif i_b == 0:
4790-
for i_v in range(geoms_info.vert_start[i_g], geoms_info.vert_end[i_g]):
4791-
verts_state_idx = verts_info.verts_state_idx[i_v]
4792-
fixed_verts_state.pos[verts_state_idx] = gu.ti_transform_by_trans_quat(
4793-
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
4794-
)
4795-
geoms_state.verts_updated[i_g, 0] = 1
4781+
func_update_verts_for_geom(i_g, i_b, geoms_state, geoms_info, verts_info, free_verts_state, fixed_verts_state)
4782+
4783+
4784+
@ti.func
4785+
def func_update_verts_for_geom(
4786+
i_g: ti.i32,
4787+
i_b: ti.i32,
4788+
geoms_state: array_class.GeomsState,
4789+
geoms_info: array_class.GeomsInfo,
4790+
verts_info: array_class.VertsInfo,
4791+
free_verts_state: array_class.FreeVertsState,
4792+
fixed_verts_state: array_class.FixedVertsState,
4793+
):
4794+
if not geoms_state.verts_updated[i_g, i_b]:
4795+
if geoms_info.is_free[i_g]:
4796+
for i_v in range(geoms_info.vert_start[i_g], geoms_info.vert_end[i_g]):
4797+
verts_state_idx = verts_info.verts_state_idx[i_v]
4798+
free_verts_state.pos[verts_state_idx, i_b] = gu.ti_transform_by_trans_quat(
4799+
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
4800+
)
4801+
geoms_state.verts_updated[i_g, i_b] = 1
4802+
elif i_b == 0:
4803+
for i_v in range(geoms_info.vert_start[i_g], geoms_info.vert_end[i_g]):
4804+
verts_state_idx = verts_info.verts_state_idx[i_v]
4805+
fixed_verts_state.pos[verts_state_idx] = gu.ti_transform_by_trans_quat(
4806+
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
4807+
)
4808+
geoms_state.verts_updated[i_g, 0] = 1
47964809

47974810

47984811
@ti.func

0 commit comments

Comments
 (0)