|
1 | 1 | from typing import TYPE_CHECKING |
2 | 2 | import numpy as np |
3 | 3 | import taichi as ti |
| 4 | +import torch |
4 | 5 |
|
5 | 6 | import genesis as gs |
6 | 7 | from genesis.engine.entities.base_entity import Entity |
@@ -106,8 +107,8 @@ def __init__( |
106 | 107 | self._steps_local = options._steps_local |
107 | 108 |
|
108 | 109 | self._cur_substep_global = 0 |
109 | | - self._g_np = np.asarray(options.gravity, dtype=np.float32) |
110 | | - self._g_ti = None |
| 110 | + self._g = torch.as_tensor(options.gravity, dtype=torch.float32).clone() |
| 111 | + self._g_ti: ti.Field | None = None |
111 | 112 |
|
112 | 113 | # solvers |
113 | 114 | self.tool_solver = ToolSolver(self.scene, self, self.tool_options) |
@@ -189,10 +190,11 @@ def build(self): |
189 | 190 | self._B = self.scene._B |
190 | 191 | self._para_level = self.scene._para_level |
191 | 192 |
|
192 | | - if self._g_np.ndim == 1: |
193 | | - self._g_np = np.repeat(self._g_np[None], self._B, axis=0) |
| 193 | + g_np = self._g.numpy() |
| 194 | + if g_np.ndim == 1: |
| 195 | + g_np = np.repeat(g_np[None], self._B, axis=0) |
194 | 196 | self._g_ti = ti.Vector.field(3, gs.ti_float, shape=self._B) |
195 | | - self._g_ti.from_numpy(self._g_np) |
| 197 | + self._g_ti.from_numpy(g_np) |
196 | 198 |
|
197 | 199 | # solvers |
198 | 200 | self._rigid_only = self.rigid_solver.is_active() |
@@ -425,7 +427,19 @@ def scene(self): |
425 | 427 | @property |
426 | 428 | def gravity(self): |
427 | 429 | """The gravity vector.""" |
428 | | - return self._g_ti |
| 430 | + return self._g_ti if self._g_ti is not None else self._g |
| 431 | + |
| 432 | + @gravity.setter |
| 433 | + def gravity(self, new_g): |
| 434 | + """Set the gravity vector for the simulator.""" |
| 435 | + # store as torch tensor |
| 436 | + self._g = torch.as_tensor(new_g, dtype=gs.tc_float) |
| 437 | + # if we've already built, update the Taichi field in-place |
| 438 | + if self._g_ti is not None: |
| 439 | + g_np = self._g.numpy() |
| 440 | + if g_np.ndim == 1: |
| 441 | + g_np = np.repeat(g_np[None], self._B, axis=0) |
| 442 | + self._g_ti.from_numpy(g_np) |
429 | 443 |
|
430 | 444 | @property |
431 | 445 | def requires_grad(self): |
|
0 commit comments