diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 9379b4a2af..5c3ad58609 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1837,6 +1837,32 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns if zero_velocity: self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + @gs.assert_built + def get_weld_constraints(self, with_entity=None, exclude_self_contact=False): + welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True) + obj_a = welds["obj_a"] + obj_b = welds["obj_b"] + + # Create mask for filtering welds involving this entity + mask = (obj_a == self.idx) | (obj_b == self.idx) + + # Additional filtering if with_entity is specified + if with_entity is not None: + if self.idx == with_entity.idx: + if exclude_self_contact: + gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.") + # For self-contact, keep only self-welds + mask = mask & ((obj_a == self.idx) & (obj_b == self.idx)) + else: + # For cross-entity, keep welds between this entity and with_entity + mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx)) + + # Apply filtering + for k in ("obj_a", "obj_b"): + welds[k] = welds[k][mask] + + return welds + @gs.assert_built def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): """ diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 7adb8819e2..afb4aa15c3 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -2293,6 +2293,53 @@ def update_verts_for_geom(self, i_g): self.fixed_verts_state, ) + def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True): + n_eqs = tuple(self.constraint_solver.constraint_state.ti_n_equalities.to_numpy()) + n_envs = len(n_eqs) + n_max = max(n_eqs) if n_eqs else 0 + + if as_tensor: + out_size = n_envs * n_max + else: + cumsum = np.cumsum(n_eqs, dtype=np.int32) + out_size = int(cumsum[-1]) if n_envs else 0 + + if to_torch: + buf = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device) + else: + buf = np.full((out_size, 3), -1, dtype=np.int32) + + if n_max > 0: + kernel_collect_welds( + as_tensor, + buf, + self.constraint_solver.constraint_state, + self.equalities_info, + self._static_rigid_sim_config, + ) + + if n_envs > 0: + if as_tensor: + buf = buf.reshape((n_envs, n_max, 3)) + obj_a = buf[..., 1] + obj_b = buf[..., 2] + else: + if to_torch: + data_chunks = torch.split(buf, n_eqs) + else: + splits = list(np.cumsum(n_eqs, dtype=np.int32)[:-1]) + data_chunks = np.split(buf, splits) + obj_a, obj_b = tuple(zip(*((data[:, 1], data[:, 2]) for data in data_chunks))) + else: + if to_torch: + obj_a = torch.empty((0,), dtype=gs.tc_int, device=gs.device) + obj_b = torch.empty((0,), dtype=gs.tc_int, device=gs.device) + else: + obj_a = [] + obj_b = [] + + return {"obj_a": obj_a, "obj_b": obj_b} + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ @@ -6718,3 +6765,38 @@ def kernel_delete_weld_constraint( constraint_state.ti_n_equalities[i_b] - 1, i_b ] constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1 + + +@ti.kernel +def kernel_collect_welds( + is_padded: ti.template(), + buf: ti.types.ndarray(), + constraint_state: array_class.ConstraintState, + equalities_info: array_class.EqualitiesInfo, + static_rigid_sim_config: ti.template(), +): + B = constraint_state.ti_n_equalities.shape[0] + max_eq = 0 + for e in range(B): + n = constraint_state.ti_n_equalities[e] + if n > max_eq: + max_eq = n + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for e in range(B): + base = 0 + if ti.static(is_padded): + base = e * max_eq + else: + for pe in range(e): + base += constraint_state.ti_n_equalities[pe] + + out = 0 + n = constraint_state.ti_n_equalities[e] + for i in range(n): + if equalities_info.eq_type[i, e] == gs.EQUALITY_TYPE.WELD and out < max_eq: + idx = base + out + buf[idx, 0] = e + buf[idx, 1] = equalities_info.eq_obj1id[i, e] + buf[idx, 2] = equalities_info.eq_obj2id[i, e] + out += 1 diff --git a/tests/test_rigid_physics.py b/tests/test_rigid_physics.py index 3f7db5352c..e189687e1e 100644 --- a/tests/test_rigid_physics.py +++ b/tests/test_rigid_physics.py @@ -2393,6 +2393,66 @@ def test_drone_advanced(show_viewer): assert abs(quat_1[2] - quat_2[2]) < tol +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu]) +def test_get_weld_constraints_api(show_viewer, tol): + scene = gs.Scene( + sim_options=gs.options.SimOptions(gravity=(0.0, 0.0, 0.0)), + show_viewer=show_viewer, + ) + cube1 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.0, 0.0, 0.05))) + cube2 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.2, 0.0, 0.05))) + scene.build(n_envs=1) + + link_a = torch.tensor([cube1.base_link.idx], dtype=gs.tc_int, device=gs.device) + link_b = torch.tensor([cube2.base_link.idx], dtype=gs.tc_int, device=gs.device) + + scene.sim.rigid_solver.add_weld_constraint(link_a, link_b) + scene.step() + + # Test all 4 combinations for solver-level API + combinations = [ + (True, True), # as_tensor=True, to_torch=True + (True, False), # as_tensor=True, to_torch=False + (False, True), # as_tensor=False, to_torch=True + (False, False), # as_tensor=False, to_torch=False + ] + + for as_tensor, to_torch in combinations: + welds = scene.sim.rigid_solver.get_weld_constraints(as_tensor=as_tensor, to_torch=to_torch) + + if as_tensor: + # Tensor format: welds["obj_a"][0, 0] + assert_allclose( + [welds["obj_a"][0, 0], welds["obj_b"][0, 0]], + [link_a.item(), link_b.item()], + tol=tol, + ) + else: + # Non-tensor format: welds["obj_a"][0][0] + assert_allclose( + [welds["obj_a"][0][0], welds["obj_b"][0][0]], + [link_a.item(), link_b.item()], + tol=tol, + ) + + # Test entity-level API + welds_single = cube1.get_weld_constraints() + assert_allclose( + [welds_single["obj_a"][0], welds_single["obj_b"][0]], + [link_a.item(), link_b.item()], + tol=tol, + ) + + # Test entity-level API with with_entity parameter + welds_with_entity = cube1.get_weld_constraints(with_entity=cube2) + assert_allclose( + [welds_with_entity["obj_a"][0], welds_with_entity["obj_b"][0]], + [link_a.item(), link_b.item()], + tol=tol, + ) + + @pytest.mark.parametrize( "n_envs, batched, backend", [