|
1 | 1 | from typing import TYPE_CHECKING |
| 2 | + |
2 | 3 | import numpy as np |
3 | 4 | import taichi as ti |
4 | 5 |
|
@@ -83,14 +84,30 @@ def build(self) -> None: |
83 | 84 | self._dx = 1 / 1024 |
84 | 85 | self._stencil_size = int(np.floor(self._dx / self.sph_solver.hash_grid_cell_size) + 2) |
85 | 86 |
|
86 | | - self.reset() |
| 87 | + self.reset(envs_idx=self.sim.scene._envs_idx) |
87 | 88 |
|
88 | | - def reset(self) -> None: |
| 89 | + def reset(self, envs_idx=None) -> None: |
89 | 90 | if self._rigid_mpm and self.mpm_solver.enable_CPIC: |
90 | | - self.mpm_rigid_normal.fill(0) |
| 91 | + if envs_idx is None: |
| 92 | + self.mpm_rigid_normal.fill(0) |
| 93 | + else: |
| 94 | + self._kernel_reset_mpm(envs_idx) |
91 | 95 |
|
92 | 96 | if self._rigid_sph: |
93 | | - self.sph_rigid_normal.fill(0) |
| 97 | + if envs_idx is None: |
| 98 | + self.sph_rigid_normal.fill(0) |
| 99 | + else: |
| 100 | + self._kernel_reset_sph(envs_idx) |
| 101 | + |
| 102 | + @ti.kernel |
| 103 | + def _kernel_reset_mpm(self, envs_idx: ti.types.ndarray()): |
| 104 | + for i_p, i_g, i_b_ in ti.ndrange(self.mpm_solver.n_particles, self.rigid_solver.n_geoms, envs_idx.shape[0]): |
| 105 | + self.mpm_rigid_normal[i_p, i_g, envs_idx[i_b_]] = 0.0 |
| 106 | + |
| 107 | + @ti.kernel |
| 108 | + def _kernel_reset_sph(self, envs_idx: ti.types.ndarray()): |
| 109 | + for i_p, i_g, i_b_ in ti.ndrange(self.sph_solver.n_particles, self.rigid_solver.n_geoms, envs_idx.shape[0]): |
| 110 | + self.sph_rigid_normal[i_p, i_g, envs_idx[i_b_]] = 0.0 |
94 | 111 |
|
95 | 112 | @ti.func |
96 | 113 | def _func_collide_with_rigid(self, f, pos_world, vel, mass, i_b): |
@@ -200,7 +217,7 @@ def mpm_grid_op(self, f: ti.i32, t: ti.f32): |
200 | 217 | vel_mpm = (1 / self.mpm_solver.grid[f, I, i_b].mass) * self.mpm_solver.grid[f, I, i_b].vel_in |
201 | 218 |
|
202 | 219 | # gravity |
203 | | - vel_mpm += self.mpm_solver.substep_dt * self.mpm_solver._gravity[None] |
| 220 | + vel_mpm += self.mpm_solver.substep_dt * self.mpm_solver._gravity[i_b] |
204 | 221 |
|
205 | 222 | pos = (I + self.mpm_solver.grid_offset) * self.mpm_solver.dx |
206 | 223 | mass_mpm = self.mpm_solver.grid[f, I, i_b].mass / self.mpm_solver._p_vol_scale |
@@ -667,7 +684,7 @@ def build(self) -> None: |
667 | 684 | self.init_pcg_fields() |
668 | 685 | self.init_linesearch_fields() |
669 | 686 |
|
670 | | - def reset(self): |
| 687 | + def reset(self, envs_idx=None) -> None: |
671 | 688 | pass |
672 | 689 |
|
673 | 690 | def init_fem_fields(self): |
|
0 commit comments