Skip to content

Commit d88faa9

Browse files
authored
[BUG FIX] Fix gravity setters. (#1498)
1 parent 5d0b452 commit d88faa9

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

genesis/engine/simulator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ def get_state(self):
403403

404404
def set_gravity(self, gravity, envs_idx=None):
405405
for solver in self._solvers:
406-
solver.set_gravity(gravity, envs_idx)
406+
if solver.is_active():
407+
solver.set_gravity(gravity, envs_idx)
407408

408409
# ------------------------------------------------------------------------------------
409410
# ----------------------------------- properties -------------------------------------

genesis/engine/solvers/base_solver.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,26 @@ def build(self):
4242
@gs.assert_built
4343
def set_gravity(self, gravity, envs_idx=None):
4444
if self._gravity is None:
45+
gs.logger.debug("Gravity is not defined, skipping `set_gravity`.")
4546
return
4647
g = np.asarray(gravity, dtype=gs.np_float)
4748
if envs_idx is None:
4849
if g.ndim == 1:
4950
g = np.tile(g, (self._B, 1))
51+
assert g.shape == (self._B, 3), "Input gravity array should match (n_envs, 3)"
5052
self._gravity.from_numpy(g)
5153
else:
52-
self._gravity[envs_idx] = g
54+
envs_idx = np.atleast_1d(np.array(envs_idx, dtype=gs.np_int))
55+
if g.ndim == 1:
56+
g = np.tile(g, (len(envs_idx), 1))
57+
assert g.shape == (len(envs_idx), 3), "Input gravity array should match (len(envs_idx), 3)"
58+
self._kernel_set_gravity(g, envs_idx)
59+
60+
@ti.kernel
61+
def _kernel_set_gravity(self, gravity: ti.types.ndarray(), envs_idx: ti.types.ndarray()):
62+
for i_b_ in range(envs_idx.shape[0]):
63+
for j in ti.static(range(3)):
64+
self._gravity[envs_idx[i_b_]][j] = gravity[i_b_, j]
5365

5466
def dump_ckpt_to_numpy(self) -> dict[str, np.ndarray]:
5567
arrays: dict[str, np.ndarray] = {}

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,16 +2194,9 @@ def update_vgeoms(self):
21942194

21952195
@gs.assert_built
21962196
def set_gravity(self, gravity, envs_idx=None):
2197-
if not hasattr(self, "_rigid_global_info"):
2198-
super().set_gravity(gravity, envs_idx)
2199-
return
2200-
g = np.asarray(gravity, dtype=gs.np_float)
2201-
if envs_idx is None:
2202-
if g.ndim == 1:
2203-
g = np.tile(g, (self._B, 1))
2204-
self._rigid_global_info.gravity.from_numpy(g)
2205-
else:
2206-
self._rigid_global_info.gravity[envs_idx] = g
2197+
super().set_gravity(gravity, envs_idx)
2198+
if hasattr(self, "_rigid_global_info"):
2199+
self._rigid_global_info.gravity.copy_from(self._gravity)
22072200

22082201
def rigid_entity_inverse_kinematics(
22092202
self,

tests/test_rigid_physics.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,18 +2177,29 @@ def test_gravity(show_viewer, tol):
21772177
)
21782178

21792179
sphere = scene.add_entity(gs.morphs.Sphere())
2180-
scene.build(n_envs=2)
2180+
scene.build(n_envs=3)
21812181

2182-
scene.sim.set_gravity(torch.tensor([0.0, 0.0, -9.8]), envs_idx=0)
2183-
scene.sim.set_gravity(torch.tensor([0.0, 0.0, 9.8]), envs_idx=1)
2182+
scene.sim.set_gravity(torch.tensor([0.0, 0.0, 0.0]))
2183+
scene.sim.set_gravity(torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]), envs_idx=[0, 1])
2184+
scene.sim.set_gravity(torch.tensor([0.0, 0.0, 3.0]), envs_idx=2)
21842185

2185-
for _ in range(200):
2186-
scene.step()
2186+
with np.testing.assert_raises(AssertionError):
2187+
scene.sim.set_gravity(torch.tensor([0.0, -10.0]))
2188+
2189+
with np.testing.assert_raises(AssertionError):
2190+
scene.sim.set_gravity(torch.tensor([[0.0, 0.0, -10.0], [0.0, 0.0, -10.0]]), envs_idx=1)
21872191

2188-
first_pos = sphere.get_dofs_position()[0, 2]
2189-
second_pos = sphere.get_dofs_position()[1, 2]
2192+
scene.step()
21902193

2191-
assert_allclose(first_pos * -1, second_pos, tol=tol)
2194+
assert_allclose(
2195+
[
2196+
[1.0, 0.0, 0.0],
2197+
[0.0, 2.0, 0.0],
2198+
[0.0, 0.0, 3.0],
2199+
],
2200+
sphere.get_links_acc().squeeze(),
2201+
tol=tol,
2202+
)
21922203

21932204

21942205
@pytest.mark.required

0 commit comments

Comments
 (0)