diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 95c8f70a12..6d9e883a56 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1,6 +1,6 @@ from copy import copy from itertools import chain -from typing import Literal +from typing import Literal, TYPE_CHECKING import numpy as np import taichi as ti @@ -25,6 +25,10 @@ from .rigid_joint import RigidJoint from .rigid_link import RigidLink +if TYPE_CHECKING: + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver + from genesis.engine.scene import Scene + @ti.data_oriented class RigidEntity(Entity): @@ -32,17 +36,20 @@ class RigidEntity(Entity): Entity class in rigid body systems. One rigid entity can be a robot, a terrain, a floating rigid body, etc. """ + # override typing + _solver: "RigidSolver" + def __init__( self, scene: "Scene", - solver: "Solver", + solver: "RigidSolver", material: Material, morph: Morph, surface: Surface, idx: int, idx_in_solver, - link_start=0, - joint_start=0, + link_start: int = 0, + joint_start: int = 0, q_start=0, dof_start=0, geom_start=0, @@ -55,13 +62,13 @@ def __init__( vvert_start=0, vface_start=0, equality_start=0, - visualize_contact=False, + visualize_contact: bool = False, ): super().__init__(idx, scene, morph, solver, material, surface) self._idx_in_solver = idx_in_solver - self._link_start = link_start - self._joint_start = joint_start + self._link_start: int = link_start + self._joint_start: int = joint_start self._q_start = q_start self._dof_start = dof_start self._geom_start = geom_start @@ -77,11 +84,11 @@ def __init__( self._base_links_idx = torch.tensor([self.base_link_idx], dtype=gs.tc_int, device=gs.device) - self._visualize_contact = visualize_contact + self._visualize_contact: bool = visualize_contact - self._is_free = morph.is_free + self._is_free: bool = morph.is_free - self._is_built = False + self._is_built: bool = False self._load_model() @@ -2853,7 +2860,7 @@ def equalities(self): return self._equalities @property - def is_free(self): + def is_free(self) -> bool: """Whether the entity is free to move.""" return self._is_free diff --git a/genesis/engine/entities/rigid_entity/rigid_geom.py b/genesis/engine/entities/rigid_entity/rigid_geom.py index 533340b92b..5bce714df7 100644 --- a/genesis/engine/entities/rigid_entity/rigid_geom.py +++ b/genesis/engine/entities/rigid_entity/rigid_geom.py @@ -4,6 +4,7 @@ import igl import numpy as np +from numpy.typing import NDArray import skimage import taichi as ti import torch @@ -16,7 +17,9 @@ from genesis.utils.misc import tensor_to_array if TYPE_CHECKING: + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver from genesis.engine.materials.rigid import Rigid as RigidMaterial + from genesis.engine.mesh import Mesh from .rigid_entity import RigidEntity from .rigid_link import RigidLink @@ -32,18 +35,18 @@ def __init__( self, link: "RigidLink", idx, - cell_start, - vert_start, - face_start, - edge_start, - verts_state_start, - mesh, - type, - friction, + cell_start: int, + vert_start: int, + face_start: int, + edge_start: int, + verts_state_start: int, + mesh: "Mesh", + type: gs.GEOM_TYPE, + friction: float, sol_params, init_pos, init_quat, - needs_coup, + needs_coup: bool, contype, conaffinity, center_init=None, @@ -52,30 +55,30 @@ def __init__( self._link: "RigidLink" = link self._entity: "RigidEntity" = link.entity self._material: "RigidMaterial" = link.entity.material - self._solver = link.entity.solver - self._mesh = mesh + self._solver: "RigidSolver" = link.entity.solver + self._mesh: "Mesh" = mesh self._uid = gs.UID() self._idx = idx - self._type = type - self._friction = friction + self._type: gs.GEOM_TYPE = type + self._friction: float = friction self._sol_params = sol_params - self._needs_coup = needs_coup + self._needs_coup: bool = needs_coup self._contype = contype self._conaffinity = conaffinity - self._is_convex = mesh.is_convex - self._cell_start = cell_start - self._vert_start = vert_start - self._face_start = face_start - self._edge_start = edge_start - self._verts_state_start = verts_state_start + self._is_convex: bool = mesh.is_convex + self._cell_start: int = cell_start + self._vert_start: int = vert_start + self._face_start: int = face_start + self._edge_start: int = edge_start + self._verts_state_start: int = verts_state_start - self._coup_softness = self._material.coup_softness - self._coup_friction = self._material.coup_friction - self._coup_restitution = self._material.coup_restitution + self._coup_softness: float = self._material.coup_softness + self._coup_friction: float = self._material.coup_friction + self._coup_restitution: float = self._material.coup_restitution - self._init_pos = init_pos - self._init_quat = init_quat + self._init_pos: np.ndarray = init_pos + self._init_quat: np.ndarray = init_quat self._init_verts = mesh.verts self._init_faces = mesh.faces @@ -455,14 +458,14 @@ def uid(self): return self._uid @property - def idx(self): + def idx(self) -> int: """ Get the global index of the geom in RigidSolver. """ return self._idx @property - def type(self): + def type(self) -> gs.GEOM_TYPE: """ Get the type of the geom. """ @@ -504,25 +507,25 @@ def entity(self) -> "RigidEntity": return self._entity @property - def solver(self): + def solver(self) -> "RigidSolver": """ Get the solver that the geom belongs to.s """ return self._solver @property - def is_convex(self): + def is_convex(self) -> bool: """ Get whether the geom is convex. """ return self._is_convex @property - def mesh(self): + def mesh(self) -> "Mesh": return self._mesh @property - def needs_coup(self): + def needs_coup(self) -> bool: """ Get whether the geom needs coupling with other non-rigid entities. """ @@ -550,35 +553,35 @@ def conaffinity(self): return self._conaffinity @property - def coup_softness(self): + def coup_softness(self) -> float: """ Get the softness coefficient of the geom for coupling. """ return self._coup_softness @property - def coup_friction(self): + def coup_friction(self) -> float: """ Get the friction coefficient of the geom for coupling. """ return self._coup_friction @property - def coup_restitution(self): + def coup_restitution(self) -> float: """ Get the restitution coefficient of the geom for coupling. """ return self._coup_restitution @property - def init_pos(self): + def init_pos(self) -> np.ndarray: """ Get the initial position of the geom. """ return self._init_pos @property - def init_quat(self): + def init_quat(self) -> np.ndarray: """ Get the initial quaternion of the geom. """ @@ -725,14 +728,14 @@ def n_cells(self): return np.prod(self._sdf_res) @property - def n_verts(self): + def n_verts(self) -> int: """ Number of vertices of the geom. """ return len(self._init_verts) @property - def n_faces(self): + def n_faces(self) -> int: """ Number of faces of the geom. """ diff --git a/genesis/engine/entities/rigid_entity/rigid_link.py b/genesis/engine/entities/rigid_entity/rigid_link.py index 85303e86ab..c1acd6ba10 100644 --- a/genesis/engine/entities/rigid_entity/rigid_link.py +++ b/genesis/engine/entities/rigid_entity/rigid_link.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING import numpy as np +from numpy.typing import ArrayLike, NDArray import taichi as ti import torch -from numpy.typing import ArrayLike import genesis as gs import trimesh @@ -26,30 +26,30 @@ class RigidLink(RBC): def __init__( self, - entity, - name, - idx, - joint_start, - n_joints, - geom_start, - cell_start, - vert_start, - face_start, - edge_start, - verts_state_start, - vgeom_start, - vvert_start, - vface_start, - pos, - quat, - inertial_pos, - inertial_quat, - inertial_i, - inertial_mass, - parent_idx, - root_idx, - invweight, - visualize_contact, + entity: "RigidEntity", + name: str, + idx: int, + joint_start: int, + n_joints: int, + geom_start: int, + cell_start: int, + vert_start: int, + face_start: int, + edge_start: int, + verts_state_start: int, + vgeom_start: int, + vvert_start: int, + vface_start: int, + pos: ArrayLike, + quat: ArrayLike, + inertial_pos: ArrayLike | None, + inertial_quat: ArrayLike | None, + inertial_i: ArrayLike | None, # may be None, eg. for plane; NDArray is 3x3 matrix + inertial_mass: float | None, # may be None, eg. for plane + parent_idx: int, + root_idx: int | None, + invweight: float | None, + visualize_contact: bool, ): self._name: str = name self._entity: "RigidEntity" = entity @@ -86,13 +86,13 @@ def __init__( if inertial_quat is not None: inertial_quat = np.asarray(inertial_quat, dtype=gs.np_float) self._inertial_quat: ArrayLike | None = inertial_quat - self._inertial_mass = inertial_mass - self._inertial_i = inertial_i + self._inertial_mass: float | None = inertial_mass + self._inertial_i: ArrayLike | None = inertial_i self._visualize_contact = visualize_contact self._geoms: list[RigidGeom] = gs.List() - self._vgeoms = gs.List() + self._vgeoms: list[RigidVisGeom] = gs.List() def _build(self): for geom in self._geoms: @@ -429,35 +429,35 @@ def uid(self): return self._uid @property - def name(self): + def name(self) -> str: """ The name of the link. """ return self._name @property - def entity(self): + def entity(self) -> "RigidEntity": """ The entity that the link belongs to. """ return self._entity @property - def solver(self): + def solver(self) -> "RigidSolver": """ The solver that the link belongs to. """ return self._solver @property - def visualize_contact(self): + def visualize_contact(self) -> bool: """ Whether to visualize the contact of the link. """ return self._visualize_contact @property - def joints(self): + def joints(self) -> list["Joint"]: """ The sequence of joints that connects the link to its parent link. """ @@ -606,84 +606,84 @@ def quat(self) -> ArrayLike: return self._quat @property - def inertial_pos(self): + def inertial_pos(self) -> ArrayLike | None: """ The initial position of the link's inertial frame. """ return self._inertial_pos @property - def inertial_quat(self): + def inertial_quat(self) -> ArrayLike | None: """ The initial quaternion of the link's inertial frame. """ return self._inertial_quat @property - def inertial_mass(self): + def inertial_mass(self) -> float | None: """ The initial mass of the link. """ return self._inertial_mass @property - def inertial_i(self): + def inertial_i(self) -> ArrayLike | None: """ The inerial matrix of the link. """ return self._inertial_i @property - def geoms(self): + def geoms(self) -> list[RigidGeom]: """ The list of the link's collision geometries (`RigidGeom`). """ return self._geoms @property - def vgeoms(self): + def vgeoms(self) -> list[RigidVisGeom]: """ The list of the link's visualization geometries (`RigidVisGeom`). """ return self._vgeoms @property - def n_geoms(self): + def n_geoms(self) -> int: """ Number of the link's collision geometries. """ return len(self._geoms) @property - def geom_start(self): + def geom_start(self) -> int: """ The start index of the link's collision geometries in the RigidSolver. """ return self._geom_start @property - def geom_end(self): + def geom_end(self) -> int: """ The end index of the link's collision geometries in the RigidSolver. """ return self._geom_start + self.n_geoms @property - def n_vgeoms(self): + def n_vgeoms(self) -> int: """ Number of the link's visualization geometries (`vgeom`). """ return len(self._vgeoms) @property - def vgeom_start(self): + def vgeom_start(self) -> int: """ The start index of the link's vgeom in the RigidSolver. """ return self._vgeom_start @property - def vgeom_end(self): + def vgeom_end(self) -> int: """ The end index of the link's vgeom in the RigidSolver. """ @@ -697,42 +697,42 @@ def n_cells(self): return sum([geom.n_cells for geom in self._geoms]) @property - def n_verts(self): + def n_verts(self) -> int: """ Number of vertices of all the link's geoms. """ return sum([geom.n_verts for geom in self._geoms]) @property - def n_vverts(self): + def n_vverts(self) -> int: """ Number of vertices of all the link's vgeoms. """ return sum([vgeom.n_vverts for vgeom in self._vgeoms]) @property - def n_faces(self): + def n_faces(self) -> int: """ Number of faces of all the link's geoms. """ return sum([geom.n_faces for geom in self._geoms]) @property - def n_vfaces(self): + def n_vfaces(self) -> int: """ Number of faces of all the link's vgeoms. """ return sum([vgeom.n_vfaces for vgeom in self._vgeoms]) @property - def n_edges(self): + def n_edges(self) -> int: """ Number of edges of all the link's geoms. """ return sum([geom.n_edges for geom in self._geoms]) @property - def is_built(self): + def is_built(self) -> bool: """ Whether the entity the link belongs to is built. """ diff --git a/genesis/engine/materials/rigid.py b/genesis/engine/materials/rigid.py index c0310d4786..a41bcd9d6d 100644 --- a/genesis/engine/materials/rigid.py +++ b/genesis/engine/materials/rigid.py @@ -88,51 +88,51 @@ def __init__( self._gravity_compensation = float(gravity_compensation) @property - def gravity_compensation(self): + def gravity_compensation(self) -> float: """Gravity compensation factor. 1.0 cancels gravity.""" return self._gravity_compensation @property - def friction(self): + def friction(self) -> float: """Friction coefficient used within the rigid solver.""" return self._friction @property - def needs_coup(self): + def needs_coup(self) -> bool: """Whether this material requires solver coupling.""" return self._needs_coup @property - def coup_friction(self): + def coup_friction(self) -> float: """Friction coefficient used in coupling interactions.""" return self._coup_friction @property - def coup_softness(self): + def coup_softness(self) -> float: """Softness parameter controlling the influence range of coupling.""" return self._coup_softness @property - def coup_restitution(self): + def coup_restitution(self) -> float: """Restitution coefficient used during contact in coupling.""" return self._coup_restitution @property - def sdf_cell_size(self): + def sdf_cell_size(self) -> float: """Size of each SDF grid cell in meters.""" return self._sdf_cell_size @property - def sdf_min_res(self): + def sdf_min_res(self) -> int: """Minimum allowed resolution for the SDF grid.""" return self._sdf_min_res @property - def sdf_max_res(self): + def sdf_max_res(self) -> int: """Maximum allowed resolution for the SDF grid.""" return self._sdf_max_res @property - def rho(self): + def rho(self) -> float: """Density of the rigid material.""" return self._rho diff --git a/genesis/engine/mesh.py b/genesis/engine/mesh.py index 950916668c..adbdb8681c 100644 --- a/genesis/engine/mesh.py +++ b/genesis/engine/mesh.py @@ -425,7 +425,7 @@ def trimesh(self): return self._mesh @property - def is_convex(self): + def is_convex(self) -> bool: """ Whether the mesh is convex. """ diff --git a/genesis/engine/scene.py b/genesis/engine/scene.py index b9387e598c..a15e66088d 100644 --- a/genesis/engine/scene.py +++ b/genesis/engine/scene.py @@ -1288,7 +1288,7 @@ def requires_grad(self): return self._sim.requires_grad @property - def is_built(self): + def is_built(self) -> bool: """Whether the scene has been built.""" return self._is_built diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index efd3b0b0ac..e6c0e66ad2 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -101,12 +101,12 @@ def __init__( self.sf_options = sf_options self.pbd_options = pbd_options - self._dt = options.dt - self._substep_dt = options.dt / options.substeps - self._substeps = options.substeps - self._substeps_local = options.substeps_local - self._requires_grad = options.requires_grad - self._steps_local = options._steps_local + self._dt: float = options.dt + self._substep_dt: float = options.dt / options.substeps + self._substeps: int = options.substeps + self._substeps_local: int | None = options.substeps_local + self._requires_grad: bool = options.requires_grad + self._steps_local: int | None = options._steps_local self._cur_substep_global = 0 self._gravity = np.array(options.gravity, dtype=gs.np_float) @@ -419,7 +419,7 @@ def set_gravity(self, gravity, envs_idx=None): # ------------------------------------------------------------------------------------ @property - def dt(self): + def dt(self) -> float: """The time duration for each simulation step.""" return self._dt @@ -444,7 +444,7 @@ def requires_grad(self): return self._requires_grad @property - def n_entities(self): + def n_entities(self) -> int: """The number of entities in the simulator.""" return len(self._entities) diff --git a/genesis/engine/solvers/rigid/collider_decomp.py b/genesis/engine/solvers/rigid/collider_decomp.py index 27e3f601dc..8edea10e2d 100644 --- a/genesis/engine/solvers/rigid/collider_decomp.py +++ b/genesis/engine/solvers/rigid/collider_decomp.py @@ -275,7 +275,7 @@ def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None: def clear(self, envs_idx=None): if envs_idx is None: envs_idx = self._solver._scene._envs_idx - collider_kernel_clear( + kernel_collider_clear( envs_idx, self._solver.links_state, self._solver.links_info, @@ -494,8 +494,9 @@ def collider_kernel_reset( collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3) +# only used with hibernation ?? @ti.kernel -def collider_kernel_clear( +def kernel_collider_clear( envs_idx: ti.types.ndarray(), links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -511,8 +512,8 @@ def collider_kernel_clear( # advect hibernated contacts for i_c in range(collider_state.n_contacts[i_b]): - i_la = collider_state.contact_data[i_c, i_b].link_a - i_lb = collider_state.contact_data[i_c, i_b].link_b + i_la = collider_state.contact_data.link_a[i_c, i_b] + i_lb = collider_state.contact_data.link_b[i_c, i_b] I_la = [i_la, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_la I_lb = [i_lb, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_lb @@ -525,7 +526,20 @@ def collider_kernel_clear( ): i_c_hibernated = collider_state.n_contacts_hibernated[i_b] if i_c != i_c_hibernated: - collider_state.contact_data[i_c_hibernated, i_b] = collider_state.contact_data[i_c, i_b] + # Copying all fields of class StructContactData: + # fmt: off + collider_state.contact_data.geom_a[i_c_hibernated, i_b] = collider_state.contact_data.geom_a[i_c, i_b] + collider_state.contact_data.geom_b[i_c_hibernated, i_b] = collider_state.contact_data.geom_b[i_c, i_b] + collider_state.contact_data.penetration[i_c_hibernated, i_b] = collider_state.contact_data.penetration[i_c, i_b] + collider_state.contact_data.normal[i_c_hibernated, i_b] = collider_state.contact_data.normal[i_c, i_b] + collider_state.contact_data.pos[i_c_hibernated, i_b] = collider_state.contact_data.pos[i_c, i_b] + collider_state.contact_data.friction[i_c_hibernated, i_b] = collider_state.contact_data.friction[i_c, i_b] + collider_state.contact_data.sol_params[i_c_hibernated, i_b] = collider_state.contact_data.sol_params[i_c, i_b] + collider_state.contact_data.force[i_c_hibernated, i_b] = collider_state.contact_data.force[i_c, i_b] + collider_state.contact_data.link_a[i_c_hibernated, i_b] = collider_state.contact_data.link_a[i_c, i_b] + collider_state.contact_data.link_b[i_c_hibernated, i_b] = collider_state.contact_data.link_b[i_c, i_b] + # fmt: on + collider_state.n_contacts_hibernated[i_b] = i_c_hibernated + 1 collider_state.n_contacts[i_b] = collider_state.n_contacts_hibernated[i_b] @@ -1258,7 +1272,7 @@ def func_broad_phase( if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info): # Clear collision normal cache if not in contact - if ti.static(not static_rigid_sim_config._enable_mujoco_compatibility): + if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): # self.contact_cache[i_ga, i_gb, i_b].i_va_ws = -1 collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero( gs.ti_float, 3 diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index 430e4c42b1..df50380f33 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -9,6 +9,7 @@ import genesis.utils.geom as gu import genesis.utils.array_class as array_class import genesis.engine.solvers.rigid.rigid_solver_decomp as rigid_solver +from genesis.engine.solvers.rigid.contact_island import ContactIsland if TYPE_CHECKING: from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver @@ -160,6 +161,10 @@ def __init__(self, rigid_solver: "RigidSolver"): self.reset() + # Creating a dummy ContactIsland, needed as param for some functions, + # and not used when hibernation is not enabled. + self.contact_island = ContactIsland(self._collider) + def clear(self, envs_idx: npt.NDArray[np.int32] | None = None): self._eq_const_info_cache.clear() if envs_idx is None: diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py index 1f88338b14..3831ecb28a 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py @@ -1,26 +1,35 @@ +from typing import TYPE_CHECKING + import numpy as np import taichi as ti import genesis as gs import genesis.utils.geom as gu +import genesis.utils.array_class as array_class + from .contact_island import ContactIsland +from .rigid_solver_decomp_util import func_wakeup_entity_and_its_temp_island + +if TYPE_CHECKING: + from genesis.engine.colliders.collider import Collider + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver @ti.data_oriented class ConstraintSolverIsland: - def __init__(self, rigid_solver): - self._solver = rigid_solver - self._collider = rigid_solver.collider - self._B = rigid_solver._B - self._para_level = rigid_solver._para_level - - self._solver_type = rigid_solver._options.constraint_solver - self.iterations = rigid_solver._options.iterations - self.tolerance = rigid_solver._options.tolerance - self.ls_iterations = rigid_solver._options.ls_iterations - self.ls_tolerance = rigid_solver._options.ls_tolerance - self.sparse_solve = True + def __init__(self, rigid_solver: "RigidSolver"): + self._solver: "RigidSolver" = rigid_solver + self._collider: "Collider" = rigid_solver.collider + self._B: int = rigid_solver._B + self._para_level: gs.PARA_LEVEL = rigid_solver._para_level + + self._solver_type: gs.constraint_solver = rigid_solver._options.constraint_solver + self.iterations: int = rigid_solver._options.iterations + self.tolerance: float = rigid_solver._options.tolerance + self.ls_iterations: int = rigid_solver._options.ls_iterations + self.ls_tolerance: float = rigid_solver._options.ls_tolerance + self.sparse_solve: bool = True # 4 constraints per contact and 1 constraints per joint limit (upper and lower, if not inf) self.len_constraints = ( @@ -29,6 +38,8 @@ def __init__(self, rigid_solver): ) self.len_constraints_ = max(1, self.len_constraints) + self.constraint_state = array_class.get_constraint_state(self, self._solver) + self.jac = ti.field( dtype=gs.ti_float, shape=self._solver._batch_shape((self.len_constraints_, self._solver.n_dofs_)) ) @@ -110,36 +121,37 @@ def _kernel_clear(self, envs_idx: ti.types.ndarray()): @ti.kernel def resolve(self): for i_b in range(self._B): - for island in range(self.contact_island.n_island[i_b]): + for i_island in range(self.contact_island.n_islands[i_b]): is_active = True if ti.static(self._solver._use_hibernation): - is_active = not self.contact_island.island_hibernated[island, i_b] + is_active = not self.contact_island.island_hibernated[i_island, i_b] if is_active: - self.add_collision_constraints(island, i_b) - self.add_joint_limit_constraints(island, i_b) - self._func_init_solver(island, i_b) - self._func_solve(island, i_b) - self._func_update_qacc(island, i_b) - self._func_update_contact_force(island, i_b) + self.add_collision_constraints_and_wakeup_entities(i_island, i_b) + self.add_joint_limit_constraints(i_island, i_b) + self._func_init_solver(i_island, i_b) + self._func_solve(i_island, i_b) + self._func_update_qacc(i_island, i_b) + self._func_update_contact_force(i_island, i_b) def handle_constraints(self): self.contact_island.construct() self.resolve() @ti.func - def add_collision_constraints(self, island, i_b): + def add_collision_constraints_and_wakeup_entities(self, i_island: int, i_b: int): self.n_constraints[i_b] = 0 - for i_island_col in range(self.contact_island.island_col[island, i_b].n): - i_col_ = self.contact_island.island_col[island, i_b].start + i_island_col + for i_island_col in range(self.contact_island.island_col[i_island, i_b].n): + i_col_ = self.contact_island.island_col[i_island, i_b].start + i_island_col i_col = self.contact_island.constraint_id[i_col_, i_b] - contact_data = self._collider._collider_state.contact_data[i_col, i_b] - link_a = contact_data.link_a - link_b = contact_data.link_b + # get links indices of the contact_data + link_a = self._collider._collider_state.contact_data.link_a[i_col, i_b] + link_b = self._collider._collider_state.contact_data.link_b[i_col, i_b] link_a_maybe_batch = [link_a, i_b] if ti.static(self._solver._options.batch_links_info) else link_a link_b_maybe_batch = [link_b, i_b] if ti.static(self._solver._options.batch_links_info) else link_b - d1, d2 = gu.ti_orthogonals(contact_data.normal) + contact_normal = self._collider._collider_state.contact_data.normal[i_col, i_b] + d1, d2 = gu.ti_orthogonals(contact_normal) invweight = self._solver.links_info.invweight[link_a_maybe_batch][0] + self._solver.links_info.invweight[ link_b_maybe_batch @@ -147,7 +159,8 @@ def add_collision_constraints(self, island, i_b): for i in range(4): d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) - n = d * contact_data.friction - contact_data.normal + contact_friction = self._collider._collider_state.contact_data.friction[i_col, i_b] + n = d * contact_friction - contact_normal n_con = ti.atomic_add(self.n_constraints[i_b], 1) if ti.static(self.sparse_solve): @@ -178,7 +191,8 @@ def add_collision_constraints(self, island, i_b): cdot_vel = self._solver.dofs_state.cdof_vel[i_d, i_b] t_quat = gu.ti_identity_quat() - t_pos = contact_data.pos - self._solver.links_state.COM[link, i_b] + contact_pos = self._collider._collider_state.contact_data.pos[i_col, i_b] + t_pos = contact_pos - self._solver.links_state.COM[link, i_b] _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) diff = sign * vel @@ -194,12 +208,12 @@ def add_collision_constraints(self, island, i_b): if ti.static(self.sparse_solve): self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs - imp, aref = gu.imp_aref( - contact_data.sol_params, -contact_data.penetration, jac_qvel, -contact_data.penetration - ) + contact_sol_params = self._collider._collider_state.contact_data.sol_params[i_col, i_b] + contact_penetration = self._collider._collider_state.contact_data.penetration[i_col, i_b] + imp, aref = gu.imp_aref(contact_sol_params, -contact_penetration, jac_qvel, -contact_penetration) - diag = invweight + contact_data.friction * contact_data.friction * invweight - diag *= 2 * contact_data.friction * contact_data.friction * (1 - imp) / ti.max(imp, gs.EPS) + diag = invweight + contact_friction * contact_friction * invweight + diag *= 2 * contact_friction * contact_friction * (1 - imp) / ti.max(imp, gs.EPS) self.diag[n_con, i_b] = diag self.aref[n_con, i_b] = aref @@ -207,22 +221,40 @@ def add_collision_constraints(self, island, i_b): self.efc_D[n_con, i_b] = 1 / ti.max(diag, gs.EPS) if ti.static(self._solver._use_hibernation): - # wake up entities - self._solver._func_wakeup_entity(self._solver.links_info[link_a_maybe_batch].entity_idx, i_b) - self._solver._func_wakeup_entity(self._solver.links_info[link_b_maybe_batch].entity_idx, i_b) + entity_idx_a = self._solver.links_info.entity_idx[link_a_maybe_batch] + entity_idx_b = self._solver.links_info.entity_idx[link_b_maybe_batch] + + is_entity_a_hibernated = self._solver.entities_state.hibernated[entity_idx_a, i_b] + is_entity_b_hibernated = self._solver.entities_state.hibernated[entity_idx_b, i_b] + if is_entity_a_hibernated or is_entity_b_hibernated: + # wake up entities + any_hibernated_entity_idx = entity_idx_a if is_entity_a_hibernated else entity_idx_b + + func_wakeup_entity_and_its_temp_island( + any_hibernated_entity_idx, + i_b, + self._solver.entities_state, + self._solver.entities_info, + self._solver.dofs_state, + self._solver.links_state, + self._solver.geoms_state, + self._solver.data_manager.rigid_global_info, + self.contact_island, + ) @ti.func - def add_joint_limit_constraints(self, island, i_b): - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): + def add_joint_limit_constraints(self, i_island: int, i_b: int): + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] for i_l in range(self.entities_info.link_start[i_e], self.entities_info.link_end[i_e]): I_l = [i_l, i_b] if ti.static(self._solver._options.batch_links_info) else i_l - l_info = self._solver.links_info[I_l] + l_info_start = self._solver.links_info.joint_start[I_l] + l_info_end = self._solver.links_info.joint_end[I_l] - for i_j in range(l_info.joint_start, l_info.joint_end): + for i_j in range(l_info_start, l_info_end): I_j = [i_j, i_b] if ti.static(self._solver._options.batch_joints_info) else i_j if ( @@ -504,59 +536,63 @@ def _kernel_reset(self, envs_idx: ti.types.ndarray()): # timer.stamp('compute force') @ti.func - def _func_update_contact_force(self, island, i_b): - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + def _func_update_contact_force(self, i_island: int, i_b: int): + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] for i_l in range(self.entities_info.link_start[i_e], self.entities_info.link_end[i_e]): self._solver.links_state.contact_force[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - for i_island_col in range(self.contact_island.island_col[island, i_b].n): - i_col_ = self.contact_island.island_col[island, i_b].start + i_island_col + for i_island_col in range(self.contact_island.island_col[i_island, i_b].n): + i_col_ = self.contact_island.island_col[i_island, i_b].start + i_island_col i_col = self.contact_island.constraint_id[i_col_, i_b] - contact_data = self._collider._collider_state.contact_data[i_col, i_b] + contact_normal = self._collider._collider_state.contact_data.normal[i_col, i_b] + contact_friction = self._collider._collider_state.contact_data.friction[i_col, i_b] force = ti.Vector.zero(gs.ti_float, 3) - d1, d2 = gu.ti_orthogonals(contact_data.normal) + d1, d2 = gu.ti_orthogonals(contact_normal) for i in range(4): d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) - n = d * contact_data.friction - contact_data.normal + n = d * contact_friction - contact_normal force += n * self.efc_force[i_island_col * 4 + i, i_b] - self._collider._collider_state.contact_data[i_col, i_b].force = force + self._collider._collider_state.contact_data.force[i_col, i_b] = force - self._solver.links_state.contact_force[contact_data.link_a, i_b] = ( - self._solver.links_state.contact_force[contact_data.link_a, i_b] - force + link_a = self._collider._collider_state.contact_data.link_a[i_col, i_b] + link_b = self._collider._collider_state.contact_data.link_b[i_col, i_b] + + self._solver.links_state.contact_force[link_a, i_b] = ( + self._solver.links_state.contact_force[link_a, i_b] - force ) - self._solver.links_state.contact_force[contact_data.link_b, i_b] = ( - self._solver.links_state.contact_force[contact_data.link_b, i_b] + force + self._solver.links_state.contact_force[link_b, i_b] = ( + self._solver.links_state.contact_force[link_b, i_b] + force ) @ti.func - def _func_update_qacc(self, island, i_b): - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + def _func_update_qacc(self, i_island: int, i_b: int): + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] for i_d in range(self.entities_info.dof_start[i_e], self.entities_info.dof_end[i_e]): self._solver.dofs_state.acc[i_d, i_b] = self.qacc[i_d, i_b] self.qacc_ws[i_d, i_b] = self.qacc[i_d, i_b] @ti.func - def _func_solve(self, island, i_b): + def _func_solve(self, i_island: int, i_b: int): # this safeguard seems not necessary in normal execution # if self.n_constraints[i_b] > 0 or self.cost_ws[i_b] < self.cost[i_b]: if self.n_constraints[i_b] > 0: tol_scaled = (self._solver.meaninertia[i_b] * ti.max(1, self._solver.n_dofs)) * self.tolerance for it in range(self.iterations): - self._func_solve_body(island, i_b) + self._func_solve_body(i_island, i_b) if self.improved[i_b] < 1: break gradient = gs.ti_float(0.0) n_dof = 0 - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] n_dof = n_dof + self.entities_info.dof_end[i_e] - self.entities_info.dof_start[i_e] for i_d in range(self.entities_info.dof_start[i_e], self.entities_info.dof_end[i_e]): @@ -1011,18 +1047,18 @@ def initialize_Ma(self, Ma, qacc, island, i_b): Ma[i_d1, i_b] = Ma_ @ti.func - def _func_init_solver(self, island, i_b): + def _func_init_solver(self, i_island: int, i_b: int): # check if warm start self.initialize_Jaref(self.qacc_ws, i_b) - self.initialize_Ma(self.Ma_ws, self.qacc_ws, island, i_b) - self._func_update_constraint(island, i_b, self.qacc_ws, self.Ma_ws, self.cost_ws) + self.initialize_Ma(self.Ma_ws, self.qacc_ws, i_island, i_b) + self._func_update_constraint(i_island, i_b, self.qacc_ws, self.Ma_ws, self.cost_ws) self.initialize_Jaref(self._solver.dofs_state.acc, i_b) - self.initialize_Ma(self.Ma, self._solver.dofs_state.acc, island, i_b) - self._func_update_constraint(island, i_b, self._solver.dofs_state.acc, self.Ma, self.cost) + self.initialize_Ma(self.Ma, self._solver.dofs_state.acc, i_island, i_b) + self._func_update_constraint(i_island, i_b, self._solver.dofs_state.acc, self.Ma, self.cost) - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] for i_d in range(self.entities_info.dof_start[i_e], self.entities_info.dof_end[i_e]): if self.cost_ws[i_b] < self.cost[i_b]: @@ -1033,15 +1069,15 @@ def _func_init_solver(self, island, i_b): self.initialize_Jaref(self.qacc, i_b) # end warm start - self._func_update_constraint(island, i_b, self.qacc, self.Ma, self.cost) + self._func_update_constraint(i_island, i_b, self.qacc, self.Ma, self.cost) if ti.static(self._solver_type == gs.constraint_solver.Newton): - self._func_nt_hessian_direct(island, i_b) + self._func_nt_hessian_direct(i_island, i_b) - self._func_update_gradient(island, i_b) + self._func_update_gradient(i_island, i_b) - for i_island_entity in range(self.contact_island.island_entity[island, i_b].n): - i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity + for i_island_entity in range(self.contact_island.island_entity[i_island, i_b].n): + i_e_ = self.contact_island.island_entity[i_island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] for i_d in range(self.entities_info.dof_start[i_e], self.entities_info.dof_end[i_e]): self.search[i_d, i_b] = -self.Mgrad[i_d, i_b] diff --git a/genesis/engine/solvers/rigid/contact_island.py b/genesis/engine/solvers/rigid/contact_island.py index 7e5bc887f8..323a8396eb 100644 --- a/genesis/engine/solvers/rigid/contact_island.py +++ b/genesis/engine/solvers/rigid/contact_island.py @@ -1,15 +1,23 @@ +from typing import TYPE_CHECKING + import numpy as np import taichi as ti import genesis as gs import genesis.utils.geom as gu +if TYPE_CHECKING: + from genesis.engine.solvers.rigid.collider_decomp import Collider + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver + +INVALID_NEXT_HIBERNATED_ENTITY_IDX = -1 + @ti.data_oriented class ContactIsland: - def __init__(self, collider): - self.solver = collider._solver - self.collider = collider + def __init__(self, collider: "Collider"): + self.solver: "RigidSolver" = collider._solver + self.collider: "Collider" = collider struct_agg_list = ti.types.struct( curr=gs.ti_int, @@ -17,47 +25,59 @@ def __init__(self, collider): start=gs.ti_int, ) - self.ci_edges = ti.field( - dtype=gs.ti_int, shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None], 2)) - ) + max_contact_pairs = self.collider._collider_info._max_contact_pairs[None] + max_contact_pairs = max(max_contact_pairs, 1) # can't create 0-sized fields - self.edge_id = ti.field( - dtype=gs.ti_int, - shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None] * 2)), - ) + self.ci_edges = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_contact_pairs, 2))) - self.constraint_list = ti.field( - dtype=gs.ti_int, shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None])) - ) + # maps half-edges (half-edges are referenced by entity_edge range) to actual edge index + # description: half_edge_ref_to_edge_idx + self.edge_id = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_contact_pairs * 2))) - self.constraint_id = ti.field( - dtype=gs.ti_int, - shape=self.solver._batch_shape((self.collider._collider_info._max_contact_pairs[None] * 2)), - ) + # maps collider_state.contact_data index to island idx + self.constraint_list = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_contact_pairs))) + # analogous to edge_id: maps island's constraint local-index to world's contact index + self.constraint_id = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((max_contact_pairs * 2))) + + # per-entity range of half-edges (indexing into edge_id) + # description: entity_idx_to_half_edge_ref_range self.entity_edge = struct_agg_list.field( shape=self.solver._batch_shape(self.solver.n_entities), needs_grad=False, layout=ti.Layout.SOA ) + # records number of collision edges per island + # description: island_idx_to_contact_ref_range self.island_col = struct_agg_list.field( shape=self.solver._batch_shape(self.solver.n_entities), needs_grad=False, layout=ti.Layout.SOA ) self.island_hibernated = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((self.solver.n_entities))) + # description: island_idx_to_entity_ref_range self.island_entity = struct_agg_list.field( shape=self.solver._batch_shape(self.solver.n_entities), needs_grad=False, layout=ti.Layout.SOA ) + # map per-island entity local-index to world's entity index + # description: entity_ref_to_entity_idx self.entity_id = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape((self.solver.n_entities))) - self.n_edge = ti.field(dtype=gs.ti_int, shape=self.solver._B) - self.n_island = ti.field(dtype=gs.ti_int, shape=self.solver._B) + # num all collision edges in world + self.n_edges = ti.field(dtype=gs.ti_int, shape=self.solver._B) + self.n_islands = ti.field(dtype=gs.ti_int, shape=self.solver._B) self.n_stack = ti.field(dtype=gs.ti_int, shape=self.solver._B) + # description: entity_idx_to_island_idx self.entity_island = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape(self.solver.n_entities)) self.stack = ti.field(dtype=gs.ti_int, shape=self.solver._batch_shape(self.solver.n_entities)) + # Used to make islands persist through hibernation: + self.entity_idx_to_next_entity_idx_in_hibernated_island = ti.field( + dtype=gs.ti_int, shape=self.solver._batch_shape(self.solver.n_entities) + ) + self.entity_idx_to_next_entity_idx_in_hibernated_island.fill(INVALID_NEXT_HIBERNATED_ENTITY_IDX) + @ti.kernel def clear(self): ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) @@ -69,57 +89,76 @@ def clear(self): ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): - self.n_edge[i_b] = 0 - self.n_island[i_b] = 0 + self.n_edges[i_b] = 0 + self.n_islands[i_b] = 0 @ti.func def add_edge(self, link_a, link_b, i_b): link_a_maybe_batch = [link_a, i_b] if ti.static(self.solver._options.batch_links_info) else link_a link_b_maybe_batch = [link_b, i_b] if ti.static(self.solver._options.batch_links_info) else link_b - ea = self.solver.links_info[link_a_maybe_batch].entity_idx - eb = self.solver.links_info[link_b_maybe_batch].entity_idx + ea = self.solver.links_info.entity_idx[link_a_maybe_batch] + eb = self.solver.links_info.entity_idx[link_b_maybe_batch] + # update num edges per entity self.entity_edge[ea, i_b].n = self.entity_edge[ea, i_b].n + 1 self.entity_edge[eb, i_b].n = self.entity_edge[eb, i_b].n + 1 - n_edge = self.n_edge[i_b] + # fill in collider-info edges with indices to connected entities. + n_edge = self.n_edges[i_b] self.ci_edges[n_edge, 0, i_b] = ea self.ci_edges[n_edge, 1, i_b] = eb - self.n_edge[i_b] = n_edge + 1 + self.n_edges[i_b] = n_edge + 1 @ti.kernel - def add_island(self): + def add_contact_edges_to_islands(self): ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): for i_col in range(self.collider._collider_state.n_contacts[i_b]): - impact = self.collider._collider_state.contact_data[i_col, i_b] - self.add_edge(impact.link_a, impact.link_b, i_b) + # get links indices of the impact + link_a = self.collider._collider_state.contact_data.link_a[i_col, i_b] + link_b = self.collider._collider_state.contact_data.link_b[i_col, i_b] + self.add_edge(link_a, link_b, i_b) + + @ti.kernel + def add_hiberanted_edges_to_islands(self): + _B = self.solver._B + n_entities = self.solver.n_entities + ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + for i_e in range(n_entities): + next_entity_idx = self.entity_idx_to_next_entity_idx_in_hibernated_island[i_e, i_b] + if next_entity_idx != INVALID_NEXT_HIBERNATED_ENTITY_IDX and next_entity_idx != i_e: + any_link_a = self.solver.entities_info.link_start[i_e] + any_link_b = self.solver.entities_info.link_start[next_entity_idx] + self.add_edge(any_link_a, any_link_b, i_b) def construct(self): self.clear() - self.add_island() - self.preprocess_island() - self.construct_island() - self.postprocess_island() + self.add_contact_edges_to_islands() + self.add_hiberanted_edges_to_islands() + self.preprocess_island_and_map_entities_to_edges() + self.construct_islands() + self.postprocess_island_and_assign_contact_data() @ti.kernel - def postprocess_island(self): + def postprocess_island_and_assign_contact_data(self): ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): for i_col in range(self.collider._collider_state.n_contacts[i_b]): - impact = self.collider._collider_state.contact_data[i_col, i_b] - link_a = impact.link_a - link_b = impact.link_b + # get links indices of the impact + link_a = self.collider._collider_state.contact_data.link_a[i_col, i_b] + link_b = self.collider._collider_state.contact_data.link_b[i_col, i_b] link_a_maybe_batch = [link_a, i_b] if ti.static(self.solver._options.batch_links_info) else link_a link_b_maybe_batch = [link_b, i_b] if ti.static(self.solver._options.batch_links_info) else link_b - ea = self.solver.links_info[link_a_maybe_batch].entity_idx - eb = self.solver.links_info[link_b_maybe_batch].entity_idx + ea = self.solver.links_info.entity_idx[link_a_maybe_batch] + eb = self.solver.links_info.entity_idx[link_b_maybe_batch] island_a = self.entity_island[ea, i_b] island_b = self.entity_island[eb, i_b] + # handle collisions between dynamic and fixed entities (island_idx == -1) island = island_a if island_a == -1: island = island_b @@ -128,7 +167,7 @@ def postprocess_island(self): self.constraint_list[i_col, i_b] = island constraint_list_start = 0 - for i in range(self.n_island[i_b]): + for i in range(self.n_islands[i_b]): self.island_col[i, i_b].start = constraint_list_start constraint_list_start = constraint_list_start + self.island_col[i, i_b].n self.island_col[i, i_b].curr = self.island_col[i, i_b].start @@ -150,7 +189,7 @@ def postprocess_island(self): self.island_hibernated[self.entity_island[i, i_b], i_b] = 0 entity_list_start = 0 - for i in range(self.n_island[i_b]): + for i in range(self.n_islands[i_b]): self.island_entity[i, i_b].start = entity_list_start self.island_entity[i, i_b].curr = self.island_entity[i, i_b].start entity_list_start = entity_list_start + self.island_entity[i, i_b].n @@ -162,19 +201,21 @@ def postprocess_island(self): self.island_entity[island, i_b].curr = self.island_entity[island, i_b].curr + 1 @ti.kernel - def preprocess_island(self): + def preprocess_island_and_map_entities_to_edges(self): ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): entity_list_start = 0 for i in range(self.solver.n_entities): self.entity_edge[i, i_b].start = entity_list_start + self.entity_edge[i, i_b].curr = entity_list_start entity_list_start = entity_list_start + self.entity_edge[i, i_b].n - self.entity_edge[i, i_b].curr = self.entity_edge[i, i_b].start - for i in range(self.n_edge[i_b]): + # process added collider-info edges + for i in range(self.n_edges[i_b]): ea = self.ci_edges[i, 0, i_b] eb = self.ci_edges[i, 1, i_b] + # map entity's half-edge index to edge index. self.edge_id[self.entity_edge[ea, i_b].curr, i_b] = i self.edge_id[self.entity_edge[eb, i_b].curr, i_b] = i @@ -182,41 +223,53 @@ def preprocess_island(self): self.entity_edge[eb, i_b].curr = self.entity_edge[eb, i_b].curr + 1 @ti.kernel - def construct_island(self): + def construct_islands(self): + """ + This assigns entities to islands, by setting their entity_island[entity_idx, batch_idx] = island_idx. + """ ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): for i_v in range(self.solver.n_entities): + # only create islands for entities with collisions and with dofs if self.entity_edge[i_v, i_b].n > 0 and self.solver.entities_info.n_dofs[i_v] > 0: if self.entity_island[i_v, i_b] != -1: continue self.n_stack[i_b] = 0 self.stack[self.n_stack[i_b], i_b] = i_v self.n_stack[i_b] = self.n_stack[i_b] + 1 + self.entity_island[i_v, i_b] = self.n_islands[i_b] + # FIXME: Add proper mechanism to detection overflow in Taichi-scope + # but raise exception in Python-scope while self.n_stack[i_b] > 0: self.n_stack[i_b] = self.n_stack[i_b] - 1 v = self.stack[self.n_stack[i_b], i_b] - if self.entity_island[v, i_b] != -1: - continue - self.entity_island[v, i_b] = self.n_island[i_b] for i_edge in range(self.entity_edge[v, i_b].n): - _id = self.entity_edge[v, i_b].start + i_edge - edge = self.edge_id[_id, i_b] - next_v = self.ci_edges[edge, 0, i_b] + _id = self.entity_edge[v, i_b].start + i_edge # half-edge index + edge = self.edge_id[_id, i_b] # edge index + next_v = self.ci_edges[edge, 0, i_b] # other entity index, connected by edge if next_v == v: next_v = self.ci_edges[edge, 1, i_b] - if self.solver.entities_info.n_dofs[next_v] > 0 and next_v != v: + if ( + self.solver.entities_info.n_dofs[next_v] > 0 + and next_v != v + and self.entity_island[next_v, i_b] == -1 + ): # 2nd condition must not happen ? self.stack[self.n_stack[i_b], i_b] = next_v self.n_stack[i_b] = self.n_stack[i_b] + 1 + self.entity_island[next_v, i_b] = self.n_islands[i_b] + # FIXME: Add proper mechanism to detection overflow in Taichi-scope + # but raise exception in Python-scope - self.n_island[i_b] = self.n_island[i_b] + 1 + self.n_islands[i_b] = self.n_islands[i_b] + 1 + # create single-entity islands for entities without collisions if self.solver._enable_joint_limit: ti.loop_config(serialize=self.solver._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self.solver._B): for i_v in range(self.solver.n_entities): if self.solver.entities_info.n_dofs[i_v] > 0 and self.entity_island[i_v, i_b] == -1: - self.entity_island[i_v, i_b] = self.n_island[i_b] - self.n_island[i_b] = self.n_island[i_b] + 1 + self.entity_island[i_v, i_b] = self.n_islands[i_b] + self.n_islands[i_b] = self.n_islands[i_b] + 1 diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 73b0adb5f0..9c112d0a2e 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1,29 +1,35 @@ -from typing import Literal, TYPE_CHECKING from dataclasses import dataclass +from typing import Literal, TYPE_CHECKING import numpy as np -import torch import numpy.typing as npt import taichi as ti +import torch import genesis as gs -from genesis.engine.entities.base_entity import Entity -from genesis.options.solvers import RigidOptions import genesis.utils.geom as gu -from genesis.utils import linalg as lu -from genesis.utils.misc import ti_field_to_torch, DeprecationError, ALLOCATE_TENSOR_WARNING +import genesis.utils.array_class as array_class + from genesis.engine.entities import AvatarEntity, DroneEntity, RigidEntity +from genesis.engine.entities.base_entity import Entity +from genesis.engine.solvers.rigid.contact_island import ContactIsland from genesis.engine.states.solvers import RigidSolverState +from genesis.options.solvers import RigidOptions from genesis.styles import colors, formats -import genesis.utils.array_class as array_class +from genesis.utils import linalg as lu +from genesis.utils.misc import ti_field_to_torch, DeprecationError, ALLOCATE_TENSOR_WARNING +from ....utils.sdf_decomp import SDF from ..base_solver import Solver -from .collider_decomp import Collider from .constraint_solver_decomp import ConstraintSolver from .constraint_solver_decomp_island import ConstraintSolverIsland -from ....utils.sdf_decomp import SDF +from .contact_island import INVALID_NEXT_HIBERNATED_ENTITY_IDX +from .collider_decomp import Collider +from .rigid_solver_decomp_util import func_wakeup_entity_and_its_temp_island if TYPE_CHECKING: + import genesis.engine.solvers.rigid.array_class + from genesis.engine.scene import Scene from genesis.engine.simulator import Simulator @@ -67,6 +73,9 @@ def _sanitize_sol_params( @ti.data_oriented class RigidSolver(Solver): + # override typing + _entities: list[RigidEntity] = gs.List() + # ------------------------------------------------------------------------------------ # --------------------------------- Initialization ----------------------------------- # ------------------------------------------------------------------------------------ @@ -149,11 +158,10 @@ def add_entity(self, idx, material, morph, surface, visualize_contact) -> Entity EntityClass = AvatarEntity if visualize_contact: gs.raise_exception("AvatarEntity does not support 'visualize_contact=True'.") + elif isinstance(morph, gs.morphs.Drone): + EntityClass = DroneEntity else: - if isinstance(morph, gs.morphs.Drone): - EntityClass = DroneEntity - else: - EntityClass = RigidEntity + EntityClass = RigidEntity if morph.is_free: verts_state_start = self.n_free_verts @@ -185,6 +193,7 @@ def add_entity(self, idx, material, morph, surface, visualize_contact) -> Entity vface_start=self.n_vfaces, visualize_contact=visualize_contact, ) + assert isinstance(entity, RigidEntity) self._entities.append(entity) return entity @@ -249,6 +258,9 @@ def build(self): self.n_equalities_candidate = max(1, self.n_equalities + self._options.max_dynamic_constraints) + # Note optional hibernation_threshold_acc/vel params at the bottom of the initialization list. + # This is caused by this code being also run by AvatarSolver, which inherits from this class + # but does not have all the attributes of the base class. self._static_rigid_sim_config = self.StaticRigidSimConfig( para_level=self.sim._para_level, use_hibernation=getattr(self, "_use_hibernation", False), @@ -273,6 +285,8 @@ def build(self): ls_tolerance=getattr(self._options, "ls_tolerance", 1e-6), n_equalities=self._n_equalities, n_equalities_candidate=self.n_equalities_candidate, + hibernation_thresh_acc=getattr(self, "_hibernation_thresh_acc", 0.0), + hibernation_thresh_vel=getattr(self, "_hibernation_thresh_vel", 0.0), ) # when the migration is finished, we will remove the about two lines @@ -289,8 +303,8 @@ def build(self): self.awake_dofs = self._rigid_global_info.awake_dofs self.n_awake_links = self._rigid_global_info.n_awake_links self.awake_links = self._rigid_global_info.awake_links - self.n_awake_entities = self._rigid_global_info.data_manager.n_awake_entities - self.awake_entities = self._rigid_global_info.data_manager.awake_entities + self.n_awake_entities = self._rigid_global_info.n_awake_entities + self.awake_entities = self._rigid_global_info.awake_entities self._init_mass_mat() self._init_dof_fields() @@ -704,8 +718,8 @@ def _init_vvert_fields(self): ) def _init_geom_fields(self): - self.geoms_info = self.data_manager.geoms_info - self.geoms_state = self.data_manager.geoms_state + self.geoms_info: array_class.GeomsInfo = self.data_manager.geoms_info + self.geoms_state: array_class.GeomsState = self.data_manager.geoms_state self.geoms_init_AABB = self._rigid_global_info.geoms_init_AABB self._geoms_render_T = np.empty((self.n_geoms_, self._B, 4, 4), dtype=np.float32) @@ -762,8 +776,8 @@ def _init_geom_fields(self): ) def _init_vgeom_fields(self): - self.vgeoms_info = self.data_manager.vgeoms_info - self.vgeoms_state = self.data_manager.vgeoms_state + self.vgeoms_info: array_class.VGeomsInfo = self.data_manager.vgeoms_info + self.vgeoms_state: array_class.VGeomsState = self.data_manager.vgeoms_state self._vgeoms_render_T = np.empty((self.n_vgeoms_, self._B, 4, 4), dtype=np.float32) if self.n_vgeoms > 0: @@ -912,11 +926,14 @@ def substep(self): dofs_info=self.dofs_info, geoms_state=self.geoms_state, geoms_info=self.geoms_info, + entities_state=self.entities_state, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + contact_island=self.constraint_solver.contact_island, ) # timer.stamp("kernel_step_1") + if isinstance(self.sim.coupler, SAPCoupler): self.update_qvel() else: @@ -936,6 +953,7 @@ def substep(self): collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + contact_island=self.constraint_solver.contact_island, ) # timer.stamp("kernel_step_2") @@ -980,9 +998,12 @@ def _func_forward_dynamics(self): dofs_state=self.dofs_state, dofs_info=self.dofs_info, joints_info=self.joints_info, + entities_state=self.entities_state, entities_info=self.entities_info, + geoms_state=self.geoms_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + contact_island=self.constraint_solver.contact_island, ) def _func_update_acc(self): @@ -1253,7 +1274,6 @@ def substep_pre_coupling_grad(self, f): pass def substep_post_coupling(self, f): - from genesis.engine.couplers import SAPCoupler if self.is_active() and isinstance(self.sim.coupler, SAPCoupler): @@ -1272,6 +1292,7 @@ def substep_post_coupling(self, f): collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + contact_island=self.constraint_solver.contact_island, ) def substep_post_coupling_grad(self, f): @@ -3041,9 +3062,12 @@ def kernel_forward_dynamics( dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, joints_info: array_class.JointsInfo, + entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, + geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island: ti.template(), # ContactIsland ): func_forward_dynamics( links_state=links_state, @@ -3051,9 +3075,12 @@ def kernel_forward_dynamics( dofs_state=dofs_state, dofs_info=dofs_info, joints_info=joints_info, + entities_state=entities_state, entities_info=entities_info, + geoms_state=geoms_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + contact_island=contact_island, ) @@ -3756,9 +3783,12 @@ def func_forward_dynamics( dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, joints_info: array_class.JointsInfo, + entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, + geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island: ti.template(), ): func_compute_mass_matrix( implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), @@ -3779,13 +3809,17 @@ def func_forward_dynamics( static_rigid_sim_config=static_rigid_sim_config, ) func_torque_and_passive_force( + entities_state=entities_state, entities_info=entities_info, dofs_state=dofs_state, dofs_info=dofs_info, + links_state=links_state, links_info=links_info, joints_info=joints_info, + geoms_state=geoms_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + contact_island=contact_island, ) func_update_acc( update_cacc=False, @@ -3903,9 +3937,11 @@ def kernel_step_1( dofs_info: array_class.DofsInfo, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, + entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island: ti.template(), ): if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): _B = links_state.pos.shape[1] @@ -3932,9 +3968,12 @@ def kernel_step_1( dofs_state=dofs_state, dofs_info=dofs_info, joints_info=joints_info, + entities_state=entities_state, entities_info=entities_info, + geoms_state=geoms_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + contact_island=contact_island, ) @@ -4014,6 +4053,7 @@ def kernel_step_2( collider_state: array_class.ColliderState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island: ti.template(), # ContactIsland ): # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it # would not corresponds to anyting physical. There is no other way than doing this right before integration, @@ -4048,15 +4088,16 @@ def kernel_step_2( ) if ti.static(static_rigid_sim_config.use_hibernation): - func_hibernate( + func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_buffer( dofs_state=dofs_state, entities_state=entities_state, entities_info=entities_info, links_state=links_state, geoms_state=geoms_state, collider_state=collider_state, - rigid_global_info=rigid_global_info, + unused__rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + contact_island=contact_island, ) func_aggregate_awake_entities( entities_state=entities_state, @@ -4928,59 +4969,88 @@ def kernel_update_vgeoms( @ti.func -def func_hibernate( - dofs_state, - entities_state, - entities_info, - links_state, - geoms_state, - collider_state, - rigid_global_info, +def func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_buffer( + dofs_state: array_class.DofsState, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + links_state: array_class.LinksState, + geoms_state: array_class.GeomsState, + collider_state: array_class.ColliderState, + unused__rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), -): + contact_island: ti.template(), # ContactIsland, +) -> None: n_entities = entities_state.hibernated.shape[0] _B = entities_state.hibernated.shape[1] + ci = contact_island + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e, i_b in ti.ndrange(n_entities, _B): - if ( - not entities_state.hibernated[i_e, i_b] and entities_info.n_dofs[i_e] > 0 - ): # We do not hibernate fixed entity - hibernate = True - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - if ( - ti.abs(dofs_state.acc[i_d, i_b]) > static_rigid_sim_config.hibernation_thresh_acc - or ti.abs(dofs_state.vel[i_d, i_b]) > static_rigid_sim_config.hibernation_thresh_vel - ): - hibernate = False - break + for i_b in range(_B): + for island_idx in range(ci.n_islands[i_b]): + was_island_hibernated = ci.island_hibernated[island_idx, i_b] + + if not was_island_hibernated: + are_all_entities_okay_for_hibernation = True + entity_ref_range = ci.island_entity[island_idx, i_b] + for i_entity_ref_offset_ in range(entity_ref_range.n): + entity_ref = entity_ref_range.start + i_entity_ref_offset_ + entity_idx = ci.entity_id[entity_ref, i_b] + + # Hibernated entities already have zero dofs_state.acc/vel + is_entity_hibernated = entities_state.hibernated[entity_idx, i_b] + if is_entity_hibernated: + continue + + for i_d in range(entities_info.dof_start[entity_idx], entities_info.dof_end[entity_idx]): + max_acc = static_rigid_sim_config.hibernation_thresh_acc + max_vel = static_rigid_sim_config.hibernation_thresh_vel + if ti.abs(dofs_state.acc[i_d, i_b]) > max_acc or ti.abs(dofs_state.vel[i_d, i_b]) > max_vel: + are_all_entities_okay_for_hibernation = False + break + + if not are_all_entities_okay_for_hibernation: + break - if hibernate: - func_hibernate_entity( - i_e, - i_b, - entities_state=entities_state, - entities_info=entities_info, - dofs_state=dofs_state, - links_state=links_state, - geoms_state=geoms_state, - ) - else: - # update collider sort_buffer - for i_g in range(entities_info.geom_start[i_e], entities_info.geom_end[i_e]): - collider_state.sort_buffer.value[geoms_state.min_buffer_idx[i_g, i_b], i_b] = geoms_state.aabb_min[ - i_g, i_b - ][0] - collider_state.sort_buffer.value[geoms_state.max_buffer_idx[i_g, i_b], i_b] = geoms_state.aabb_max[ - i_g, i_b - ][0] + if not are_all_entities_okay_for_hibernation: + # update collider sort_buffer with aabb extents along x-axis + for i_entity_ref_offset_ in range(entity_ref_range.n): + entity_ref = entity_ref_range.start + i_entity_ref_offset_ + entity_idx = ci.entity_id[entity_ref, i_b] + for i_g in range(entities_info.geom_start[entity_idx], entities_info.geom_end[entity_idx]): + min_idx, min_val = geoms_state.min_buffer_idx[i_g, i_b], geoms_state.aabb_min[i_g, i_b][0] + max_idx, max_val = geoms_state.max_buffer_idx[i_g, i_b], geoms_state.aabb_max[i_g, i_b][0] + collider_state.sort_buffer.value[min_idx, i_b] = min_val + collider_state.sort_buffer.value[max_idx, i_b] = max_val + else: + # perform hibernation + prev_entity_ref = entity_ref_range.start + entity_ref_range.n - 1 + prev_entity_idx = ci.entity_id[prev_entity_ref, i_b] + + for i_entity_ref_offset_ in range(entity_ref_range.n): + entity_ref = entity_ref_range.start + i_entity_ref_offset_ + entity_idx = ci.entity_id[entity_ref, i_b] + + func_hibernate_entity_and_zero_dof_velocities( + entity_idx, + i_b, + entities_state=entities_state, + entities_info=entities_info, + dofs_state=dofs_state, + links_state=links_state, + geoms_state=geoms_state, + ) + + # store entities in the hibernated islands by daisy chaining them + ci.entity_idx_to_next_entity_idx_in_hibernated_island[prev_entity_idx, i_b] = entity_idx + prev_entity_idx = entity_idx @ti.func def func_aggregate_awake_entities( - entities_state, - entities_info, - rigid_global_info, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): @@ -4993,29 +5063,38 @@ def func_aggregate_awake_entities( for i_e, i_b in ti.ndrange(n_entities, _B): if entities_state.hibernated[i_e, i_b] or entities_info.n_dofs[i_e] == 0: continue - n_awake_entities = ti.atomic_add(rigid_global_info.n_awake_entities[i_b], 1) - rigid_global_info.awake_entities[n_awake_entities, i_b] = i_e - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - n_awake_dofs = ti.atomic_add(rigid_global_info.n_awake_dofs[i_b], 1) - rigid_global_info.awake_dofs[n_awake_dofs, i_b] = i_d + next_awake_entity_idx = ti.atomic_add(rigid_global_info.n_awake_entities[i_b], 1) + rigid_global_info.awake_entities[next_awake_entity_idx, i_b] = i_e - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - n_awake_links = ti.atomic_add(rigid_global_info.n_awake_links[i_b], 1) - rigid_global_info.awake_links[n_awake_links, i_b] = i_l + n_dofs = entities_info.n_dofs[i_e] + entity_dofs_base_idx: ti.int32 = entities_info.dof_start[i_e] + awake_dofs_base_idx = ti.atomic_add(rigid_global_info.n_awake_dofs[i_b], n_dofs) + for i in range(n_dofs): + rigid_global_info.awake_dofs[awake_dofs_base_idx + i, i_b] = entity_dofs_base_idx + i + + n_links = entities_info.n_links[i_e] + entity_links_base_idx: ti.int32 = entities_info.link_start[i_e] + awake_links_base_idx = ti.atomic_add(rigid_global_info.n_awake_links[i_b], n_links) + for i in range(n_links): + rigid_global_info.awake_links[awake_links_base_idx + i, i_b] = entity_links_base_idx + i @ti.func -def func_hibernate_entity( - i_e, - i_b, - entities_state, - entities_info, - dofs_state, - links_state, - geoms_state, -): +def func_hibernate_entity_and_zero_dof_velocities( + i_e: int, + i_b: int, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + links_state: array_class.LinksState, + geoms_state: array_class.GeomsState, +) -> None: + """ + Mark RigidEnity, individual DOFs in DofsState, RigidLinks, and RigidGeoms as hibernated. + Also, zero out DOF velocitities and accelerations. + """ entities_state.hibernated[i_e, i_b] = True for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): @@ -5030,32 +5109,6 @@ def func_hibernate_entity( geoms_state.hibernated[i_g, i_b] = True -@ti.func -def func_wakeup_entity( - i_e, - i_b, - entities_state: array_class.EntitiesState, - entities_info: array_class.EntitiesInfo, - dofs_state: array_class.DofsState, - links_state: array_class.LinksState, - rigid_global_info: array_class.RigidGlobalInfo, -): - if entities_state.hibernated[i_e, i_b]: - entities_state.hibernated[i_e, i_b] = False - n_awake_entities = ti.atomic_add(rigid_global_info.n_awake_entities[i_b], 1) - rigid_global_info.awake_entities[n_awake_entities, i_b] = i_e - - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - dofs_state.hibernated[i_d, i_b] = False - n_awake_dofs = ti.atomic_add(rigid_global_info.n_awake_dofs[i_b], 1) - rigid_global_info.awake_dofs[n_awake_dofs, i_b] = i_d - - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - links_state.hibernated[i_l, i_b] = False - n_awake_links = ti.atomic_add(rigid_global_info.n_awake_links[i_b], 1) - rigid_global_info.awake_links[n_awake_links, i_b] = i_l - - @ti.kernel def kernel_apply_links_external_force( force: ti.types.ndarray(), @@ -5165,13 +5218,17 @@ def func_clear_external_force( @ti.func def func_torque_and_passive_force( + entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, + links_state: array_class.LinksState, links_info: array_class.LinksInfo, joints_info: array_class.JointsInfo, + geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island: ti.template(), ): n_entities = entities_info.n_links.shape[0] _B = dofs_state.ctrl_mode.shape[1] @@ -5257,13 +5314,22 @@ def func_torque_and_passive_force( if ti.abs(force) > gs.EPS: wakeup = True - if ti.static(static_rigid_sim_config.use_hibernation): - if wakeup: - func_wakeup_entity(i_e, i_b) + if ti.static(static_rigid_sim_config.use_hibernation) and entities_state.hibernated[i_e, i_b] and wakeup: + func_wakeup_entity_and_its_temp_island( + i_e, + i_b, + entities_state, + entities_info, + dofs_state, + links_state, + geoms_state, + rigid_global_info, + contact_island, + ) if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(rigid_global_info._B): + for i_b in range(_B): for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): i_d = rigid_global_info.awake_dofs[i_d_, i_b] I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d @@ -5271,7 +5337,7 @@ def func_torque_and_passive_force( dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(rigid_global_info._B): + for i_b in range(_B): for i_l_ in range(rigid_global_info.n_awake_links[i_b]): i_l = rigid_global_info.awake_links[i_l_, i_b] I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l @@ -5465,7 +5531,7 @@ def func_update_force( for i_b in range(_B): for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l in range(entities_info.link_end[i_e] - 1 - entities_info.link_start[i_e]): + for i in range(entities_info.n_links[i_e]): i_l = entities_info.link_end[i_e] - 1 - i I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -5603,7 +5669,7 @@ def func_compute_qacc( if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in ti.range(_B): + for i_b in range(_B): for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): i_e = rigid_global_info.awake_entities[i_e_, i_b] for i_d1_ in range(entities_info.n_dofs[i_e]): @@ -5645,11 +5711,11 @@ def func_integrate( I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j dof_start = joints_info.dof_start[I_j] q_start = joints_info.q_start[I_j] q_end = joints_info.q_end[I_j] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] if joint_type == gs.JOINT_TYPE.FREE: diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp_util.py b/genesis/engine/solvers/rigid/rigid_solver_decomp_util.py new file mode 100644 index 0000000000..8545b5aecc --- /dev/null +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp_util.py @@ -0,0 +1,57 @@ +import taichi as ti + +import genesis.utils.array_class as array_class + +from genesis.engine.solvers.rigid.contact_island import INVALID_NEXT_HIBERNATED_ENTITY_IDX + + +@ti.func +def func_wakeup_entity_and_its_temp_island( + i_e, + i_b, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + links_state: array_class.LinksState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + contact_island: ti.template(), +): + # Note: Original function handled non-hibernated & fixed entities. + # Now, we require a properly hibernated entity to be passed in. + island_idx = contact_island.entity_island[i_e, i_b] + + entity_ref_range = contact_island.island_entity[island_idx, i_b] + for ei in range(entity_ref_range.n): + entity_ref = entity_ref_range.start + ei + entity_idx = contact_island.entity_id[entity_ref, i_b] + + is_entity_hibernated = entities_state.hibernated[entity_idx, i_b] + + if is_entity_hibernated: + contact_island.entity_idx_to_next_entity_idx_in_hibernated_island[entity_idx, i_b] = ( + INVALID_NEXT_HIBERNATED_ENTITY_IDX + ) + + entities_state.hibernated[entity_idx, i_b] = False + n_awake_entities = ti.atomic_add(rigid_global_info.n_awake_entities[i_b], 1) + rigid_global_info.awake_entities[n_awake_entities, i_b] = entity_idx + + n_dofs = entities_info.n_dofs[entity_idx] + base_entity_dof_idx = entities_info.dof_start[entity_idx] + base_awake_dof_idx = ti.atomic_add(rigid_global_info.n_awake_dofs[i_b], n_dofs) + for i in range(n_dofs): + i_d = base_entity_dof_idx + i + dofs_state.hibernated[i_d, i_b] = False + rigid_global_info.awake_dofs[base_awake_dof_idx + i, i_b] = i_d + + n_links = entities_info.n_links[entity_idx] + base_entity_link_idx = entities_info.link_start[entity_idx] + base_awake_link_idx = ti.atomic_add(rigid_global_info.n_awake_links[i_b], n_links) + for i in range(n_links): + i_l = base_entity_link_idx + i + links_state.hibernated[i_l, i_b] = False + rigid_global_info.awake_links[base_awake_link_idx + i, i_b] = i_l + + for i_g in range(entities_info.geom_start[entity_idx], entities_info.geom_end[entity_idx]): + geoms_state.hibernated[i_g, i_b] = False diff --git a/genesis/options/morphs.py b/genesis/options/morphs.py index d46413065d..3d6762c724 100644 --- a/genesis/options/morphs.py +++ b/genesis/options/morphs.py @@ -77,6 +77,9 @@ class Morph(Options): **This is only used for RigidEntity.** is_free : bool, optional Whether the entity is free to move. Defaults to True. **This is only used for RigidEntity.** + This determines whether the entity's geoms have their vertices put into StructFreeVertsState or + StructFixedVertsState, and effectively whether they're stored per batch-element, or stored once and shared + for the entire batch. That affects correct processing of collision detection. """ # Note: pos, euler, quat store only initial varlues at creation time, and are unaffected by sim diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 408be39be7..56ec37ba15 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -25,6 +25,8 @@ class StructRigidGlobalInfo: awake_dofs: V_ANNOTATION n_awake_entities: V_ANNOTATION awake_entities: V_ANNOTATION + n_awake_links: V_ANNOTATION + awake_links: V_ANNOTATION qpos0: V_ANNOTATION qpos: V_ANNOTATION links_T: V_ANNOTATION @@ -48,6 +50,8 @@ def get_rigid_global_info(solver): "awake_dofs": V(dtype=gs.ti_int, shape=f_batch(solver.n_dofs_)), "n_awake_entities": V(dtype=gs.ti_int, shape=f_batch()), "awake_entities": V(dtype=gs.ti_int, shape=f_batch(solver.n_entities_)), + "n_awake_links": V(dtype=gs.ti_int, shape=f_batch()), + "awake_links": V(dtype=gs.ti_int, shape=f_batch(solver.n_links)), "qpos0": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_qs_)), "qpos": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_qs_)), "links_T": V_MAT(n=4, m=4, dtype=gs.ti_float, shape=solver.n_links), @@ -245,6 +249,7 @@ def __init__(self): @dataclasses.dataclass class StructContactData: + # WARNING: cannot add/remove fields here without also updating collider_decomp.py::kernel_collider_clear geom_a: V_ANNOTATION geom_b: V_ANNOTATION penetration: V_ANNOTATION