Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c176525
get_weld_contraints api and unit test
LeonLiu4 Jul 7, 2025
3cbc525
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 7, 2025
0cdd727
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 8, 2025
8d35d0e
dictionary return in get_weld
LeonLiu4 Jul 8, 2025
eb5eb84
mend
LeonLiu4 Jul 8, 2025
e463433
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 8, 2025
88fc6b9
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 9, 2025
17c3f82
taichi kernel and changed API
LeonLiu4 Jul 9, 2025
4837d9b
entity get_weld and unit test
LeonLiu4 Jul 9, 2025
e8d4a78
Update test_rigid_physics.py
LeonLiu4 Jul 9, 2025
4d6825e
velocity check in unit test
LeonLiu4 Jul 10, 2025
64ac3eb
mend
LeonLiu4 Jul 10, 2025
dc1124a
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 10, 2025
0a5efde
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 10, 2025
7ab12c5
mend
LeonLiu4 Jul 11, 2025
d53a5a4
Merge branch 'main' into feature/get-weld
YilingQiao Jul 11, 2025
ab2abfb
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 12, 2025
4ade718
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 14, 2025
97feebe
no flattening array
LeonLiu4 Jul 14, 2025
bf506ec
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 15, 2025
b093928
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 15, 2025
69922ce
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 16, 2025
543dceb
no 3D tensor and fixed formatting in unit tests
LeonLiu4 Jul 16, 2025
346a145
changed entity level get_weld
LeonLiu4 Jul 16, 2025
b232d5f
Merge branch 'main' into feature/get-weld
LeonLiu4 Jul 23, 2025
0ea2e96
mend
LeonLiu4 Jul 23, 2025
8bd9fce
Merge branch 'main' into feature/get-weld
LeonLiu4 Aug 4, 2025
f59852a
updated get_weld_constraint api after merging main
LeonLiu4 Aug 4, 2025
f742a9a
api tester, removed torch/tensor options and optional key, cleaned up…
LeonLiu4 Aug 8, 2025
4bb2c58
Merge branch 'main' into feature/get-weld
LeonLiu4 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,32 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def get_weld_constraints(self, with_entity=None, exclude_self_contact=False):
welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True)
obj_a = welds["obj_a"]
obj_b = welds["obj_b"]

# Create mask for filtering welds involving this entity
mask = (obj_a == self.idx) | (obj_b == self.idx)

# Additional filtering if with_entity is specified
if with_entity is not None:
if self.idx == with_entity.idx:
if exclude_self_contact:
gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.")
# For self-contact, keep only self-welds
mask = mask & ((obj_a == self.idx) & (obj_b == self.idx))
else:
# For cross-entity, keep welds between this entity and with_entity
mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx))

# Apply filtering
for k in ("obj_a", "obj_b"):
welds[k] = welds[k][mask]

return welds

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
"""
Expand Down
82 changes: 82 additions & 0 deletions genesis/engine/solvers/rigid/rigid_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,53 @@ def update_verts_for_geom(self, i_g):
self.fixed_verts_state,
)

def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True):
n_eqs = tuple(self.constraint_solver.constraint_state.ti_n_equalities.to_numpy())
n_envs = len(n_eqs)
n_max = max(n_eqs) if n_eqs else 0

if as_tensor:
out_size = n_envs * n_max
else:
cumsum = np.cumsum(n_eqs, dtype=np.int32)
out_size = int(cumsum[-1]) if n_envs else 0

if to_torch:
buf = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device)
else:
buf = np.full((out_size, 3), -1, dtype=np.int32)

if n_max > 0:
kernel_collect_welds(
as_tensor,
buf,
self.constraint_solver.constraint_state,
self.equalities_info,
self._static_rigid_sim_config,
)

if n_envs > 0:
if as_tensor:
buf = buf.reshape((n_envs, n_max, 3))
obj_a = buf[..., 1]
obj_b = buf[..., 2]
else:
if to_torch:
data_chunks = torch.split(buf, n_eqs)
else:
splits = list(np.cumsum(n_eqs, dtype=np.int32)[:-1])
data_chunks = np.split(buf, splits)
obj_a, obj_b = tuple(zip(*((data[:, 1], data[:, 2]) for data in data_chunks)))
else:
if to_torch:
obj_a = torch.empty((0,), dtype=gs.tc_int, device=gs.device)
obj_b = torch.empty((0,), dtype=gs.tc_int, device=gs.device)
else:
obj_a = []
obj_b = []

return {"obj_a": obj_a, "obj_b": obj_b}

# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -6718,3 +6765,38 @@ def kernel_delete_weld_constraint(
constraint_state.ti_n_equalities[i_b] - 1, i_b
]
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1


@ti.kernel
def kernel_collect_welds(
is_padded: ti.template(),
buf: ti.types.ndarray(),
constraint_state: array_class.ConstraintState,
equalities_info: array_class.EqualitiesInfo,
static_rigid_sim_config: ti.template(),
):
B = constraint_state.ti_n_equalities.shape[0]
max_eq = 0
for e in range(B):
n = constraint_state.ti_n_equalities[e]
if n > max_eq:
max_eq = n

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for e in range(B):
base = 0
if ti.static(is_padded):
base = e * max_eq
else:
for pe in range(e):
base += constraint_state.ti_n_equalities[pe]

out = 0
n = constraint_state.ti_n_equalities[e]
for i in range(n):
if equalities_info.eq_type[i, e] == gs.EQUALITY_TYPE.WELD and out < max_eq:
idx = base + out
buf[idx, 0] = e
buf[idx, 1] = equalities_info.eq_obj1id[i, e]
buf[idx, 2] = equalities_info.eq_obj2id[i, e]
out += 1
60 changes: 60 additions & 0 deletions tests/test_rigid_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,66 @@ def test_drone_advanced(show_viewer):
assert abs(quat_1[2] - quat_2[2]) < tol


@pytest.mark.required
@pytest.mark.parametrize("backend", [gs.cpu])
def test_get_weld_constraints_api(show_viewer, tol):
scene = gs.Scene(
sim_options=gs.options.SimOptions(gravity=(0.0, 0.0, 0.0)),
show_viewer=show_viewer,
)
cube1 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.0, 0.0, 0.05)))
cube2 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.2, 0.0, 0.05)))
scene.build(n_envs=1)

link_a = torch.tensor([cube1.base_link.idx], dtype=gs.tc_int, device=gs.device)
link_b = torch.tensor([cube2.base_link.idx], dtype=gs.tc_int, device=gs.device)

scene.sim.rigid_solver.add_weld_constraint(link_a, link_b)
scene.step()

# Test all 4 combinations for solver-level API
combinations = [
(True, True), # as_tensor=True, to_torch=True
(True, False), # as_tensor=True, to_torch=False
(False, True), # as_tensor=False, to_torch=True
(False, False), # as_tensor=False, to_torch=False
]

for as_tensor, to_torch in combinations:
welds = scene.sim.rigid_solver.get_weld_constraints(as_tensor=as_tensor, to_torch=to_torch)

if as_tensor:
# Tensor format: welds["obj_a"][0, 0]
assert_allclose(
[welds["obj_a"][0, 0], welds["obj_b"][0, 0]],
[link_a.item(), link_b.item()],
tol=tol,
)
else:
# Non-tensor format: welds["obj_a"][0][0]
assert_allclose(
[welds["obj_a"][0][0], welds["obj_b"][0][0]],
[link_a.item(), link_b.item()],
tol=tol,
)

# Test entity-level API
welds_single = cube1.get_weld_constraints()
assert_allclose(
[welds_single["obj_a"][0], welds_single["obj_b"][0]],
[link_a.item(), link_b.item()],
tol=tol,
)

# Test entity-level API with with_entity parameter
welds_with_entity = cube1.get_weld_constraints(with_entity=cube2)
assert_allclose(
[welds_with_entity["obj_a"][0], welds_with_entity["obj_b"][0]],
[link_a.item(), link_b.item()],
tol=tol,
)


@pytest.mark.parametrize(
"n_envs, batched, backend",
[
Expand Down