Skip to content

Commit 91cac2a

Browse files
authored
[MISC] Refactor weld constraint API. (#1536)
1 parent fe3bd62 commit 91cac2a

File tree

7 files changed

+300
-257
lines changed

7 files changed

+300
-257
lines changed

genesis/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def init(
5656
if _initialized:
5757
raise_exception("Genesis already initialized.")
5858

59+
# Make sure evertything is properly destroyed, just in case initialization failed previously
60+
destroy()
61+
5962
# genesis._theme
6063
global _theme
6164
is_theme_valid = theme in ("dark", "light", "dumb")

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,32 +1837,6 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
18371837
if zero_velocity:
18381838
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
18391839

1840-
@gs.assert_built
1841-
def get_weld_constraints(self, with_entity=None, exclude_self_contact=False):
1842-
welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True)
1843-
obj_a = welds["obj_a"]
1844-
obj_b = welds["obj_b"]
1845-
1846-
# Create mask for filtering welds involving this entity
1847-
mask = (obj_a == self.idx) | (obj_b == self.idx)
1848-
1849-
# Additional filtering if with_entity is specified
1850-
if with_entity is not None:
1851-
if self.idx == with_entity.idx:
1852-
if exclude_self_contact:
1853-
gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.")
1854-
# For self-contact, keep only self-welds
1855-
mask = mask & ((obj_a == self.idx) & (obj_b == self.idx))
1856-
else:
1857-
# For cross-entity, keep welds between this entity and with_entity
1858-
mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx))
1859-
1860-
# Apply filtering
1861-
for k in ("obj_a", "obj_b"):
1862-
welds[k] = welds[k][mask]
1863-
1864-
return welds
1865-
18661840
@gs.assert_built
18671841
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
18681842
"""

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,8 @@ def collider_kernel_get_contacts(
545545
):
546546
_B = collider_state.active_buffer.shape[1]
547547
n_contacts_max = gs.ti_int(0)
548+
549+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
548550
for i_b in range(_B):
549551
n_contacts = collider_state.n_contacts[i_b]
550552
if n_contacts > n_contacts_max:

genesis/engine/solvers/rigid/constraint_solver_decomp.py

Lines changed: 268 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import TYPE_CHECKING
2+
23
import numpy as np
3-
import taichi as ti
44
import numpy.typing as npt
5+
import taichi as ti
6+
import torch
57

68
import genesis as gs
79
import genesis.utils.geom as gu
@@ -37,6 +39,8 @@ def __init__(self, rigid_solver: "RigidSolver"):
3739

3840
self.constraint_state = array_class.get_constraint_state(self, self._solver)
3941

42+
self._eq_const_info_cache = {}
43+
4044
# self.ti_n_equalities = ti.field(gs.ti_int, shape=self._solver._batch_shape())
4145
# self.ti_n_equalities.from_numpy(np.full((self._solver._B,), self._solver.n_equalities, dtype=gs.np_int))
4246

@@ -157,11 +161,13 @@ def __init__(self, rigid_solver: "RigidSolver"):
157161
self.reset()
158162

159163
def clear(self, envs_idx: npt.NDArray[np.int32] | None = None):
164+
self._eq_const_info_cache.clear()
160165
if envs_idx is None:
161166
envs_idx = self._solver._scene._envs_idx
162167
constraint_solver_kernel_clear(envs_idx, self._solver._static_rigid_sim_config, self.constraint_state)
163168

164169
def reset(self, envs_idx=None):
170+
self._eq_const_info_cache.clear()
165171
if envs_idx is None:
166172
envs_idx = self._solver._scene._envs_idx
167173
constraint_solver_kernel_reset(
@@ -253,6 +259,137 @@ def resolve(self):
253259
)
254260
# timer.stamp("compute force")
255261

262+
def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True):
263+
# Early return if already pre-computed
264+
eq_const_info = self._eq_const_info_cache.get((as_tensor, to_torch))
265+
if eq_const_info is not None:
266+
return eq_const_info.copy()
267+
268+
n_eqs = tuple(self.constraint_state.ti_n_equalities.to_numpy())
269+
n_envs = len(n_eqs)
270+
n_eqs_max = max(n_eqs)
271+
272+
if as_tensor:
273+
out_size = n_envs * n_eqs_max
274+
else:
275+
*n_eqs_starts, out_size = np.cumsum(n_eqs)
276+
277+
if to_torch:
278+
iout = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device)
279+
fout = torch.zeros((out_size, 6), dtype=gs.tc_float, device=gs.device)
280+
else:
281+
iout = np.full((out_size, 3), -1, dtype=gs.np_int)
282+
fout = np.zeros((out_size, 6), dtype=gs.np_float)
283+
284+
if n_eqs_max > 0:
285+
kernel_get_equality_constraints(
286+
as_tensor,
287+
iout,
288+
fout,
289+
self.constraint_state,
290+
self._solver.equalities_info,
291+
self._solver._static_rigid_sim_config,
292+
)
293+
294+
if as_tensor:
295+
iout = iout.reshape((n_envs, n_eqs_max, 3))
296+
eq_type, obj_a, obj_b = (iout[..., i] for i in range(3))
297+
efc_force = fout.reshape((n_envs, n_eqs_max, 6))
298+
values = (eq_type, obj_a, obj_b, fout)
299+
else:
300+
if to_torch:
301+
iout_chunks = torch.split(iout, n_eqs)
302+
efc_force = torch.split(fout, n_eqs)
303+
else:
304+
iout_chunks = np.split(iout, n_eqs_starts)
305+
efc_force = np.split(fout, n_eqs_starts)
306+
eq_type, obj_a, obj_b = tuple(zip(*([data[..., i] for i in range(3)] for data in iout_chunks)))
307+
308+
values = (eq_type, obj_a, obj_b, efc_force)
309+
eq_const_info = dict(zip(("type", "obj_a", "obj_b", "force"), values))
310+
311+
# Cache equality constraint information before returning
312+
self._eq_const_info_cache[(as_tensor, to_torch)] = eq_const_info
313+
314+
return eq_const_info.copy()
315+
316+
def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True):
317+
eq_const_info = self.get_equality_constraints(as_tensor, to_torch)
318+
eq_type = eq_const_info.pop("type")
319+
320+
weld_const_info = {}
321+
if as_tensor:
322+
weld_mask = eq_type == gs.EQUALITY_TYPE.WELD
323+
n_envs = len(weld_mask)
324+
n_welds = weld_mask.sum(dim=-1) if to_torch else np.sum(weld_mask, axis=-1)
325+
n_welds_max = max(n_welds)
326+
for key, value in eq_const_info.items():
327+
shape = (n_envs, n_welds_max, *value.shape[2:])
328+
if to_torch:
329+
if torch.is_floating_point(value):
330+
weld_const_info[key] = torch.zeros(shape, dtype=value.dtype, device=value.device)
331+
else:
332+
weld_const_info[key] = torch.full(shape, -1, dtype=value.dtype, device=value.device)
333+
else:
334+
if np.issubdtype(value.dtype, np.floating):
335+
weld_const_info[key] = np.zeros(shape, dtype=value.dtype)
336+
else:
337+
weld_const_info[key] = np.full(shape, -1, dtype=value.dtype)
338+
for i_b, (n_welds_i, weld_mask_i) in enumerate(zip(n_welds, weld_mask)):
339+
for eq_value, weld_value in zip(eq_const_info.values(), weld_const_info.values()):
340+
weld_value[i_b, :n_welds_i] = eq_value[i_b, weld_mask_i]
341+
else:
342+
weld_mask_chunks = tuple(eq_type_i == gs.EQUALITY_TYPE.WELD for eq_type_i in eq_type)
343+
for key, value in eq_const_info.items():
344+
weld_const_info[key] = tuple(data[weld_mask] for weld_mask, data in zip(weld_mask_chunks, value))
345+
346+
weld_const_info["link_a"] = weld_const_info.pop("obj_a")
347+
weld_const_info["link_b"] = weld_const_info.pop("obj_b")
348+
349+
return weld_const_info
350+
351+
def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False):
352+
envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
353+
link1_idx, link2_idx = int(link1_idx), int(link2_idx)
354+
355+
if not unsafe:
356+
assert link1_idx >= 0 and link2_idx >= 0
357+
weld_const_info = self.get_weld_constraints(as_tensor=True, to_torch=True)
358+
link_a = weld_const_info["link_a"]
359+
link_b = weld_const_info["link_b"]
360+
assert not (
361+
((link_a == link1_idx) | (link_b == link1_idx)) & ((link_a == link2_idx) | (link_b == link2_idx))
362+
).any()
363+
364+
self._eq_const_info_cache.clear()
365+
overflow = kernel_add_weld_constraint(
366+
link1_idx,
367+
link2_idx,
368+
envs_idx,
369+
self._solver.equalities_info,
370+
self.constraint_state,
371+
self._solver.links_state,
372+
self._solver._static_rigid_sim_config,
373+
)
374+
if overflow:
375+
gs.logger.warning(
376+
"Ignoring dynamically registered weld constraint to avoid exceeding max number of equality constraints"
377+
f"({self._static_rigid_sim_config.n_equalities_candidate}). Please increase the value of "
378+
"RigidSolver's option 'max_dynamic_constraints'."
379+
)
380+
381+
def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False):
382+
envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
383+
self._eq_const_info_cache.clear()
384+
kernel_delete_weld_constraint(
385+
int(link1_idx),
386+
int(link2_idx),
387+
envs_idx,
388+
self._solver.equalities_info,
389+
self.constraint_state,
390+
self._solver._static_rigid_sim_config,
391+
)
392+
256393

257394
@ti.kernel
258395
def constraint_solver_kernel_clear(
@@ -486,11 +623,11 @@ def func_equality_connect(
486623

487624
imp, aref = gu.imp_aref(sol_params, -penetration, jac_qvel, pos_diff[i_3])
488625

489-
diag = ti.max(invweight * (1 - imp) / imp, gs.EPS)
626+
diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS)
490627

491628
constraint_state.diag[n_con, i_b] = diag
492629
constraint_state.aref[n_con, i_b] = aref
493-
constraint_state.efc_D[n_con, i_b] = 1 / diag
630+
constraint_state.efc_D[n_con, i_b] = 1.0 / diag
494631

495632

496633
@ti.func
@@ -564,11 +701,11 @@ def func_equality_joint(
564701

565702
imp, aref = gu.imp_aref(sol_params, -ti.abs(pos), jac_qvel, pos)
566703

567-
diag = ti.max(invweight * (1 - imp) / imp, gs.EPS)
704+
diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS)
568705

569706
constraint_state.diag[n_con, i_b] = diag
570707
constraint_state.aref[n_con, i_b] = aref
571-
constraint_state.efc_D[n_con, i_b] = 1 / diag
708+
constraint_state.efc_D[n_con, i_b] = 1.0 / diag
572709

573710

574711
@ti.kernel
@@ -1939,3 +2076,129 @@ def func_init_solver(
19392076
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
19402077
for i_d, i_b in ti.ndrange(n_dofs, _B):
19412078
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]
2079+
2080+
2081+
@ti.kernel
2082+
def kernel_add_weld_constraint(
2083+
link1_idx: ti.i32,
2084+
link2_idx: ti.i32,
2085+
envs_idx: ti.types.ndarray(),
2086+
equalities_info: array_class.EqualitiesInfo,
2087+
constraint_state: array_class.ConstraintState,
2088+
links_state: array_class.LinksState,
2089+
static_rigid_sim_config: ti.template(),
2090+
) -> ti.i32:
2091+
overflow = gs.ti_bool(False)
2092+
2093+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
2094+
for i_b_ in ti.ndrange(envs_idx.shape[0]):
2095+
i_b = envs_idx[i_b_]
2096+
i_e = constraint_state.ti_n_equalities[i_b]
2097+
if i_e == static_rigid_sim_config.n_equalities_candidate:
2098+
overflow = True
2099+
else:
2100+
shared_pos = links_state.pos[link1_idx, i_b]
2101+
pos1 = gu.ti_inv_transform_by_trans_quat(
2102+
shared_pos, links_state.pos[link1_idx, i_b], links_state.quat[link1_idx, i_b]
2103+
)
2104+
pos2 = gu.ti_inv_transform_by_trans_quat(
2105+
shared_pos, links_state.pos[link2_idx, i_b], links_state.quat[link2_idx, i_b]
2106+
)
2107+
2108+
equalities_info.eq_type[i_e, i_b] = gs.ti_int(gs.EQUALITY_TYPE.WELD)
2109+
equalities_info.eq_obj1id[i_e, i_b] = link1_idx
2110+
equalities_info.eq_obj2id[i_e, i_b] = link2_idx
2111+
2112+
for i_3 in ti.static(range(3)):
2113+
equalities_info.eq_data[i_e, i_b][i_3 + 3] = pos1[i_3]
2114+
equalities_info.eq_data[i_e, i_b][i_3] = pos2[i_3]
2115+
2116+
relpose = gu.ti_quat_mul(gu.ti_inv_quat(links_state.quat[link1_idx, i_b]), links_state.quat[link2_idx, i_b])
2117+
2118+
equalities_info.eq_data[i_e, i_b][6] = relpose[0]
2119+
equalities_info.eq_data[i_e, i_b][7] = relpose[1]
2120+
equalities_info.eq_data[i_e, i_b][8] = relpose[2]
2121+
equalities_info.eq_data[i_e, i_b][9] = relpose[3]
2122+
2123+
equalities_info.eq_data[i_e, i_b][10] = 1.0
2124+
equalities_info.sol_params[i_e, i_b] = ti.Vector(
2125+
[2 * static_rigid_sim_config.substep_dt, 1.0e00, 9.0e-01, 9.5e-01, 1.0e-03, 5.0e-01, 2.0e00]
2126+
)
2127+
2128+
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] + 1
2129+
return overflow
2130+
2131+
2132+
@ti.kernel
2133+
def kernel_delete_weld_constraint(
2134+
link1_idx: ti.i32,
2135+
link2_idx: ti.i32,
2136+
envs_idx: ti.types.ndarray(),
2137+
equalities_info: array_class.EqualitiesInfo,
2138+
constraint_state: array_class.ConstraintState,
2139+
static_rigid_sim_config: ti.template(),
2140+
):
2141+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
2142+
for i_b_ in ti.ndrange(envs_idx.shape[0]):
2143+
i_b = envs_idx[i_b_]
2144+
for i_e in range(static_rigid_sim_config.n_equalities, constraint_state.ti_n_equalities[i_b]):
2145+
if (
2146+
equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD
2147+
and equalities_info.eq_obj1id[i_e, i_b] == link1_idx
2148+
and equalities_info.eq_obj2id[i_e, i_b] == link2_idx
2149+
):
2150+
if i_e < constraint_state.ti_n_equalities[i_b] - 1:
2151+
equalities_info.eq_type[i_e, i_b] = equalities_info.eq_type[
2152+
constraint_state.ti_n_equalities[i_b] - 1, i_b
2153+
]
2154+
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1
2155+
2156+
2157+
@ti.kernel
2158+
def kernel_get_equality_constraints(
2159+
is_padded: ti.template(),
2160+
iout: ti.types.ndarray(),
2161+
fout: ti.types.ndarray(),
2162+
constraint_state: array_class.ConstraintState,
2163+
equalities_info: array_class.EqualitiesInfo,
2164+
static_rigid_sim_config: ti.template(),
2165+
):
2166+
_B = constraint_state.ti_n_equalities.shape[0]
2167+
n_eqs_max = gs.ti_int(0)
2168+
2169+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
2170+
for i_b in range(_B):
2171+
n_eqs = constraint_state.ti_n_equalities[i_b]
2172+
if n_eqs > n_eqs_max:
2173+
n_eqs_max = n_eqs
2174+
2175+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
2176+
for i_b in range(_B):
2177+
i_c_start = gs.ti_int(0)
2178+
i_e_start = gs.ti_int(0)
2179+
if ti.static(is_padded):
2180+
i_e_start = i_b * n_eqs_max
2181+
else:
2182+
for j_b in range(i_b):
2183+
i_e_start = i_e_start + constraint_state.ti_n_equalities[j_b]
2184+
2185+
for i_e_ in range(constraint_state.ti_n_equalities[i_b]):
2186+
i_e = i_e_start + i_e_
2187+
2188+
iout[i_e, 0] = equalities_info.eq_type[i_e_, i_b]
2189+
iout[i_e, 1] = equalities_info.eq_obj1id[i_e_, i_b]
2190+
iout[i_e, 2] = equalities_info.eq_obj2id[i_e_, i_b]
2191+
2192+
if equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.CONNECT:
2193+
for i_c_ in ti.static(range(3)):
2194+
i_c = i_c_start + i_c_
2195+
fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b]
2196+
i_c_start = i_c_start + 3
2197+
elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.WELD:
2198+
for i_c_ in ti.static(range(6)):
2199+
i_c = i_c_start + i_c_
2200+
fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b]
2201+
i_c_start = i_c_start + 6
2202+
elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.JOINT:
2203+
fout[i_e, 0] = constraint_state.efc_force[i_c_start, i_b]
2204+
i_c_start = i_c_start + 1

0 commit comments

Comments
 (0)