Skip to content

Commit ebcd09e

Browse files
committed
gravity is now a torch tensor and unit test is a sphere
1 parent c1c189a commit ebcd09e

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

genesis/engine/simulator.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import TYPE_CHECKING
22
import numpy as np
33
import taichi as ti
4+
import torch
45

56
import genesis as gs
67
from genesis.engine.entities.base_entity import Entity
@@ -106,8 +107,8 @@ def __init__(
106107
self._steps_local = options._steps_local
107108

108109
self._cur_substep_global = 0
109-
self._g_np = np.asarray(options.gravity, dtype=np.float32)
110-
self._g_ti = None
110+
self._g = torch.as_tensor(options.gravity, dtype=torch.float32).clone()
111+
self._g_ti: ti.Field | None = None
111112

112113
# solvers
113114
self.tool_solver = ToolSolver(self.scene, self, self.tool_options)
@@ -189,10 +190,11 @@ def build(self):
189190
self._B = self.scene._B
190191
self._para_level = self.scene._para_level
191192

192-
if self._g_np.ndim == 1:
193-
self._g_np = np.repeat(self._g_np[None], self._B, axis=0)
193+
g_np = self._g.numpy()
194+
if g_np.ndim == 1:
195+
g_np = np.repeat(g_np[None], self._B, axis=0)
194196
self._g_ti = ti.Vector.field(3, gs.ti_float, shape=self._B)
195-
self._g_ti.from_numpy(self._g_np)
197+
self._g_ti.from_numpy(g_np)
196198

197199
# solvers
198200
self._rigid_only = self.rigid_solver.is_active()
@@ -425,7 +427,19 @@ def scene(self):
425427
@property
426428
def gravity(self):
427429
"""The gravity vector."""
428-
return self._g_ti
430+
return self._g_ti if self._g_ti is not None else self._g
431+
432+
@gravity.setter
433+
def gravity(self, new_g):
434+
"""Set the gravity vector for the simulator."""
435+
# store as torch tensor
436+
self._g = torch.as_tensor(new_g, dtype=gs.tc_float)
437+
# if we've already built, update the Taichi field in-place
438+
if self._g_ti is not None:
439+
g_np = self._g.numpy()
440+
if g_np.ndim == 1:
441+
g_np = np.repeat(g_np[None], self._B, axis=0)
442+
self._g_ti.from_numpy(g_np)
429443

430444
@property
431445
def requires_grad(self):

tests/test_rigid_physics.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2202,32 +2202,27 @@ def test_urdf_mimic(show_viewer, tol):
22022202

22032203
@pytest.mark.required
22042204
@pytest.mark.parametrize("backend", [gs.cpu])
2205-
def test_gravity(show_viewer):
2206-
base_rpm = 14468.429183500699
2205+
def test_gravity(show_viewer, tol):
22072206
scene = gs.Scene(
22082207
show_viewer=show_viewer,
22092208
sim_options=gs.options.SimOptions(
22102209
dt=0.01,
22112210
substeps=1,
2212-
gravity=[(0.0, 0.0, -9.8), (0.0, 0.0, -10.00)],
2211+
gravity=[(0.0, 0.0, -9.8), (0.0, 0.0, 9.8)],
22132212
),
22142213
)
22152214

2216-
drone = scene.add_entity(
2217-
gs.morphs.Drone(file="urdf/drones/cf2x.urdf", pos=(0, 0, 1.0)),
2218-
)
2215+
sphere = scene.add_entity(gs.morphs.Sphere())
22192216

22202217
scene.build(n_envs=2)
22212218

2222-
for _ in range(500):
2223-
drone.set_propellels_rpm([[base_rpm, base_rpm, base_rpm, base_rpm], [base_rpm, base_rpm, base_rpm, base_rpm]])
2219+
for _ in range(200):
22242220
scene.step()
22252221

2226-
first_pos = drone.get_dofs_position()[0, 2]
2227-
second_pos = drone.get_dofs_position()[1, 2]
2228-
assert_allclose(
2229-
second_pos, first_pos - 2.5, tol=scene.sim_options.dt
2230-
) # Relax the tolerance due to time integration's error
2222+
first_pos = sphere.get_dofs_position()[0, 2]
2223+
second_pos = sphere.get_dofs_position()[1, 2]
2224+
2225+
assert_allclose(first_pos * -1, second_pos, tol=tol)
22312226

22322227

22332228
@pytest.mark.required

0 commit comments

Comments
 (0)