Skip to content

Commit b60d18c

Browse files
authored
[BUG FIX] Fix 'RigidGeom.get_(pos|quat)' invalid shape. (Genesis-Embodied-AI#2218)
1 parent 4fbc610 commit b60d18c

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

genesis/engine/entities/rigid_entity/rigid_geom.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,15 @@ def get_pos(self, envs_idx=None):
361361
"""
362362
Get the position of the geom in world frame.
363363
"""
364-
tensor = ti_to_torch(self._solver.geoms_state.pos, envs_idx, self._idx, transpose=True, copy=True)
364+
tensor = ti_to_torch(self._solver.geoms_state.pos, envs_idx, self._idx, transpose=True, copy=True)[..., 0, :]
365365
return tensor[0] if self._solver.n_envs == 0 else tensor
366366

367367
@gs.assert_built
368368
def get_quat(self, envs_idx=None):
369369
"""
370370
Get the quaternion of the geom in world frame.
371371
"""
372-
tensor = ti_to_torch(self._solver.geoms_state.quat, envs_idx, self._idx, transpose=True, copy=True)
372+
tensor = ti_to_torch(self._solver.geoms_state.quat, envs_idx, self._idx, transpose=True, copy=True)[..., 0, :]
373373
return tensor[0] if self._solver.n_envs == 0 else tensor
374374

375375
@gs.assert_built
@@ -874,15 +874,15 @@ def get_pos(self, envs_idx=None):
874874
"""
875875
Get the position of the geom in world frame.
876876
"""
877-
tensor = ti_to_torch(self._solver.vgeoms_state.pos, envs_idx, self._idx, transpose=True, copy=True)
877+
tensor = ti_to_torch(self._solver.vgeoms_state.pos, envs_idx, self._idx, transpose=True, copy=True)[..., 0, :]
878878
return tensor[0] if self._solver.n_envs == 0 else tensor
879879

880880
@gs.assert_built
881881
def get_quat(self, envs_idx=None):
882882
"""
883883
Get the quaternion of the geom in world frame.
884884
"""
885-
tensor = ti_to_torch(self._solver.vgeoms_state.quat, envs_idx, self._idx, transpose=True, copy=True)
885+
tensor = ti_to_torch(self._solver.vgeoms_state.quat, envs_idx, self._idx, transpose=True, copy=True)[..., 0, :]
886886
return tensor[0] if self._solver.n_envs == 0 else tensor
887887

888888
@gs.assert_built

genesis/utils/misc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,8 @@ def assign_indexed_tensor(
897897
value: "np.typing.ArrayLike",
898898
dim_names: tuple[str, ...] | list[str] | None = None,
899899
) -> None:
900+
if isinstance(tensor, np.ndarray):
901+
value = torch.as_tensor(value)
900902
try:
901903
tensor[indices] = value
902904
except (TypeError, RuntimeError):

tests/test_rigid_physics.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,9 +3157,17 @@ def must_cast(value, dtype):
31573157
(-1, n_envs, gs_robot.get_quat, gs_robot.set_quat, None),
31583158
(-1, -1, gs_robot.get_mass, gs_robot.set_mass, None),
31593159
(-1, -1, gs_robot.get_AABB, None, None),
3160+
(-1, -1, gs_robot.get_vAABB, None, None),
31603161
# LINK
3162+
(-1, -1, gs_link.get_pos, None, None),
3163+
(-1, -1, gs_link.get_quat, None, None),
31613164
(-1, -1, gs_link.get_mass, gs_link.set_mass, None),
31623165
(-1, -1, gs_link.get_AABB, None, None),
3166+
(-1, -1, gs_link.get_vAABB, None, None),
3167+
# GEOM
3168+
(-1, -1, gs_link.get_pos, None, None),
3169+
(-1, -1, gs_link.get_quat, None, None),
3170+
(-1, -1, gs_link.get_vAABB, None, None),
31633171
):
31643172
getter, spec = (getter_or_spec, None) if callable(getter_or_spec) else (None, getter_or_spec)
31653173

@@ -3387,7 +3395,8 @@ def test_extended_broadcasting():
33873395

33883396

33893397
@pytest.mark.required
3390-
def test_geom_pos_quat(show_viewer):
3398+
@pytest.mark.parametrize("n_envs", [0, 2])
3399+
def test_geom_pos_quat(n_envs, show_viewer):
33913400
scene = gs.Scene(
33923401
sim_options=gs.options.SimOptions(
33933402
gravity=(0.0, 0.0, -10.0),
@@ -3401,12 +3410,22 @@ def test_geom_pos_quat(show_viewer):
34013410
pos=(0.0, 0.0, 2.0),
34023411
)
34033412
)
3404-
scene.build()
3413+
scene.build(n_envs=n_envs)
3414+
batch_shape = (n_envs,) if n_envs > 0 else ()
3415+
3416+
box.set_dofs_position(np.random.rand(*batch_shape, 6))
3417+
scene.rigid_solver.update_vgeoms()
34053418

34063419
for link in box.links:
34073420
for vgeom, geom in zip(link.vgeoms, link.geoms):
3408-
assert_allclose(geom.get_pos(), vgeom.get_pos(), atol=gs.EPS)
3409-
assert_allclose(geom.get_quat(), vgeom.get_quat(), atol=gs.EPS)
3421+
geom_pos, geom_quat = geom.get_pos(), geom.get_quat()
3422+
assert geom_pos.shape == (*batch_shape, 3)
3423+
assert geom_quat.shape == (*batch_shape, 4)
3424+
vgeom_pos, vgeom_quat = vgeom.get_pos(), vgeom.get_quat()
3425+
assert vgeom_pos.shape == (*batch_shape, 3)
3426+
assert vgeom_quat.shape == (*batch_shape, 4)
3427+
assert_allclose(geom_pos, vgeom_pos, atol=gs.EPS)
3428+
assert_allclose(geom_quat, vgeom_quat, atol=gs.EPS)
34103429

34113430

34123431
@pytest.mark.required

0 commit comments

Comments
 (0)