Skip to content

Commit 6947269

Browse files
committed
Env wise gravity
1 parent c4295b3 commit 6947269

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

genesis/engine/simulator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def __init__(
105105
self._steps_local = options._steps_local
106106

107107
self._cur_substep_global = 0
108-
self._gravity = np.array(options.gravity)
108+
self._g_np = np.asarray(options.gravity, dtype=np.float32)
109+
self._g_ti = None
109110

110111
# solvers
111112
self.tool_solver = ToolSolver(self.scene, self, self.tool_options)
@@ -184,9 +185,15 @@ def build(self):
184185
self._B = self.scene._B
185186
self._para_level = self.scene._para_level
186187

188+
if self._g_np.ndim == 1:
189+
self._g_np = np.repeat(self._g_np[None], self._B, axis=0)
190+
self._g_ti = ti.Vector.field(3, gs.ti_float, shape=self._B)
191+
self._g_ti.from_numpy(self._g_np)
192+
187193
# solvers
188194
self._rigid_only = self.rigid_solver.is_active()
189195
for solver in self._solvers:
196+
solver._finalize_batch(self._B)
190197
solver.build()
191198
if solver.is_active():
192199
self._active_solvers.append(solver)
@@ -411,7 +418,7 @@ def scene(self):
411418
@property
412419
def gravity(self):
413420
"""The gravity vector."""
414-
return self._gravity
421+
return self._g_ti
415422

416423
@property
417424
def requires_grad(self):

genesis/engine/solvers/base_solver.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ def __init__(self, scene: "Scene", sim: "Simulator", options):
1717
self._uid = gs.UID()
1818
self._sim = sim
1919
self._scene = scene
20-
self._dt: float = options.dt
21-
self._substep_dt: float = options.dt / sim.substeps
20+
self._dt = options.dt
21+
self._substep_dt = options.dt / sim.substeps
2222

23-
if hasattr(options, "gravity"):
24-
self._gravity = ti.field(dtype=gs.ti_vec3, shape=())
25-
self._gravity.from_numpy(np.array(options.gravity, dtype=gs.np_float))
26-
else:
27-
self._gravity = None
23+
self._gravity_cfg = np.asarray(options.gravity, gs.np_float) if hasattr(options, "gravity") else None
24+
self._gravity = None
2825

2926
self._entities: list[Entity] = gs.List()
3027

@@ -34,6 +31,15 @@ def __init__(self, scene: "Scene", sim: "Simulator", options):
3431
def _add_force_field(self, force_field):
3532
self._ffs.append(force_field)
3633

34+
def _finalize_batch(self, B: int):
35+
if self._gravity_cfg is None or self._gravity is not None:
36+
return
37+
self._gravity = ti.field(dtype=gs.ti_vec3, shape=(B,))
38+
g = self._gravity_cfg
39+
if g.ndim == 1:
40+
g = np.tile(g[None, :], (B, 1))
41+
self._gravity.from_numpy(g[..., None])
42+
3743
# ------------------------------------------------------------------------------------
3844
# ----------------------------------- properties -------------------------------------
3945
# ------------------------------------------------------------------------------------

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,9 +2844,7 @@ def _func_update_acc(self, update_cacc: ti.template()):
28442844
i_p = self.links_info[I_l].parent_idx
28452845

28462846
if i_p == -1:
2847-
self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * (
2848-
1 - e_info.gravity_compensation
2849-
)
2847+
self.links_state[i_l, i_b].cdd_vel = -self._gravity[i_b] * (1 - e_info.gravity_compensation)
28502848
self.links_state[i_l, i_b].cdd_ang = ti.Vector.zero(gs.ti_float, 3)
28512849
if ti.static(update_cacc):
28522850
self.links_state[i_l, i_b].cacc_lin = ti.Vector.zero(gs.ti_float, 3)
@@ -2883,7 +2881,7 @@ def _func_update_acc(self, update_cacc: ti.template()):
28832881
i_p = self.links_info[I_l].parent_idx
28842882

28852883
if i_p == -1:
2886-
self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * (1 - e_info.gravity_compensation)
2884+
self.links_state[i_l, i_b].cdd_vel = -self._gravity[i_b] * (1 - e_info.gravity_compensation)
28872885
self.links_state[i_l, i_b].cdd_ang = ti.Vector.zero(gs.ti_float, 3)
28882886
if ti.static(update_cacc):
28892887
self.links_state[i_l, i_b].cacc_lin = ti.Vector.zero(gs.ti_float, 3)
@@ -4532,7 +4530,7 @@ def _kernel_get_links_acc(
45324530
# Mimick IMU accelerometer signal if requested
45334531
if mimick_imu:
45344532
# Subtract gravity
4535-
acc_classic_lin -= self._gravity[None]
4533+
acc_classic_lin -= self._gravity[i_b]
45364534

45374535
# Move the resulting linear acceleration in local links frame
45384536
acc_classic_lin = gu.ti_inv_transform_by_quat(acc_classic_lin, self.links_state[i_l, i_b].quat)

0 commit comments

Comments
 (0)