Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions genesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def init(
if _initialized:
raise_exception("Genesis already initialized.")

# Make sure evertything is properly destroyed, just in case initialization failed previously
destroy()

# genesis._theme
global _theme
is_theme_valid = theme in ("dark", "light", "dumb")
Expand Down
26 changes: 0 additions & 26 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,32 +1837,6 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def get_weld_constraints(self, with_entity=None, exclude_self_contact=False):
welds = self._solver.get_weld_constraints(as_tensor=True, to_torch=True)
obj_a = welds["obj_a"]
obj_b = welds["obj_b"]

# Create mask for filtering welds involving this entity
mask = (obj_a == self.idx) | (obj_b == self.idx)

# Additional filtering if with_entity is specified
if with_entity is not None:
if self.idx == with_entity.idx:
if exclude_self_contact:
gs.raise_exception("`with_entity` is self but `exclude_self_contact` is True.")
# For self-contact, keep only self-welds
mask = mask & ((obj_a == self.idx) & (obj_b == self.idx))
else:
# For cross-entity, keep welds between this entity and with_entity
mask = mask & ((obj_a == with_entity.idx) | (obj_b == with_entity.idx))

# Apply filtering
for k in ("obj_a", "obj_b"):
welds[k] = welds[k][mask]

return welds

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
"""
Expand Down
2 changes: 2 additions & 0 deletions genesis/engine/solvers/rigid/collider_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ def collider_kernel_get_contacts(
):
_B = collider_state.active_buffer.shape[1]
n_contacts_max = gs.ti_int(0)

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
n_contacts = collider_state.n_contacts[i_b]
if n_contacts > n_contacts_max:
Expand Down
273 changes: 268 additions & 5 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING

import numpy as np
import taichi as ti
import numpy.typing as npt
import taichi as ti
import torch

import genesis as gs
import genesis.utils.geom as gu
Expand Down Expand Up @@ -37,6 +39,8 @@ def __init__(self, rigid_solver: "RigidSolver"):

self.constraint_state = array_class.get_constraint_state(self, self._solver)

self._eq_const_info_cache = {}

# self.ti_n_equalities = ti.field(gs.ti_int, shape=self._solver._batch_shape())
# self.ti_n_equalities.from_numpy(np.full((self._solver._B,), self._solver.n_equalities, dtype=gs.np_int))

Expand Down Expand Up @@ -157,11 +161,13 @@ def __init__(self, rigid_solver: "RigidSolver"):
self.reset()

def clear(self, envs_idx: npt.NDArray[np.int32] | None = None):
self._eq_const_info_cache.clear()
if envs_idx is None:
envs_idx = self._solver._scene._envs_idx
constraint_solver_kernel_clear(envs_idx, self._solver._static_rigid_sim_config, self.constraint_state)

def reset(self, envs_idx=None):
self._eq_const_info_cache.clear()
if envs_idx is None:
envs_idx = self._solver._scene._envs_idx
constraint_solver_kernel_reset(
Expand Down Expand Up @@ -253,6 +259,137 @@ def resolve(self):
)
# timer.stamp("compute force")

def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True):
# Early return if already pre-computed
eq_const_info = self._eq_const_info_cache.get((as_tensor, to_torch))
if eq_const_info is not None:
return eq_const_info.copy()

n_eqs = tuple(self.constraint_state.ti_n_equalities.to_numpy())
n_envs = len(n_eqs)
n_eqs_max = max(n_eqs)

if as_tensor:
out_size = n_envs * n_eqs_max
else:
*n_eqs_starts, out_size = np.cumsum(n_eqs)

if to_torch:
iout = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device)
fout = torch.zeros((out_size, 6), dtype=gs.tc_float, device=gs.device)
else:
iout = np.full((out_size, 3), -1, dtype=gs.np_int)
fout = np.zeros((out_size, 6), dtype=gs.np_float)

if n_eqs_max > 0:
kernel_get_equality_constraints(
as_tensor,
iout,
fout,
self.constraint_state,
self._solver.equalities_info,
self._solver._static_rigid_sim_config,
)

if as_tensor:
iout = iout.reshape((n_envs, n_eqs_max, 3))
eq_type, obj_a, obj_b = (iout[..., i] for i in range(3))
efc_force = fout.reshape((n_envs, n_eqs_max, 6))
values = (eq_type, obj_a, obj_b, fout)
else:
if to_torch:
iout_chunks = torch.split(iout, n_eqs)
efc_force = torch.split(fout, n_eqs)
else:
iout_chunks = np.split(iout, n_eqs_starts)
efc_force = np.split(fout, n_eqs_starts)
eq_type, obj_a, obj_b = tuple(zip(*([data[..., i] for i in range(3)] for data in iout_chunks)))

values = (eq_type, obj_a, obj_b, efc_force)
eq_const_info = dict(zip(("type", "obj_a", "obj_b", "force"), values))

# Cache equality constraint information before returning
self._eq_const_info_cache[(as_tensor, to_torch)] = eq_const_info

return eq_const_info.copy()

def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True):
eq_const_info = self.get_equality_constraints(as_tensor, to_torch)
eq_type = eq_const_info.pop("type")

weld_const_info = {}
if as_tensor:
weld_mask = eq_type == gs.EQUALITY_TYPE.WELD
n_envs = len(weld_mask)
n_welds = weld_mask.sum(dim=-1) if to_torch else np.sum(weld_mask, axis=-1)
n_welds_max = max(n_welds)
for key, value in eq_const_info.items():
shape = (n_envs, n_welds_max, *value.shape[2:])
if to_torch:
if torch.is_floating_point(value):
weld_const_info[key] = torch.zeros(shape, dtype=value.dtype, device=value.device)
else:
weld_const_info[key] = torch.full(shape, -1, dtype=value.dtype, device=value.device)
else:
if np.issubdtype(value.dtype, np.floating):
weld_const_info[key] = np.zeros(shape, dtype=value.dtype)
else:
weld_const_info[key] = np.full(shape, -1, dtype=value.dtype)
for i_b, (n_welds_i, weld_mask_i) in enumerate(zip(n_welds, weld_mask)):
for eq_value, weld_value in zip(eq_const_info.values(), weld_const_info.values()):
weld_value[i_b, :n_welds_i] = eq_value[i_b, weld_mask_i]
else:
weld_mask_chunks = tuple(eq_type_i == gs.EQUALITY_TYPE.WELD for eq_type_i in eq_type)
for key, value in eq_const_info.items():
weld_const_info[key] = tuple(data[weld_mask] for weld_mask, data in zip(weld_mask_chunks, value))

weld_const_info["link_a"] = weld_const_info.pop("obj_a")
weld_const_info["link_b"] = weld_const_info.pop("obj_b")

return weld_const_info

def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False):
envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
link1_idx, link2_idx = int(link1_idx), int(link2_idx)

if not unsafe:
assert link1_idx >= 0 and link2_idx >= 0
weld_const_info = self.get_weld_constraints(as_tensor=True, to_torch=True)
link_a = weld_const_info["link_a"]
link_b = weld_const_info["link_b"]
assert not (
((link_a == link1_idx) | (link_b == link1_idx)) & ((link_a == link2_idx) | (link_b == link2_idx))
).any()

self._eq_const_info_cache.clear()
overflow = kernel_add_weld_constraint(
link1_idx,
link2_idx,
envs_idx,
self._solver.equalities_info,
self.constraint_state,
self._solver.links_state,
self._solver._static_rigid_sim_config,
)
if overflow:
gs.logger.warning(
"Ignoring dynamically registered weld constraint to avoid exceeding max number of equality constraints"
f"({self._static_rigid_sim_config.n_equalities_candidate}). Please increase the value of "
"RigidSolver's option 'max_dynamic_constraints'."
)

def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False):
envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
self._eq_const_info_cache.clear()
kernel_delete_weld_constraint(
int(link1_idx),
int(link2_idx),
envs_idx,
self._solver.equalities_info,
self.constraint_state,
self._solver._static_rigid_sim_config,
)


@ti.kernel
def constraint_solver_kernel_clear(
Expand Down Expand Up @@ -486,11 +623,11 @@ def func_equality_connect(

imp, aref = gu.imp_aref(sol_params, -penetration, jac_qvel, pos_diff[i_3])

diag = ti.max(invweight * (1 - imp) / imp, gs.EPS)
diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS)

constraint_state.diag[n_con, i_b] = diag
constraint_state.aref[n_con, i_b] = aref
constraint_state.efc_D[n_con, i_b] = 1 / diag
constraint_state.efc_D[n_con, i_b] = 1.0 / diag


@ti.func
Expand Down Expand Up @@ -564,11 +701,11 @@ def func_equality_joint(

imp, aref = gu.imp_aref(sol_params, -ti.abs(pos), jac_qvel, pos)

diag = ti.max(invweight * (1 - imp) / imp, gs.EPS)
diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS)

constraint_state.diag[n_con, i_b] = diag
constraint_state.aref[n_con, i_b] = aref
constraint_state.efc_D[n_con, i_b] = 1 / diag
constraint_state.efc_D[n_con, i_b] = 1.0 / diag


@ti.kernel
Expand Down Expand Up @@ -1939,3 +2076,129 @@ def func_init_solver(
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in ti.ndrange(n_dofs, _B):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]


@ti.kernel
def kernel_add_weld_constraint(
link1_idx: ti.i32,
link2_idx: ti.i32,
envs_idx: ti.types.ndarray(),
equalities_info: array_class.EqualitiesInfo,
constraint_state: array_class.ConstraintState,
links_state: array_class.LinksState,
static_rigid_sim_config: ti.template(),
) -> ti.i32:
overflow = gs.ti_bool(False)

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
for i_b_ in ti.ndrange(envs_idx.shape[0]):
i_b = envs_idx[i_b_]
i_e = constraint_state.ti_n_equalities[i_b]
if i_e == static_rigid_sim_config.n_equalities_candidate:
overflow = True
else:
shared_pos = links_state.pos[link1_idx, i_b]
pos1 = gu.ti_inv_transform_by_trans_quat(
shared_pos, links_state.pos[link1_idx, i_b], links_state.quat[link1_idx, i_b]
)
pos2 = gu.ti_inv_transform_by_trans_quat(
shared_pos, links_state.pos[link2_idx, i_b], links_state.quat[link2_idx, i_b]
)

equalities_info.eq_type[i_e, i_b] = gs.ti_int(gs.EQUALITY_TYPE.WELD)
equalities_info.eq_obj1id[i_e, i_b] = link1_idx
equalities_info.eq_obj2id[i_e, i_b] = link2_idx

for i_3 in ti.static(range(3)):
equalities_info.eq_data[i_e, i_b][i_3 + 3] = pos1[i_3]
equalities_info.eq_data[i_e, i_b][i_3] = pos2[i_3]

relpose = gu.ti_quat_mul(gu.ti_inv_quat(links_state.quat[link1_idx, i_b]), links_state.quat[link2_idx, i_b])

equalities_info.eq_data[i_e, i_b][6] = relpose[0]
equalities_info.eq_data[i_e, i_b][7] = relpose[1]
equalities_info.eq_data[i_e, i_b][8] = relpose[2]
equalities_info.eq_data[i_e, i_b][9] = relpose[3]

equalities_info.eq_data[i_e, i_b][10] = 1.0
equalities_info.sol_params[i_e, i_b] = ti.Vector(
[2 * static_rigid_sim_config.substep_dt, 1.0e00, 9.0e-01, 9.5e-01, 1.0e-03, 5.0e-01, 2.0e00]
)

constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] + 1
return overflow


@ti.kernel
def kernel_delete_weld_constraint(
link1_idx: ti.i32,
link2_idx: ti.i32,
envs_idx: ti.types.ndarray(),
equalities_info: array_class.EqualitiesInfo,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
for i_b_ in ti.ndrange(envs_idx.shape[0]):
i_b = envs_idx[i_b_]
for i_e in range(static_rigid_sim_config.n_equalities, constraint_state.ti_n_equalities[i_b]):
if (
equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD
and equalities_info.eq_obj1id[i_e, i_b] == link1_idx
and equalities_info.eq_obj2id[i_e, i_b] == link2_idx
):
if i_e < constraint_state.ti_n_equalities[i_b] - 1:
equalities_info.eq_type[i_e, i_b] = equalities_info.eq_type[
constraint_state.ti_n_equalities[i_b] - 1, i_b
]
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1


@ti.kernel
def kernel_get_equality_constraints(
is_padded: ti.template(),
iout: ti.types.ndarray(),
fout: ti.types.ndarray(),
constraint_state: array_class.ConstraintState,
equalities_info: array_class.EqualitiesInfo,
static_rigid_sim_config: ti.template(),
):
_B = constraint_state.ti_n_equalities.shape[0]
n_eqs_max = gs.ti_int(0)

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
n_eqs = constraint_state.ti_n_equalities[i_b]
if n_eqs > n_eqs_max:
n_eqs_max = n_eqs

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
i_c_start = gs.ti_int(0)
i_e_start = gs.ti_int(0)
if ti.static(is_padded):
i_e_start = i_b * n_eqs_max
else:
for j_b in range(i_b):
i_e_start = i_e_start + constraint_state.ti_n_equalities[j_b]

for i_e_ in range(constraint_state.ti_n_equalities[i_b]):
i_e = i_e_start + i_e_

iout[i_e, 0] = equalities_info.eq_type[i_e_, i_b]
iout[i_e, 1] = equalities_info.eq_obj1id[i_e_, i_b]
iout[i_e, 2] = equalities_info.eq_obj2id[i_e_, i_b]

if equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.CONNECT:
for i_c_ in ti.static(range(3)):
i_c = i_c_start + i_c_
fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b]
i_c_start = i_c_start + 3
elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.WELD:
for i_c_ in ti.static(range(6)):
i_c = i_c_start + i_c_
fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b]
i_c_start = i_c_start + 6
elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.JOINT:
fout[i_e, 0] = constraint_state.efc_force[i_c_start, i_b]
i_c_start = i_c_start + 1
Loading