Skip to content

Commit 0621861

Browse files
committed
Env wise gravity
1 parent 18a4761 commit 0621861

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

genesis/engine/simulator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def __init__(
106106
self._steps_local = options._steps_local
107107

108108
self._cur_substep_global = 0
109-
self._gravity = np.array(options.gravity)
109+
self._g_np = np.asarray(options.gravity, dtype=np.float32)
110+
self._g_ti = None
110111

111112
# solvers
112113
self.tool_solver = ToolSolver(self.scene, self, self.tool_options)
@@ -188,9 +189,15 @@ def build(self):
188189
self._B = self.scene._B
189190
self._para_level = self.scene._para_level
190191

192+
if self._g_np.ndim == 1:
193+
self._g_np = np.repeat(self._g_np[None], self._B, axis=0)
194+
self._g_ti = ti.Vector.field(3, gs.ti_float, shape=self._B)
195+
self._g_ti.from_numpy(self._g_np)
196+
191197
# solvers
192198
self._rigid_only = self.rigid_solver.is_active()
193199
for solver in self._solvers:
200+
solver._finalize_batch(self._B)
194201
solver.build()
195202
if solver.is_active():
196203
self._active_solvers.append(solver)
@@ -418,7 +425,7 @@ def scene(self):
418425
@property
419426
def gravity(self):
420427
"""The gravity vector."""
421-
return self._gravity
428+
return self._g_ti
422429

423430
@property
424431
def requires_grad(self):

genesis/engine/solvers/base_solver.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ def __init__(self, scene: "Scene", sim: "Simulator", options):
1717
self._uid = gs.UID()
1818
self._sim = sim
1919
self._scene = scene
20-
self._dt: float = options.dt
21-
self._substep_dt: float = options.dt / sim.substeps
20+
self._dt = options.dt
21+
self._substep_dt = options.dt / sim.substeps
2222

23-
if hasattr(options, "gravity"):
24-
self._gravity = ti.field(dtype=gs.ti_vec3, shape=())
25-
self._gravity.from_numpy(np.array(options.gravity, dtype=gs.np_float))
26-
else:
27-
self._gravity = None
23+
self._gravity_cfg = np.asarray(options.gravity, gs.np_float) if hasattr(options, "gravity") else None
24+
self._gravity = None
2825

2926
self._entities: list[Entity] = gs.List()
3027

@@ -34,6 +31,15 @@ def __init__(self, scene: "Scene", sim: "Simulator", options):
3431
def _add_force_field(self, force_field):
3532
self._ffs.append(force_field)
3633

34+
def _finalize_batch(self, B: int):
35+
if self._gravity_cfg is None or self._gravity is not None:
36+
return
37+
self._gravity = ti.field(dtype=gs.ti_vec3, shape=(B,))
38+
g = self._gravity_cfg
39+
if g.ndim == 1:
40+
g = np.tile(g[None, :], (B, 1))
41+
self._gravity.from_numpy(g[..., None])
42+
3743
# ------------------------------------------------------------------------------------
3844
# ----------------------------------- properties -------------------------------------
3945
# ------------------------------------------------------------------------------------

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,9 +2932,7 @@ def _func_update_acc(self, update_cacc: ti.template()):
29322932
i_p = self.links_info[I_l].parent_idx
29332933

29342934
if i_p == -1:
2935-
self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * (
2936-
1 - e_info.gravity_compensation
2937-
)
2935+
self.links_state[i_l, i_b].cdd_vel = -self._gravity[i_b] * (1 - e_info.gravity_compensation)
29382936
self.links_state[i_l, i_b].cdd_ang = ti.Vector.zero(gs.ti_float, 3)
29392937
if ti.static(update_cacc):
29402938
self.links_state[i_l, i_b].cacc_lin = ti.Vector.zero(gs.ti_float, 3)
@@ -2971,7 +2969,7 @@ def _func_update_acc(self, update_cacc: ti.template()):
29712969
i_p = self.links_info[I_l].parent_idx
29722970

29732971
if i_p == -1:
2974-
self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * (1 - e_info.gravity_compensation)
2972+
self.links_state[i_l, i_b].cdd_vel = -self._gravity[i_b] * (1 - e_info.gravity_compensation)
29752973
self.links_state[i_l, i_b].cdd_ang = ti.Vector.zero(gs.ti_float, 3)
29762974
if ti.static(update_cacc):
29772975
self.links_state[i_l, i_b].cacc_lin = ti.Vector.zero(gs.ti_float, 3)
@@ -4620,7 +4618,7 @@ def _kernel_get_links_acc(
46204618
# Mimick IMU accelerometer signal if requested
46214619
if mimick_imu:
46224620
# Subtract gravity
4623-
acc_classic_lin -= self._gravity[None]
4621+
acc_classic_lin -= self._gravity[i_b]
46244622

46254623
# Move the resulting linear acceleration in local links frame
46264624
acc_classic_lin = gu.ti_inv_transform_by_quat(acc_classic_lin, self.links_state[i_l, i_b].quat)

0 commit comments

Comments
 (0)