|
1 | 1 | from typing import TYPE_CHECKING |
| 2 | + |
2 | 3 | import numpy as np |
3 | | -import taichi as ti |
4 | 4 | import numpy.typing as npt |
| 5 | +import taichi as ti |
| 6 | +import torch |
5 | 7 |
|
6 | 8 | import genesis as gs |
7 | 9 | import genesis.utils.geom as gu |
@@ -37,6 +39,8 @@ def __init__(self, rigid_solver: "RigidSolver"): |
37 | 39 |
|
38 | 40 | self.constraint_state = array_class.get_constraint_state(self, self._solver) |
39 | 41 |
|
| 42 | + self._eq_const_info_cache = {} |
| 43 | + |
40 | 44 | # self.ti_n_equalities = ti.field(gs.ti_int, shape=self._solver._batch_shape()) |
41 | 45 | # self.ti_n_equalities.from_numpy(np.full((self._solver._B,), self._solver.n_equalities, dtype=gs.np_int)) |
42 | 46 |
|
@@ -157,11 +161,13 @@ def __init__(self, rigid_solver: "RigidSolver"): |
157 | 161 | self.reset() |
158 | 162 |
|
159 | 163 | def clear(self, envs_idx: npt.NDArray[np.int32] | None = None): |
| 164 | + self._eq_const_info_cache.clear() |
160 | 165 | if envs_idx is None: |
161 | 166 | envs_idx = self._solver._scene._envs_idx |
162 | 167 | constraint_solver_kernel_clear(envs_idx, self._solver._static_rigid_sim_config, self.constraint_state) |
163 | 168 |
|
164 | 169 | def reset(self, envs_idx=None): |
| 170 | + self._eq_const_info_cache.clear() |
165 | 171 | if envs_idx is None: |
166 | 172 | envs_idx = self._solver._scene._envs_idx |
167 | 173 | constraint_solver_kernel_reset( |
@@ -253,6 +259,137 @@ def resolve(self): |
253 | 259 | ) |
254 | 260 | # timer.stamp("compute force") |
255 | 261 |
|
| 262 | + def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True): |
| 263 | + # Early return if already pre-computed |
| 264 | + eq_const_info = self._eq_const_info_cache.get((as_tensor, to_torch)) |
| 265 | + if eq_const_info is not None: |
| 266 | + return eq_const_info.copy() |
| 267 | + |
| 268 | + n_eqs = tuple(self.constraint_state.ti_n_equalities.to_numpy()) |
| 269 | + n_envs = len(n_eqs) |
| 270 | + n_eqs_max = max(n_eqs) |
| 271 | + |
| 272 | + if as_tensor: |
| 273 | + out_size = n_envs * n_eqs_max |
| 274 | + else: |
| 275 | + *n_eqs_starts, out_size = np.cumsum(n_eqs) |
| 276 | + |
| 277 | + if to_torch: |
| 278 | + iout = torch.full((out_size, 3), -1, dtype=gs.tc_int, device=gs.device) |
| 279 | + fout = torch.zeros((out_size, 6), dtype=gs.tc_float, device=gs.device) |
| 280 | + else: |
| 281 | + iout = np.full((out_size, 3), -1, dtype=gs.np_int) |
| 282 | + fout = np.zeros((out_size, 6), dtype=gs.np_float) |
| 283 | + |
| 284 | + if n_eqs_max > 0: |
| 285 | + kernel_get_equality_constraints( |
| 286 | + as_tensor, |
| 287 | + iout, |
| 288 | + fout, |
| 289 | + self.constraint_state, |
| 290 | + self._solver.equalities_info, |
| 291 | + self._solver._static_rigid_sim_config, |
| 292 | + ) |
| 293 | + |
| 294 | + if as_tensor: |
| 295 | + iout = iout.reshape((n_envs, n_eqs_max, 3)) |
| 296 | + eq_type, obj_a, obj_b = (iout[..., i] for i in range(3)) |
| 297 | + efc_force = fout.reshape((n_envs, n_eqs_max, 6)) |
| 298 | + values = (eq_type, obj_a, obj_b, fout) |
| 299 | + else: |
| 300 | + if to_torch: |
| 301 | + iout_chunks = torch.split(iout, n_eqs) |
| 302 | + efc_force = torch.split(fout, n_eqs) |
| 303 | + else: |
| 304 | + iout_chunks = np.split(iout, n_eqs_starts) |
| 305 | + efc_force = np.split(fout, n_eqs_starts) |
| 306 | + eq_type, obj_a, obj_b = tuple(zip(*([data[..., i] for i in range(3)] for data in iout_chunks))) |
| 307 | + |
| 308 | + values = (eq_type, obj_a, obj_b, efc_force) |
| 309 | + eq_const_info = dict(zip(("type", "obj_a", "obj_b", "force"), values)) |
| 310 | + |
| 311 | + # Cache equality constraint information before returning |
| 312 | + self._eq_const_info_cache[(as_tensor, to_torch)] = eq_const_info |
| 313 | + |
| 314 | + return eq_const_info.copy() |
| 315 | + |
| 316 | + def get_weld_constraints(self, as_tensor: bool = True, to_torch: bool = True): |
| 317 | + eq_const_info = self.get_equality_constraints(as_tensor, to_torch) |
| 318 | + eq_type = eq_const_info.pop("type") |
| 319 | + |
| 320 | + weld_const_info = {} |
| 321 | + if as_tensor: |
| 322 | + weld_mask = eq_type == gs.EQUALITY_TYPE.WELD |
| 323 | + n_envs = len(weld_mask) |
| 324 | + n_welds = weld_mask.sum(dim=-1) if to_torch else np.sum(weld_mask, axis=-1) |
| 325 | + n_welds_max = max(n_welds) |
| 326 | + for key, value in eq_const_info.items(): |
| 327 | + shape = (n_envs, n_welds_max, *value.shape[2:]) |
| 328 | + if to_torch: |
| 329 | + if torch.is_floating_point(value): |
| 330 | + weld_const_info[key] = torch.zeros(shape, dtype=value.dtype, device=value.device) |
| 331 | + else: |
| 332 | + weld_const_info[key] = torch.full(shape, -1, dtype=value.dtype, device=value.device) |
| 333 | + else: |
| 334 | + if np.issubdtype(value.dtype, np.floating): |
| 335 | + weld_const_info[key] = np.zeros(shape, dtype=value.dtype) |
| 336 | + else: |
| 337 | + weld_const_info[key] = np.full(shape, -1, dtype=value.dtype) |
| 338 | + for i_b, (n_welds_i, weld_mask_i) in enumerate(zip(n_welds, weld_mask)): |
| 339 | + for eq_value, weld_value in zip(eq_const_info.values(), weld_const_info.values()): |
| 340 | + weld_value[i_b, :n_welds_i] = eq_value[i_b, weld_mask_i] |
| 341 | + else: |
| 342 | + weld_mask_chunks = tuple(eq_type_i == gs.EQUALITY_TYPE.WELD for eq_type_i in eq_type) |
| 343 | + for key, value in eq_const_info.items(): |
| 344 | + weld_const_info[key] = tuple(data[weld_mask] for weld_mask, data in zip(weld_mask_chunks, value)) |
| 345 | + |
| 346 | + weld_const_info["link_a"] = weld_const_info.pop("obj_a") |
| 347 | + weld_const_info["link_b"] = weld_const_info.pop("obj_b") |
| 348 | + |
| 349 | + return weld_const_info |
| 350 | + |
| 351 | + def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): |
| 352 | + envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) |
| 353 | + link1_idx, link2_idx = int(link1_idx), int(link2_idx) |
| 354 | + |
| 355 | + if not unsafe: |
| 356 | + assert link1_idx >= 0 and link2_idx >= 0 |
| 357 | + weld_const_info = self.get_weld_constraints(as_tensor=True, to_torch=True) |
| 358 | + link_a = weld_const_info["link_a"] |
| 359 | + link_b = weld_const_info["link_b"] |
| 360 | + assert not ( |
| 361 | + ((link_a == link1_idx) | (link_b == link1_idx)) & ((link_a == link2_idx) | (link_b == link2_idx)) |
| 362 | + ).any() |
| 363 | + |
| 364 | + self._eq_const_info_cache.clear() |
| 365 | + overflow = kernel_add_weld_constraint( |
| 366 | + link1_idx, |
| 367 | + link2_idx, |
| 368 | + envs_idx, |
| 369 | + self._solver.equalities_info, |
| 370 | + self.constraint_state, |
| 371 | + self._solver.links_state, |
| 372 | + self._solver._static_rigid_sim_config, |
| 373 | + ) |
| 374 | + if overflow: |
| 375 | + gs.logger.warning( |
| 376 | + "Ignoring dynamically registered weld constraint to avoid exceeding max number of equality constraints" |
| 377 | + f"({self._static_rigid_sim_config.n_equalities_candidate}). Please increase the value of " |
| 378 | + "RigidSolver's option 'max_dynamic_constraints'." |
| 379 | + ) |
| 380 | + |
| 381 | + def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=False): |
| 382 | + envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) |
| 383 | + self._eq_const_info_cache.clear() |
| 384 | + kernel_delete_weld_constraint( |
| 385 | + int(link1_idx), |
| 386 | + int(link2_idx), |
| 387 | + envs_idx, |
| 388 | + self._solver.equalities_info, |
| 389 | + self.constraint_state, |
| 390 | + self._solver._static_rigid_sim_config, |
| 391 | + ) |
| 392 | + |
256 | 393 |
|
257 | 394 | @ti.kernel |
258 | 395 | def constraint_solver_kernel_clear( |
@@ -486,11 +623,11 @@ def func_equality_connect( |
486 | 623 |
|
487 | 624 | imp, aref = gu.imp_aref(sol_params, -penetration, jac_qvel, pos_diff[i_3]) |
488 | 625 |
|
489 | | - diag = ti.max(invweight * (1 - imp) / imp, gs.EPS) |
| 626 | + diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS) |
490 | 627 |
|
491 | 628 | constraint_state.diag[n_con, i_b] = diag |
492 | 629 | constraint_state.aref[n_con, i_b] = aref |
493 | | - constraint_state.efc_D[n_con, i_b] = 1 / diag |
| 630 | + constraint_state.efc_D[n_con, i_b] = 1.0 / diag |
494 | 631 |
|
495 | 632 |
|
496 | 633 | @ti.func |
@@ -564,11 +701,11 @@ def func_equality_joint( |
564 | 701 |
|
565 | 702 | imp, aref = gu.imp_aref(sol_params, -ti.abs(pos), jac_qvel, pos) |
566 | 703 |
|
567 | | - diag = ti.max(invweight * (1 - imp) / imp, gs.EPS) |
| 704 | + diag = ti.max(invweight * (1.0 - imp) / imp, gs.EPS) |
568 | 705 |
|
569 | 706 | constraint_state.diag[n_con, i_b] = diag |
570 | 707 | constraint_state.aref[n_con, i_b] = aref |
571 | | - constraint_state.efc_D[n_con, i_b] = 1 / diag |
| 708 | + constraint_state.efc_D[n_con, i_b] = 1.0 / diag |
572 | 709 |
|
573 | 710 |
|
574 | 711 | @ti.kernel |
@@ -1939,3 +2076,129 @@ def func_init_solver( |
1939 | 2076 | ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) |
1940 | 2077 | for i_d, i_b in ti.ndrange(n_dofs, _B): |
1941 | 2078 | constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] |
| 2079 | + |
| 2080 | + |
| 2081 | +@ti.kernel |
| 2082 | +def kernel_add_weld_constraint( |
| 2083 | + link1_idx: ti.i32, |
| 2084 | + link2_idx: ti.i32, |
| 2085 | + envs_idx: ti.types.ndarray(), |
| 2086 | + equalities_info: array_class.EqualitiesInfo, |
| 2087 | + constraint_state: array_class.ConstraintState, |
| 2088 | + links_state: array_class.LinksState, |
| 2089 | + static_rigid_sim_config: ti.template(), |
| 2090 | +) -> ti.i32: |
| 2091 | + overflow = gs.ti_bool(False) |
| 2092 | + |
| 2093 | + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) |
| 2094 | + for i_b_ in ti.ndrange(envs_idx.shape[0]): |
| 2095 | + i_b = envs_idx[i_b_] |
| 2096 | + i_e = constraint_state.ti_n_equalities[i_b] |
| 2097 | + if i_e == static_rigid_sim_config.n_equalities_candidate: |
| 2098 | + overflow = True |
| 2099 | + else: |
| 2100 | + shared_pos = links_state.pos[link1_idx, i_b] |
| 2101 | + pos1 = gu.ti_inv_transform_by_trans_quat( |
| 2102 | + shared_pos, links_state.pos[link1_idx, i_b], links_state.quat[link1_idx, i_b] |
| 2103 | + ) |
| 2104 | + pos2 = gu.ti_inv_transform_by_trans_quat( |
| 2105 | + shared_pos, links_state.pos[link2_idx, i_b], links_state.quat[link2_idx, i_b] |
| 2106 | + ) |
| 2107 | + |
| 2108 | + equalities_info.eq_type[i_e, i_b] = gs.ti_int(gs.EQUALITY_TYPE.WELD) |
| 2109 | + equalities_info.eq_obj1id[i_e, i_b] = link1_idx |
| 2110 | + equalities_info.eq_obj2id[i_e, i_b] = link2_idx |
| 2111 | + |
| 2112 | + for i_3 in ti.static(range(3)): |
| 2113 | + equalities_info.eq_data[i_e, i_b][i_3 + 3] = pos1[i_3] |
| 2114 | + equalities_info.eq_data[i_e, i_b][i_3] = pos2[i_3] |
| 2115 | + |
| 2116 | + relpose = gu.ti_quat_mul(gu.ti_inv_quat(links_state.quat[link1_idx, i_b]), links_state.quat[link2_idx, i_b]) |
| 2117 | + |
| 2118 | + equalities_info.eq_data[i_e, i_b][6] = relpose[0] |
| 2119 | + equalities_info.eq_data[i_e, i_b][7] = relpose[1] |
| 2120 | + equalities_info.eq_data[i_e, i_b][8] = relpose[2] |
| 2121 | + equalities_info.eq_data[i_e, i_b][9] = relpose[3] |
| 2122 | + |
| 2123 | + equalities_info.eq_data[i_e, i_b][10] = 1.0 |
| 2124 | + equalities_info.sol_params[i_e, i_b] = ti.Vector( |
| 2125 | + [2 * static_rigid_sim_config.substep_dt, 1.0e00, 9.0e-01, 9.5e-01, 1.0e-03, 5.0e-01, 2.0e00] |
| 2126 | + ) |
| 2127 | + |
| 2128 | + constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] + 1 |
| 2129 | + return overflow |
| 2130 | + |
| 2131 | + |
| 2132 | +@ti.kernel |
| 2133 | +def kernel_delete_weld_constraint( |
| 2134 | + link1_idx: ti.i32, |
| 2135 | + link2_idx: ti.i32, |
| 2136 | + envs_idx: ti.types.ndarray(), |
| 2137 | + equalities_info: array_class.EqualitiesInfo, |
| 2138 | + constraint_state: array_class.ConstraintState, |
| 2139 | + static_rigid_sim_config: ti.template(), |
| 2140 | +): |
| 2141 | + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) |
| 2142 | + for i_b_ in ti.ndrange(envs_idx.shape[0]): |
| 2143 | + i_b = envs_idx[i_b_] |
| 2144 | + for i_e in range(static_rigid_sim_config.n_equalities, constraint_state.ti_n_equalities[i_b]): |
| 2145 | + if ( |
| 2146 | + equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD |
| 2147 | + and equalities_info.eq_obj1id[i_e, i_b] == link1_idx |
| 2148 | + and equalities_info.eq_obj2id[i_e, i_b] == link2_idx |
| 2149 | + ): |
| 2150 | + if i_e < constraint_state.ti_n_equalities[i_b] - 1: |
| 2151 | + equalities_info.eq_type[i_e, i_b] = equalities_info.eq_type[ |
| 2152 | + constraint_state.ti_n_equalities[i_b] - 1, i_b |
| 2153 | + ] |
| 2154 | + constraint_state.ti_n_equalities[i_b] = constraint_state.ti_n_equalities[i_b] - 1 |
| 2155 | + |
| 2156 | + |
| 2157 | +@ti.kernel |
| 2158 | +def kernel_get_equality_constraints( |
| 2159 | + is_padded: ti.template(), |
| 2160 | + iout: ti.types.ndarray(), |
| 2161 | + fout: ti.types.ndarray(), |
| 2162 | + constraint_state: array_class.ConstraintState, |
| 2163 | + equalities_info: array_class.EqualitiesInfo, |
| 2164 | + static_rigid_sim_config: ti.template(), |
| 2165 | +): |
| 2166 | + _B = constraint_state.ti_n_equalities.shape[0] |
| 2167 | + n_eqs_max = gs.ti_int(0) |
| 2168 | + |
| 2169 | + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) |
| 2170 | + for i_b in range(_B): |
| 2171 | + n_eqs = constraint_state.ti_n_equalities[i_b] |
| 2172 | + if n_eqs > n_eqs_max: |
| 2173 | + n_eqs_max = n_eqs |
| 2174 | + |
| 2175 | + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) |
| 2176 | + for i_b in range(_B): |
| 2177 | + i_c_start = gs.ti_int(0) |
| 2178 | + i_e_start = gs.ti_int(0) |
| 2179 | + if ti.static(is_padded): |
| 2180 | + i_e_start = i_b * n_eqs_max |
| 2181 | + else: |
| 2182 | + for j_b in range(i_b): |
| 2183 | + i_e_start = i_e_start + constraint_state.ti_n_equalities[j_b] |
| 2184 | + |
| 2185 | + for i_e_ in range(constraint_state.ti_n_equalities[i_b]): |
| 2186 | + i_e = i_e_start + i_e_ |
| 2187 | + |
| 2188 | + iout[i_e, 0] = equalities_info.eq_type[i_e_, i_b] |
| 2189 | + iout[i_e, 1] = equalities_info.eq_obj1id[i_e_, i_b] |
| 2190 | + iout[i_e, 2] = equalities_info.eq_obj2id[i_e_, i_b] |
| 2191 | + |
| 2192 | + if equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.CONNECT: |
| 2193 | + for i_c_ in ti.static(range(3)): |
| 2194 | + i_c = i_c_start + i_c_ |
| 2195 | + fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b] |
| 2196 | + i_c_start = i_c_start + 3 |
| 2197 | + elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.WELD: |
| 2198 | + for i_c_ in ti.static(range(6)): |
| 2199 | + i_c = i_c_start + i_c_ |
| 2200 | + fout[i_e, i_c_] = constraint_state.efc_force[i_c, i_b] |
| 2201 | + i_c_start = i_c_start + 6 |
| 2202 | + elif equalities_info.eq_type[i_e_, i_b] == gs.EQUALITY_TYPE.JOINT: |
| 2203 | + fout[i_e, 0] = constraint_state.efc_force[i_c_start, i_b] |
| 2204 | + i_c_start = i_c_start + 1 |
0 commit comments