Skip to content

Commit b6abad0

Browse files
committed
Fix compilation for tests; always create ConstraintSolver.contact_island
1 parent dc26c05 commit b6abad0

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

genesis/engine/solvers/rigid/constraint_solver_decomp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import genesis.utils.geom as gu
1010
import genesis.utils.array_class as array_class
1111
import genesis.engine.solvers.rigid.rigid_solver_decomp as rigid_solver
12+
from genesis.engine.solvers.rigid.contact_island import ContactIsland
1213

1314
if TYPE_CHECKING:
1415
from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver
@@ -160,6 +161,10 @@ def __init__(self, rigid_solver: "RigidSolver"):
160161

161162
self.reset()
162163

164+
# Creating a dummy ContactIsland, needed as param for some functions,
165+
# and not used when hibernation is not enabled.
166+
self.contact_island = ContactIsland(self._collider)
167+
163168
def clear(self, envs_idx: npt.NDArray[np.int32] | None = None):
164169
self._eq_const_info_cache.clear()
165170
if envs_idx is None:

genesis/engine/solvers/rigid/contact_island.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,25 @@ def __init__(self, collider: "Collider"):
3232
start=gs.ti_int,
3333
)
3434

35-
self.ci_edges = ti.field(
36-
dtype=gs.ti_int, shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None], 2))
37-
)
35+
max_num_contact_pairs = self.collider._collider_info._max_contact_pairs[None]
36+
max_num_contact_pairs = max(max_num_contact_pairs, 1) # can't create 0-sized fields
37+
38+
self.ci_edges = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_num_contact_pairs, 2)))
3839

3940
# maps half-edges (half-edges are referenced by entity_edge range) to actual edge index
4041
# description: half_edge_ref_to_edge_idx
4142
self.edge_id = ti.field(
4243
dtype=gs.ti_int,
43-
shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None] * 2)),
44+
shape=self.solver._batch_shape((max_num_contact_pairs * 2)),
4445
)
4546

4647
# maps collider_state.contact_data index to island idx
47-
self.constraint_list = ti.field(
48-
dtype=gs.ti_int, shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None]))
49-
)
48+
self.constraint_list = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_num_contact_pairs)))
5049

5150
# analogous to edge_id: maps island's constraint local-index to world's contact index
5251
self.constraint_id = ti.field(
5352
dtype=gs.ti_int,
54-
shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None] * 2)),
53+
shape=self.solver._batch_shape((max_num_contact_pairs * 2)),
5554
)
5655

5756
# per-entity range of half-edges (indexing into edge_id)

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -921,15 +921,6 @@ def _init_constraint_solver(self):
921921
def substep(self):
922922
# from genesis.utils.tools import create_timer
923923

924-
# Note: ContactIsland param is needed when supporting hibernation. But the attribute does not exist
925-
# in the solver when hibernation is disabled. In that case, we create a dummy ContactIsland object
926-
# needed for compilation, but not being used in the kernel.
927-
if not hasattr(self, "optional_contact_island"):
928-
if hasattr(self, "constraint_solver") and hasattr(self.constraint_solver, "contact_island"):
929-
self.optional_contact_island = self.constraint_solver.contact_island
930-
else:
931-
self.optional_contact_island = ContactIsland(self.collider)
932-
933924
# timer = create_timer("rigid", level=1, ti_sync=True, skip_first_call=True)
934925
kernel_step_1(
935926
links_state=self.links_state,
@@ -944,7 +935,7 @@ def substep(self):
944935
entities_info=self.entities_info,
945936
rigid_global_info=self._rigid_global_info,
946937
static_rigid_sim_config=self._static_rigid_sim_config,
947-
contact_island=self.optional_contact_island,
938+
contact_island=self.constraint_solver.contact_island,
948939
)
949940
# timer.stamp("kernel_step_1")
950941
self._func_constraint_force()
@@ -963,7 +954,7 @@ def substep(self):
963954
collider_state=self.collider._collider_state,
964955
rigid_global_info=self._rigid_global_info,
965956
static_rigid_sim_config=self._static_rigid_sim_config,
966-
contact_island=self.optional_contact_island,
957+
contact_island=self.constraint_solver.contact_island,
967958
)
968959
# timer.stamp("kernel_step_2")
969960

@@ -1015,6 +1006,7 @@ def _func_forward_dynamics(self):
10151006
geoms_state=self.geoms_state,
10161007
rigid_global_info=self._rigid_global_info,
10171008
static_rigid_sim_config=self._static_rigid_sim_config,
1009+
contact_island=self.constraint_solver.contact_island,
10181010
)
10191011

10201012
def _func_update_acc(self):
@@ -3023,6 +3015,7 @@ def kernel_forward_dynamics(
30233015
geoms_state: array_class.GeomsState,
30243016
rigid_global_info: array_class.RigidGlobalInfo,
30253017
static_rigid_sim_config: ti.template(),
3018+
contact_island: ti.template(), # ContactIsland
30263019
):
30273020
func_forward_dynamics(
30283021
links_state=links_state,
@@ -3035,6 +3028,7 @@ def kernel_forward_dynamics(
30353028
geoms_state=geoms_state,
30363029
rigid_global_info=rigid_global_info,
30373030
static_rigid_sim_config=static_rigid_sim_config,
3031+
contact_island=contact_island,
30383032
)
30393033

30403034

0 commit comments

Comments
 (0)