Skip to content

Commit f59852a

Browse files
committed
updated get_weld_constraint api after merging main
1 parent 8bd9fce commit f59852a

File tree

1 file changed

+66
-36
lines changed

1 file changed

+66
-36
lines changed

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,53 +2296,48 @@ def update_verts_for_geom(self, i_g):
22962296
self.fixed_verts_state,
22972297
)
22982298

2299-
@gs.assert_built
23002299
def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True):
2301-
n_welds = tuple(self.constraint_solver.ti_n_equalities.to_numpy())
2302-
n_envs = len(n_welds)
2303-
n_welds_max = max(n_welds) if n_welds else 0
2304-
out_size = n_welds_max * n_envs
2300+
n_eqs = tuple(self.constraint_solver.constraint_state.ti_n_equalities.to_numpy())
2301+
n_envs = len(n_eqs)
2302+
n_max = max(n_eqs) if n_eqs else 0
2303+
2304+
if as_tensor:
2305+
out_size = n_envs * n_max
2306+
splits = None
2307+
else:
2308+
cumsum = np.cumsum(n_eqs, dtype=np.int32)
2309+
splits = list(cumsum[:-1])
2310+
out_size = int(cumsum[-1]) if n_envs else 0
23052311

23062312
if to_torch:
23072313
buf = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device)
23082314
else:
23092315
buf = np.full((out_size, 3), -1, dtype=np.int32)
23102316

2311-
if n_welds_max > 0:
2312-
self._kernel_collect_welds(buf)
2313-
2314-
if to_torch:
2315-
buf_view = buf.view(n_envs, n_welds_max, 3)
2316-
else:
2317-
buf_view = buf.reshape(n_envs, n_welds_max, 3)
2318-
env_idx = buf_view[..., 0]
2319-
obj_a = buf_view[..., 1]
2320-
obj_b = buf_view[..., 2]
2317+
if n_max > 0:
2318+
kernel_collect_welds(
2319+
as_tensor,
2320+
buf,
2321+
self.constraint_solver.constraint_state,
2322+
self.equalities_info,
2323+
self._static_rigid_sim_config,
2324+
)
23212325

23222326
if as_tensor:
2327+
if n_envs > 0:
2328+
buf = buf.view(n_envs, n_max, 3) if to_torch else buf.reshape(n_envs, n_max, 3)
2329+
env_idx, obj_a, obj_b = buf[..., 0], buf[..., 1], buf[..., 2]
23232330
return {"env": env_idx, "obj_a": obj_a, "obj_b": obj_b}
2324-
result_a = []
2325-
result_b = []
2326-
for e, count in enumerate(n_welds):
2327-
result_a.append(obj_a[e, :count].copy())
2328-
result_b.append(obj_b[e, :count].copy())
2331+
2332+
if n_envs == 0:
2333+
return {"obj_a": [], "obj_b": []}
2334+
2335+
parts = torch.split(buf, n_eqs) if to_torch else np.split(buf, splits)
2336+
a = [p[:, 1] for p in parts]
2337+
b = [p[:, 2] for p in parts]
23292338
if n_envs == 1:
2330-
return {"obj_a": result_a[0], "obj_b": result_b[0]}
2331-
return {"obj_a": result_a, "obj_b": result_b}
2332-
2333-
@ti.kernel
2334-
def _kernel_collect_welds(self, buf: ti.types.ndarray()):
2335-
for env in range(self.n_envs):
2336-
base = env * self.n_equalities_candidate
2337-
out = 0
2338-
n_eq = self.constraint_solver.ti_n_equalities[env]
2339-
for j in range(n_eq):
2340-
rec = self.equalities_info[j, env]
2341-
if rec.eq_type == gs.EQUALITY_TYPE.WELD and out < self.n_equalities_candidate:
2342-
buf[base + out, 0] = env
2343-
buf[base + out, 1] = rec.eq_obj1id
2344-
buf[base + out, 2] = rec.eq_obj2id
2345-
out += 1
2339+
return {"obj_a": a[0], "obj_b": b[0]}
2340+
return {"obj_a": a, "obj_b": b}
23462341

23472342
# ------------------------------------------------------------------------------------
23482343
# ----------------------------------- properties -------------------------------------
@@ -6765,3 +6760,38 @@ def kernel_delete_weld_constraint(
67656760
constraint_state.ti_n_equalities[i_b] - 1, i_b
67666761
]
67676762
constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1
6763+
6764+
6765+
@ti.kernel
6766+
def kernel_collect_welds(
6767+
is_padded: ti.template(),
6768+
buf: ti.types.ndarray(),
6769+
constraint_state: array_class.ConstraintState,
6770+
equalities_info: array_class.EqualitiesInfo,
6771+
static_rigid_sim_config: ti.template(),
6772+
):
6773+
B = constraint_state.ti_n_equalities.shape[0]
6774+
max_eq = 0
6775+
for e in range(B):
6776+
n = constraint_state.ti_n_equalities[e]
6777+
if n > max_eq:
6778+
max_eq = n
6779+
6780+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
6781+
for e in range(B):
6782+
base = 0
6783+
if ti.static(is_padded):
6784+
base = e * max_eq
6785+
else:
6786+
for pe in range(e):
6787+
base += constraint_state.ti_n_equalities[pe]
6788+
6789+
out = 0
6790+
n = constraint_state.ti_n_equalities[e]
6791+
for i in range(n):
6792+
if equalities_info.eq_type[i, e] == gs.EQUALITY_TYPE.WELD and out < max_eq:
6793+
idx = base + out
6794+
buf[idx, 0] = e
6795+
buf[idx, 1] = equalities_info.eq_obj1id[i, e]
6796+
buf[idx, 2] = equalities_info.eq_obj2id[i, e]
6797+
out += 1

0 commit comments

Comments
 (0)