diff --git a/genesis/__init__.py b/genesis/__init__.py index 5395d72203..fc2437e456 100644 --- a/genesis/__init__.py +++ b/genesis/__init__.py @@ -56,6 +56,9 @@ def init( if _initialized: raise_exception("Genesis already initialized.") + # Make sure evertything is properly destroyed, just in case initialization failed previously + destroy() + # genesis._theme global _theme is_theme_valid = theme in ("dark", "light", "dumb") diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 5c3ad58609..9379b4a2af 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1837,32 +1837,6 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns if zero_velocity: self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) - @gs.assert_built - def get_weld_constraints(self, with_entity=None, exclude_self_contact=False): - welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True) - obj_a = welds["obj_a"] - obj_b = welds["obj_b"] - - # Create mask for filtering welds involving this entity - mask = (obj_a == self.idx) | (obj_b == self.idx) - - # Additional filtering if with_entity is specified - if with_entity is not None: - if self.idx == with_entity.idx: - if exclude_self_contact: - gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.") - # For self-contact, keep only self-welds - mask = mask & ((obj_a == self.idx) & (obj_b == self.idx)) - else: - # For cross-entity, keep welds between this entity and with_entity - mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx)) - - # Apply filtering - for k in ("obj_a", "obj_b"): - welds[k] = welds[k][mask] - - return welds - @gs.assert_built def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): """ diff --git a/genesis/engine/solvers/rigid/collider_decomp.py b/genesis/engine/solvers/rigid/collider_decomp.py index 14dc11e8fe..27e3f601dc 100644 --- a/genesis/engine/solvers/rigid/collider_decomp.py +++ b/genesis/engine/solvers/rigid/collider_decomp.py @@ -545,6 +545,8 @@ def collider_kernel_get_contacts( ): _B = collider_state.active_buffer.shape[1] n_contacts_max = gs.ti_int(0) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): n_contacts = collider_state.n_contacts[i_b] if n_contacts > n_contacts_max: diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index a0b9ca6def..7f07658808 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING + import numpy as np -import taichi as ti import numpy.typing as npt +import taichi as ti +import torch import genesis as gs import genesis.utils.geom as gu @@ -37,6 +39,8 @@ def __init__(self, rigid_solver: "RigidSolver"): self.constraint_state = array_class.get_constraint_state(self, self._solver) + self._eq_const_info_cache = {} + # self.ti_n_equalities = ti.field(gs.ti_int, shape=self._solver._batch_shape()) # self.ti_n_equalities.from_numpy(np.full((self._solver._B,), self._solver.n_equalities, dtype=gs.np_int)) @@ -157,11 +161,13 @@ def __init__(self, rigid_solver: "RigidSolver"): self.reset() def clear(self, envs_idx: npt.NDArray[np.int32] | None = None): + self._eq_const_info_cache.clear() if envs_idx is None: envs_idx = self._solver._scene._envs_idx constraint_solver_kernel_clear(envs_idx, self._solver._static_rigid_sim_config, self.constraint_state) def reset(self, envs_idx=None): + self._eq_const_info_cache.clear() if envs_idx is None: envs_idx = self._solver._scene._envs_idx constraint_solver_kernel_reset( @@ -253,6 +259,137 @@ def resolve(self): ) # timer.stamp("compute force") + def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True): + # Early return if already pre-computed + eq_const_info = self._eq_const_info_cache.get((as_tensor, to_torch)) + if eq_const_info is not None: + return eq_const_info.copy() + + n_eqs = tuple(self.constraint_state.ti_n_equalities.to_numpy()) + n_envs = len(n_eqs) + n_eqs_max = max(n_eqs) + + if as_tensor: + out_size = n_envs * n_eqs_max + else: + *n_eqs_starts, out_size = np.cumsum(n_eqs) + + if to_torch: + iout = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device) + fout = torch.zeros((out_size, 6), dtype=gs.tc_float, device=gs.device) + else: + iout = np.full((out_size, 3), -1, dtype=gs.np_int) + fout = np.zeros((out_size, 6), dtype=gs.np_float) + + if n_eqs_max > 0: + kernel_get_equality_constraints( + as_tensor, + iout, + fout, + self.constraint_state, + self._solver.equalities_info, + self._solver._static_rigid_sim_config, + ) + + if as_tensor: + iout = iout.reshape((n_envs, n_eqs_max, 3)) + eq_type, obj_a, obj_b = (iout[..., i] for i in range(3)) + efc_force = fout.reshape((n_envs, n_eqs_max, 6)) + values = (eq_type, obj_a, obj_b, fout) + else: + if to_torch: + iout_chunks = torch.split(iout, n_eqs) + efc_force = torch.split(fout, n_eqs) + else: + iout_chunks = np.split(iout, n_eqs_starts) + efc_force = np.split(fout, n_eqs_starts) + eq_type, obj_a, obj_b = tuple(zip(*([data[..., i] for i in range(3)] for data in iout_chunks))) + + values = (eq_type, obj_a, obj_b, efc_force) + eq_const_info = dict(zip(("type", "obj_a", "obj_b", "force"), values)) + + # Cache equality constraint information before returning + self._eq_const_info_cache[(as_tensor, to_torch)] = eq_const_info + + return eq_const_info.copy() + + def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True): + eq_const_info = self.get_equality_constraints(as_tensor, to_torch) + eq_type = eq_const_info.pop("type") + + weld_const_info = {} + if as_tensor: + weld_mask = eq_type == gs.EQUALITY_TYPE.WELD + n_envs = len(weld_mask) + n_welds = weld_mask.sum(dim=-1) if to_torch else np.sum(weld_mask, axis=-1) + n_welds_max = max(n_welds) + for key, value in eq_const_info.items(): + shape = (n_envs, n_welds_max, *value.shape[2:]) + if to_torch: + if torch.is_floating_point(value): + weld_const_info[key] = torch.zeros(shape, dtype=value.dtype, device=value.device) + else: + weld_const_info[key] = torch.full(shape, -1, dtype=value.dtype, device=value.device) + else: + if np.issubdtype(value.dtype, np.floating): + weld_const_info[key] = np.zeros(shape, dtype=value.dtype) + else: + weld_const_info[key] = np.full(shape, -1, dtype=value.dtype) + for i_b, (n_welds_i, weld_mask_i) in enumerate(zip(n_welds, weld_mask)): + for eq_value, weld_value in zip(eq_const_info.values(), weld_const_info.values()): + weld_value[i_b, :n_welds_i] = eq_value[i_b, weld_mask_i] + else: + weld_mask_chunks = tuple(eq_type_i == gs.EQUALITY_TYPE.WELD for eq_type_i in eq_type) + for key, value in eq_const_info.items(): + weld_const_info[key] = tuple(data[weld_mask] for weld_mask, data in zip(weld_mask_chunks, value)) + + weld_const_info["link_a"] = weld_const_info.pop("obj_a") + weld_const_info["link_b"] = weld_const_info.pop("obj_b") + + return weld_const_info + + def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): + envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) + link1_idx, link2_idx = int(link1_idx), int(link2_idx) + + if not unsafe: + assert link1_idx >= 0 and link2_idx >= 0 + weld_const_info = self.get_weld_constraints(as_tensor=True, to_torch=True) + link_a = weld_const_info["link_a"] + link_b = weld_const_info["link_b"] + assert not ( + ((link_a == link1_idx) | (link_b == link1_idx)) & ((link_a == link2_idx) | (link_b == link2_idx)) + ).any() + + self._eq_const_info_cache.clear() + overflow = kernel_add_weld_constraint( + link1_idx, + link2_idx, + envs_idx, + self._solver.equalities_info, + self.constraint_state, + self._solver.links_state, + self._solver._static_rigid_sim_config, + ) + if overflow: + gs.logger.warning( + "Ignoring dynamically registered weld constraint to avoid exceeding max number of equality constraints" + f"({self._static_rigid_sim_config.n_equalities_candidate}). Please increase the value of " + "RigidSolver's option 'max_dynamic_constraints'." + ) + + def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): + envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) + self._eq_const_info_cache.clear() + kernel_delete_weld_constraint( + int(link1_idx), + int(link2_idx), + envs_idx, + self._solver.equalities_info, + self.constraint_state, + self._solver._static_rigid_sim_config, + ) + @ti.kernel def constraint_solver_kernel_clear( @@ -486,11 +623,11 @@ def func_equality_connect( imp, aref = gu.imp_aref(sol_params, -penetration, jac_qvel, pos_diff[i_3]) - diag = ti.max(invweight * (1 - imp) / imp, gs.EPS) + diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS) constraint_state.diag[n_con, i_b] = diag constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag + constraint_state.efc_D[n_con, i_b] = 1.0 / diag @ti.func @@ -564,11 +701,11 @@ def func_equality_joint( imp, aref = gu.imp_aref(sol_params, -ti.abs(pos), jac_qvel, pos) - diag = ti.max(invweight * (1 - imp) / imp, gs.EPS) + diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS) constraint_state.diag[n_con, i_b] = diag constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag + constraint_state.efc_D[n_con, i_b] = 1.0 / diag @ti.kernel @@ -1939,3 +2076,129 @@ def func_init_solver( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(n_dofs, _B): constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@ti.kernel +def kernel_add_weld_constraint( + link1_idx: ti.i32, + link2_idx: ti.i32, + envs_idx: ti.types.ndarray(), + equalities_info: array_class.EqualitiesInfo, + constraint_state: array_class.ConstraintState, + links_state: array_class.LinksState, + static_rigid_sim_config: ti.template(), +) -> ti.i32: + overflow = gs.ti_bool(False) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_b_ in ti.ndrange(envs_idx.shape[0]): + i_b = envs_idx[i_b_] + i_e = constraint_state.ti_n_equalities[i_b] + if i_e == static_rigid_sim_config.n_equalities_candidate: + overflow = True + else: + shared_pos = links_state.pos[link1_idx, i_b] + pos1 = gu.ti_inv_transform_by_trans_quat( + shared_pos, links_state.pos[link1_idx, i_b], links_state.quat[link1_idx, i_b] + ) + pos2 = gu.ti_inv_transform_by_trans_quat( + shared_pos, links_state.pos[link2_idx, i_b], links_state.quat[link2_idx, i_b] + ) + + equalities_info.eq_type[i_e, i_b] = gs.ti_int(gs.EQUALITY_TYPE.WELD) + equalities_info.eq_obj1id[i_e, i_b] = link1_idx + equalities_info.eq_obj2id[i_e, i_b] = link2_idx + + for i_3 in ti.static(range(3)): + equalities_info.eq_data[i_e, i_b][i_3 + 3] = pos1[i_3] + equalities_info.eq_data[i_e, i_b][i_3] = pos2[i_3] + + relpose = gu.ti_quat_mul(gu.ti_inv_quat(links_state.quat[link1_idx, i_b]), links_state.quat[link2_idx, i_b]) + + equalities_info.eq_data[i_e, i_b][6] = relpose[0] + equalities_info.eq_data[i_e, i_b][7] = relpose[1] + equalities_info.eq_data[i_e, i_b][8] = relpose[2] + equalities_info.eq_data[i_e, i_b][9] = relpose[3] + + equalities_info.eq_data[i_e, i_b][10] = 1.0 + equalities_info.sol_params[i_e, i_b] = ti.Vector( + [2 * static_rigid_sim_config.substep_dt, 1.0e00, 9.0e-01, 9.5e-01, 1.0e-03, 5.0e-01, 2.0e00] + ) + + constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] + 1 + return overflow + + +@ti.kernel +def kernel_delete_weld_constraint( + link1_idx: ti.i32, + link2_idx: ti.i32, + envs_idx: ti.types.ndarray(), + equalities_info: array_class.EqualitiesInfo, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_b_ in ti.ndrange(envs_idx.shape[0]): + i_b = envs_idx[i_b_] + for i_e in range(static_rigid_sim_config.n_equalities, constraint_state.ti_n_equalities[i_b]): + if ( + equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD + and equalities_info.eq_obj1id[i_e, i_b] == link1_idx + and equalities_info.eq_obj2id[i_e, i_b] == link2_idx + ): + if i_e < constraint_state.ti_n_equalities[i_b] - 1: + equalities_info.eq_type[i_e, i_b] = equalities_info.eq_type[ + constraint_state.ti_n_equalities[i_b] - 1, i_b + ] + constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1 + + +@ti.kernel +def kernel_get_equality_constraints( + is_padded: ti.template(), + iout: ti.types.ndarray(), + fout: ti.types.ndarray(), + constraint_state: array_class.ConstraintState, + equalities_info: array_class.EqualitiesInfo, + static_rigid_sim_config: ti.template(), +): + _B = constraint_state.ti_n_equalities.shape[0] + n_eqs_max = gs.ti_int(0) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + n_eqs = constraint_state.ti_n_equalities[i_b] + if n_eqs > n_eqs_max: + n_eqs_max = n_eqs + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + i_c_start = gs.ti_int(0) + i_e_start = gs.ti_int(0) + if ti.static(is_padded): + i_e_start = i_b * n_eqs_max + else: + for j_b in range(i_b): + i_e_start = i_e_start + constraint_state.ti_n_equalities[j_b] + + for i_e_ in range(constraint_state.ti_n_equalities[i_b]): + i_e = i_e_start + i_e_ + + iout[i_e, 0] = equalities_info.eq_type[i_e_, i_b] + iout[i_e, 1] = equalities_info.eq_obj1id[i_e_, i_b] + iout[i_e, 2] = equalities_info.eq_obj2id[i_e_, i_b] + + if equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.CONNECT: + for i_c_ in ti.static(range(3)): + i_c = i_c_start + i_c_ + fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b] + i_c_start = i_c_start + 3 + elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.WELD: + for i_c_ in ti.static(range(6)): + i_c = i_c_start + i_c_ + fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b] + i_c_start = i_c_start + 6 + elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.JOINT: + fout[i_e, 0] = constraint_state.efc_force[i_c_start, i_b] + i_c_start = i_c_start + 1 diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index afb4aa15c3..b8cd5ed187 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -2165,27 +2165,16 @@ def set_geoms_friction(self, friction, geoms_idx=None, *, unsafe=False): kernel_set_geoms_friction(friction, geoms_idx, self.geoms_info, self._static_rigid_sim_config) def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): - envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) - kernel_add_weld_constraint( - int(link1_idx), - int(link2_idx), - envs_idx, - self.equalities_info, - self.constraint_solver.constraint_state, - self.links_state, - self._static_rigid_sim_config, - ) + return self.constraint_solver.add_weld_constraint(link1_idx, link2_idx, envs_idx, unsafe=unsafe) def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): - envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) - kernel_delete_weld_constraint( - int(link1_idx), - int(link2_idx), - envs_idx, - self.equalities_info, - self.constraint_solver.constraint_state, - self._static_rigid_sim_config, - ) + return self.constraint_solver.delete_weld_constraint(link1_idx, link2_idx, envs_idx, unsafe=unsafe) + + def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True): + return self.constraint_solver.get_weld_constraints(as_tensor, to_torch) + + def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True): + return self.constraint_solver.get_equality_constraints(as_tensor, to_torch) def clear_external_force(self): kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config) @@ -2293,53 +2282,6 @@ def update_verts_for_geom(self, i_g): self.fixed_verts_state, ) - def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True): - n_eqs = tuple(self.constraint_solver.constraint_state.ti_n_equalities.to_numpy()) - n_envs = len(n_eqs) - n_max = max(n_eqs) if n_eqs else 0 - - if as_tensor: - out_size = n_envs * n_max - else: - cumsum = np.cumsum(n_eqs, dtype=np.int32) - out_size = int(cumsum[-1]) if n_envs else 0 - - if to_torch: - buf = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device) - else: - buf = np.full((out_size, 3), -1, dtype=np.int32) - - if n_max > 0: - kernel_collect_welds( - as_tensor, - buf, - self.constraint_solver.constraint_state, - self.equalities_info, - self._static_rigid_sim_config, - ) - - if n_envs > 0: - if as_tensor: - buf = buf.reshape((n_envs, n_max, 3)) - obj_a = buf[..., 1] - obj_b = buf[..., 2] - else: - if to_torch: - data_chunks = torch.split(buf, n_eqs) - else: - splits = list(np.cumsum(n_eqs, dtype=np.int32)[:-1]) - data_chunks = np.split(buf, splits) - obj_a, obj_b = tuple(zip(*((data[:, 1], data[:, 2]) for data in data_chunks))) - else: - if to_torch: - obj_a = torch.empty((0,), dtype=gs.tc_int, device=gs.device) - obj_b = torch.empty((0,), dtype=gs.tc_int, device=gs.device) - else: - obj_a = [] - obj_b = [] - - return {"obj_a": obj_a, "obj_b": obj_b} - # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ @@ -6688,115 +6630,3 @@ def kernel_set_geoms_friction( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_g_ in ti.ndrange(geoms_idx.shape[0]): geoms_info.friction[geoms_idx[i_g_]] = friction[i_g_] - - -@ti.kernel -def kernel_add_weld_constraint( - link1_idx: ti.i32, - link2_idx: ti.i32, - envs_idx: ti.types.ndarray(), - equalities_info: array_class.EqualitiesInfo, - constraint_state: array_class.ConstraintState, - links_state: array_class.LinksState, - static_rigid_sim_config: ti.template(), -): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b_ in ti.ndrange(envs_idx.shape[0]): - i_b = envs_idx[i_b_] - i_e = constraint_state.ti_n_equalities[i_b] - if i_e == static_rigid_sim_config.n_equalities_candidate: - print( - f"{colors.YELLOW}[Genesis] [00:00:00] [WARNING] Ignoring dynamically registered weld constraint " - f"to avoid exceeding max number of equality constraints ({static_rigid_sim_config.n_equalities_candidate}). " - f"Please increase the value of RigidSolver's option 'max_dynamic_constraints'.{formats.RESET}" - ) - else: - shared_pos = links_state.pos[link1_idx, i_b] - pos1 = gu.ti_inv_transform_by_trans_quat( - shared_pos, links_state.pos[link1_idx, i_b], links_state.quat[link1_idx, i_b] - ) - pos2 = gu.ti_inv_transform_by_trans_quat( - shared_pos, links_state.pos[link2_idx, i_b], links_state.quat[link2_idx, i_b] - ) - - equalities_info.eq_type[i_e, i_b] = gs.ti_int(gs.EQUALITY_TYPE.WELD) - equalities_info.eq_obj1id[i_e, i_b] = link1_idx - equalities_info.eq_obj2id[i_e, i_b] = link2_idx - - for i_3 in ti.static(range(3)): - equalities_info.eq_data[i_e, i_b][i_3 + 3] = pos1[i_3] - equalities_info.eq_data[i_e, i_b][i_3] = pos2[i_3] - - relpose = gu.ti_quat_mul(gu.ti_inv_quat(links_state.quat[link1_idx, i_b]), links_state.quat[link2_idx, i_b]) - - equalities_info.eq_data[i_e, i_b][6] = relpose[0] - equalities_info.eq_data[i_e, i_b][7] = relpose[1] - equalities_info.eq_data[i_e, i_b][8] = relpose[2] - equalities_info.eq_data[i_e, i_b][9] = relpose[3] - - equalities_info.eq_data[i_e, i_b][10] = 1.0 - equalities_info.sol_params[i_e, i_b] = ti.Vector( - [2 * static_rigid_sim_config.substep_dt, 1.0e00, 9.0e-01, 9.5e-01, 1.0e-03, 5.0e-01, 2.0e00] - ) - - constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] + 1 - - -@ti.kernel -def kernel_delete_weld_constraint( - link1_idx: ti.i32, - link2_idx: ti.i32, - envs_idx: ti.types.ndarray(), - equalities_info: array_class.EqualitiesInfo, - constraint_state: array_class.ConstraintState, - static_rigid_sim_config: ti.template(), -): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b_ in ti.ndrange(envs_idx.shape[0]): - i_b = envs_idx[i_b_] - for i_e in range(static_rigid_sim_config.n_equalities, constraint_state.ti_n_equalities[i_b]): - if ( - equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD - and equalities_info.eq_obj1id[i_e, i_b] == link1_idx - and equalities_info.eq_obj2id[i_e, i_b] == link2_idx - ): - if i_e < constraint_state.ti_n_equalities[i_b] - 1: - equalities_info.eq_type[i_e, i_b] = equalities_info.eq_type[ - constraint_state.ti_n_equalities[i_b] - 1, i_b - ] - constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1 - - -@ti.kernel -def kernel_collect_welds( - is_padded: ti.template(), - buf: ti.types.ndarray(), - constraint_state: array_class.ConstraintState, - equalities_info: array_class.EqualitiesInfo, - static_rigid_sim_config: ti.template(), -): - B = constraint_state.ti_n_equalities.shape[0] - max_eq = 0 - for e in range(B): - n = constraint_state.ti_n_equalities[e] - if n > max_eq: - max_eq = n - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for e in range(B): - base = 0 - if ti.static(is_padded): - base = e * max_eq - else: - for pe in range(e): - base += constraint_state.ti_n_equalities[pe] - - out = 0 - n = constraint_state.ti_n_equalities[e] - for i in range(n): - if equalities_info.eq_type[i, e] == gs.EQUALITY_TYPE.WELD and out < max_eq: - idx = base + out - buf[idx, 0] = e - buf[idx, 1] = equalities_info.eq_obj1id[i, e] - buf[idx, 2] = equalities_info.eq_obj2id[i, e] - out += 1 diff --git a/pyproject.toml b/pyproject.toml index 04e4dc5f24..89e10594a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dev = [ "syrupy", "huggingface_hub", "wandb", + "ipython", ] docs = [ # Note that currently sphinx 7 does not work, so we must use v6.2.1. Once fixed we can use a later version. diff --git a/tests/test_rigid_physics.py b/tests/test_rigid_physics.py index e189687e1e..9260b8d63e 100644 --- a/tests/test_rigid_physics.py +++ b/tests/test_rigid_physics.py @@ -2395,62 +2395,32 @@ def test_drone_advanced(show_viewer): @pytest.mark.required @pytest.mark.parametrize("backend", [gs.cpu]) -def test_get_weld_constraints_api(show_viewer, tol): +def test_get_constraints_api(show_viewer, tol): scene = gs.Scene( sim_options=gs.options.SimOptions(gravity=(0.0, 0.0, 0.0)), show_viewer=show_viewer, ) - cube1 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.0, 0.0, 0.05))) - cube2 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.2, 0.0, 0.05))) - scene.build(n_envs=1) - - link_a = torch.tensor([cube1.base_link.idx], dtype=gs.tc_int, device=gs.device) - link_b = torch.tensor([cube2.base_link.idx], dtype=gs.tc_int, device=gs.device) - - scene.sim.rigid_solver.add_weld_constraint(link_a, link_b) - scene.step() - - # Test all 4 combinations for solver-level API - combinations = [ - (True, True), # as_tensor=True, to_torch=True - (True, False), # as_tensor=True, to_torch=False - (False, True), # as_tensor=False, to_torch=True - (False, False), # as_tensor=False, to_torch=False - ] + robot = scene.add_entity( + gs.morphs.MJCF( + file="xml/franka_emika_panda/panda.xml", + ), + ) + cube = scene.add_entity(gs.morphs.Box(size=(0.05, 0.05, 0.05), pos=(0.2, 0.0, 0.05))) + scene.build(n_envs=2) - for as_tensor, to_torch in combinations: - welds = scene.sim.rigid_solver.get_weld_constraints(as_tensor=as_tensor, to_torch=to_torch) + link_a, link_b = robot.base_link.idx, cube.base_link.idx + scene.sim.rigid_solver.add_weld_constraint(link_a, link_b, envs_idx=[1]) + with np.testing.assert_raises(AssertionError): + scene.sim.rigid_solver.add_weld_constraint(link_a, link_b, envs_idx=[1]) + for as_tensor, to_torch in ((True, True), (True, False), (False, True), (False, False)): + weld_const_info = scene.sim.rigid_solver.get_weld_constraints(as_tensor, to_torch) + link_a_, link_b_ = weld_const_info["link_a"], weld_const_info["link_b"] if as_tensor: - # Tensor format: welds["obj_a"][0, 0] - assert_allclose( - [welds["obj_a"][0, 0], welds["obj_b"][0, 0]], - [link_a.item(), link_b.item()], - tol=tol, - ) + assert_allclose((link_a_[0], link_b_[0]), ((-1,), (-1,)), tol=0) else: - # Non-tensor format: welds["obj_a"][0][0] - assert_allclose( - [welds["obj_a"][0][0], welds["obj_b"][0][0]], - [link_a.item(), link_b.item()], - tol=tol, - ) - - # Test entity-level API - welds_single = cube1.get_weld_constraints() - assert_allclose( - [welds_single["obj_a"][0], welds_single["obj_b"][0]], - [link_a.item(), link_b.item()], - tol=tol, - ) - - # Test entity-level API with with_entity parameter - welds_with_entity = cube1.get_weld_constraints(with_entity=cube2) - assert_allclose( - [welds_with_entity["obj_a"][0], welds_with_entity["obj_b"][0]], - [link_a.item(), link_b.item()], - tol=tol, - ) + assert_allclose((link_a_[0], link_b_[0]), ((), ()), tol=0) + assert_allclose((link_a_[1], link_b_[1]), ((link_a,), (link_b,)), tol=0) @pytest.mark.parametrize(