Skip to content

Commit b091842

Browse files
authored
[FEATURE] Add 'get_weld_constraints' API. (#1370)
1 parent f823d6a commit b091842

File tree

3 files changed

+168
-0
lines changed

3 files changed

+168
-0
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,6 +1837,32 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
18371837
if zero_velocity:
18381838
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
18391839

1840+
@gs.assert_built
1841+
def get_weld_constraints(self, with_entity=None, exclude_self_contact=False):
1842+
welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True)
1843+
obj_a = welds["obj_a"]
1844+
obj_b = welds["obj_b"]
1845+
1846+
# Create mask for filtering welds involving this entity
1847+
mask = (obj_a == self.idx) | (obj_b == self.idx)
1848+
1849+
# Additional filtering if with_entity is specified
1850+
if with_entity is not None:
1851+
if self.idx == with_entity.idx:
1852+
if exclude_self_contact:
1853+
gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.")
1854+
# For self-contact, keep only self-welds
1855+
mask = mask & ((obj_a == self.idx) & (obj_b == self.idx))
1856+
else:
1857+
# For cross-entity, keep welds between this entity and with_entity
1858+
mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx))
1859+
1860+
# Apply filtering
1861+
for k in ("obj_a", "obj_b"):
1862+
welds[k] = welds[k][mask]
1863+
1864+
return welds
1865+
18401866
@gs.assert_built
18411867
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
18421868
"""

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,53 @@ def update_verts_for_geom(self, i_g):
22932293
self.fixed_verts_state,
22942294
)
22952295

2296+
def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True):
2297+
n_eqs = tuple(self.constraint_solver.constraint_state.ti_n_equalities.to_numpy())
2298+
n_envs = len(n_eqs)
2299+
n_max = max(n_eqs) if n_eqs else 0
2300+
2301+
if as_tensor:
2302+
out_size = n_envs * n_max
2303+
else:
2304+
cumsum = np.cumsum(n_eqs, dtype=np.int32)
2305+
out_size = int(cumsum[-1]) if n_envs else 0
2306+
2307+
if to_torch:
2308+
buf = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device)
2309+
else:
2310+
buf = np.full((out_size, 3), -1, dtype=np.int32)
2311+
2312+
if n_max > 0:
2313+
kernel_collect_welds(
2314+
as_tensor,
2315+
buf,
2316+
self.constraint_solver.constraint_state,
2317+
self.equalities_info,
2318+
self._static_rigid_sim_config,
2319+
)
2320+
2321+
if n_envs > 0:
2322+
if as_tensor:
2323+
buf = buf.reshape((n_envs, n_max, 3))
2324+
obj_a = buf[..., 1]
2325+
obj_b = buf[..., 2]
2326+
else:
2327+
if to_torch:
2328+
data_chunks = torch.split(buf, n_eqs)
2329+
else:
2330+
splits = list(np.cumsum(n_eqs, dtype=np.int32)[:-1])
2331+
data_chunks = np.split(buf, splits)
2332+
obj_a, obj_b = tuple(zip(*((data[:, 1], data[:, 2]) for data in data_chunks)))
2333+
else:
2334+
if to_torch:
2335+
obj_a = torch.empty((0,), dtype=gs.tc_int, device=gs.device)
2336+
obj_b = torch.empty((0,), dtype=gs.tc_int, device=gs.device)
2337+
else:
2338+
obj_a = []
2339+
obj_b = []
2340+
2341+
return {"obj_a": obj_a, "obj_b": obj_b}
2342+
22962343
# ------------------------------------------------------------------------------------
22972344
# ----------------------------------- properties -------------------------------------
22982345
# ------------------------------------------------------------------------------------
@@ -6718,3 +6765,38 @@ def kernel_delete_weld_constraint(
67186765
constraint_state.ti_n_equalities[i_b] - 1, i_b
67196766
]
67206767
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1
6768+
6769+
6770+
@ti.kernel
6771+
def kernel_collect_welds(
6772+
is_padded: ti.template(),
6773+
buf: ti.types.ndarray(),
6774+
constraint_state: array_class.ConstraintState,
6775+
equalities_info: array_class.EqualitiesInfo,
6776+
static_rigid_sim_config: ti.template(),
6777+
):
6778+
B = constraint_state.ti_n_equalities.shape[0]
6779+
max_eq = 0
6780+
for e in range(B):
6781+
n = constraint_state.ti_n_equalities[e]
6782+
if n > max_eq:
6783+
max_eq = n
6784+
6785+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
6786+
for e in range(B):
6787+
base = 0
6788+
if ti.static(is_padded):
6789+
base = e * max_eq
6790+
else:
6791+
for pe in range(e):
6792+
base += constraint_state.ti_n_equalities[pe]
6793+
6794+
out = 0
6795+
n = constraint_state.ti_n_equalities[e]
6796+
for i in range(n):
6797+
if equalities_info.eq_type[i, e] == gs.EQUALITY_TYPE.WELD and out < max_eq:
6798+
idx = base + out
6799+
buf[idx, 0] = e
6800+
buf[idx, 1] = equalities_info.eq_obj1id[i, e]
6801+
buf[idx, 2] = equalities_info.eq_obj2id[i, e]
6802+
out += 1

tests/test_rigid_physics.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,66 @@ def test_drone_advanced(show_viewer):
23932393
assert abs(quat_1[2] - quat_2[2]) < tol
23942394

23952395

2396+
@pytest.mark.required
2397+
@pytest.mark.parametrize("backend", [gs.cpu])
2398+
def test_get_weld_constraints_api(show_viewer, tol):
2399+
scene = gs.Scene(
2400+
sim_options=gs.options.SimOptions(gravity=(0.0, 0.0, 0.0)),
2401+
show_viewer=show_viewer,
2402+
)
2403+
cube1 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.0, 0.0, 0.05)))
2404+
cube2 = scene.add_entity(gs.morphs.Box(size=(0.05,) * 3, pos=(0.2, 0.0, 0.05)))
2405+
scene.build(n_envs=1)
2406+
2407+
link_a = torch.tensor([cube1.base_link.idx], dtype=gs.tc_int, device=gs.device)
2408+
link_b = torch.tensor([cube2.base_link.idx], dtype=gs.tc_int, device=gs.device)
2409+
2410+
scene.sim.rigid_solver.add_weld_constraint(link_a, link_b)
2411+
scene.step()
2412+
2413+
# Test all 4 combinations for solver-level API
2414+
combinations = [
2415+
(True, True), # as_tensor=True, to_torch=True
2416+
(True, False), # as_tensor=True, to_torch=False
2417+
(False, True), # as_tensor=False, to_torch=True
2418+
(False, False), # as_tensor=False, to_torch=False
2419+
]
2420+
2421+
for as_tensor, to_torch in combinations:
2422+
welds = scene.sim.rigid_solver.get_weld_constraints(as_tensor=as_tensor, to_torch=to_torch)
2423+
2424+
if as_tensor:
2425+
# Tensor format: welds["obj_a"][0, 0]
2426+
assert_allclose(
2427+
[welds["obj_a"][0, 0], welds["obj_b"][0, 0]],
2428+
[link_a.item(), link_b.item()],
2429+
tol=tol,
2430+
)
2431+
else:
2432+
# Non-tensor format: welds["obj_a"][0][0]
2433+
assert_allclose(
2434+
[welds["obj_a"][0][0], welds["obj_b"][0][0]],
2435+
[link_a.item(), link_b.item()],
2436+
tol=tol,
2437+
)
2438+
2439+
# Test entity-level API
2440+
welds_single = cube1.get_weld_constraints()
2441+
assert_allclose(
2442+
[welds_single["obj_a"][0], welds_single["obj_b"][0]],
2443+
[link_a.item(), link_b.item()],
2444+
tol=tol,
2445+
)
2446+
2447+
# Test entity-level API with with_entity parameter
2448+
welds_with_entity = cube1.get_weld_constraints(with_entity=cube2)
2449+
assert_allclose(
2450+
[welds_with_entity["obj_a"][0], welds_with_entity["obj_b"][0]],
2451+
[link_a.item(), link_b.item()],
2452+
tol=tol,
2453+
)
2454+
2455+
23962456
@pytest.mark.parametrize(
23972457
"n_envs, batched, backend",
23982458
[

0 commit comments

Comments
 (0)