Skip to content

Commit 9d9e7af

Browse files
authored
[MISC] Improve reset (#1350)
1 parent b0f6479 commit 9d9e7af

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

genesis/engine/coupler.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import TYPE_CHECKING
2+
23
import numpy as np
34
import taichi as ti
45

@@ -83,14 +84,30 @@ def build(self) -> None:
8384
self._dx = 1 / 1024
8485
self._stencil_size = int(np.floor(self._dx / self.sph_solver.hash_grid_cell_size) + 2)
8586

86-
self.reset()
87+
self.reset(envs_idx=self.sim.scene._envs_idx)
8788

88-
def reset(self) -> None:
89+
def reset(self, envs_idx=None) -> None:
8990
if self._rigid_mpm and self.mpm_solver.enable_CPIC:
90-
self.mpm_rigid_normal.fill(0)
91+
if envs_idx is None:
92+
self.mpm_rigid_normal.fill(0)
93+
else:
94+
self._kernel_reset_mpm(envs_idx)
9195

9296
if self._rigid_sph:
93-
self.sph_rigid_normal.fill(0)
97+
if envs_idx is None:
98+
self.sph_rigid_normal.fill(0)
99+
else:
100+
self._kernel_reset_sph(envs_idx)
101+
102+
@ti.kernel
103+
def _kernel_reset_mpm(self, envs_idx: ti.types.ndarray()):
104+
for i_p, i_g, i_b_ in ti.ndrange(self.mpm_solver.n_particles, self.rigid_solver.n_geoms, envs_idx.shape[0]):
105+
self.mpm_rigid_normal[i_p, i_g, envs_idx[i_b_]] = 0.0
106+
107+
@ti.kernel
108+
def _kernel_reset_sph(self, envs_idx: ti.types.ndarray()):
109+
for i_p, i_g, i_b_ in ti.ndrange(self.sph_solver.n_particles, self.rigid_solver.n_geoms, envs_idx.shape[0]):
110+
self.sph_rigid_normal[i_p, i_g, envs_idx[i_b_]] = 0.0
94111

95112
@ti.func
96113
def _func_collide_with_rigid(self, f, pos_world, vel, mass, i_b):
@@ -667,7 +684,7 @@ def build(self) -> None:
667684
self.init_pcg_fields()
668685
self.init_linesearch_fields()
669686

670-
def reset(self):
687+
def reset(self, envs_idx=None) -> None:
671688
pass
672689

673690
def init_fem_fields(self):

genesis/engine/scene.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from genesis.engine.force_fields import ForceField
1313
from genesis.engine.materials.base import Material
1414
from genesis.engine.entities import Emitter
15+
from genesis.engine.states.solvers import SimState
1516
from genesis.engine.simulator import Simulator
1617
from genesis.options import (
1718
AvatarOptions,
@@ -695,23 +696,27 @@ def _parallelize(
695696
self._para_level = gs.PARA_LEVEL.ALL
696697

697698
@gs.assert_built
698-
def reset(self, state: dict | None = None, envs_idx=None):
699+
def reset(self, state: SimState | None = None, envs_idx=None):
699700
"""
700701
Resets the scene to its initial state.
701702
702703
Parameters
703704
----------
704-
state : dict | None
705-
The state to reset the scene to. If None, the scene will be reset to its initial state. If this is given, the scene's registerered initial state will be updated to this state.
705+
state : SimState | None
706+
The state to reset the scene to. If None, the scene will be reset to its initial state.
707+
If this is given, the scene's registerered initial state will be updated to this state.
708+
envs_idx : None | array_like, optional
709+
The indices of the environments. If None, all environments will be considered. Defaults to None.
706710
"""
707-
gs.logger.info(f"Resetting Scene ~~~<{self._uid}>~~~.")
708-
self._reset(state, envs_idx)
711+
gs.logger.debug(f"Resetting Scene ~~~<{self._uid}>~~~.")
712+
self._reset(state, envs_idx=envs_idx)
709713

710-
def _reset(self, state=None, envs_idx=None):
714+
def _reset(self, state: SimState | None = None, *, envs_idx=None):
711715
if self._is_built:
712716
if state is None:
713717
state = self._init_state
714718
else:
719+
assert isinstance(state, SimState), "state must be a SimState object"
715720
self._init_state = state
716721
self._sim.reset(state, envs_idx)
717722
else:

genesis/engine/simulator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .coupler import Coupler, SAPCoupler
2424
from .entities import HybridEntity
25+
from .solvers.base_solver import Solver
2526
from .solvers import (
2627
AvatarSolver,
2728
FEMSolver,
@@ -205,12 +206,12 @@ def build(self):
205206
if isinstance(entity, HybridEntity):
206207
entity.build()
207208

208-
def reset(self, state, envs_idx=None):
209+
def reset(self, state: SimState, envs_idx=None):
209210
for solver, solver_state in zip(self._solvers, state):
210-
solver.set_state(0, solver_state, envs_idx)
211+
if solver.n_entities > 0:
212+
solver.set_state(0, solver_state, envs_idx)
211213

212-
# TODO: keeping as is for now, since coupler is currently for non-batched scenes
213-
self.coupler.reset()
214+
self.coupler.reset(envs_idx=envs_idx)
214215

215216
# TODO: keeping as is for now
216217
self.reset_grad()

0 commit comments

Comments
 (0)