Skip to content

Commit 5511fce

Browse files
committed
merge
2 parents 81e91e6 + b16f229 commit 5511fce

File tree

17 files changed

+16836
-14891
lines changed

17 files changed

+16836
-14891
lines changed

genesis/engine/entities/drone_entity.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,14 @@ def set_propellels_rpm(self, propellels_rpm):
7575
gs.raise_exception("`propellels_rpm` cannot be negative.")
7676
self._propellers_revs = (self._propellers_revs + propellels_rpm) % (60 / self.solver.dt)
7777

78-
self.solver._kernel_set_drone_rpm(
78+
self.solver.set_drone_rpm(
7979
self._n_propellers,
8080
self._propellers_link_idxs,
8181
propellels_rpm,
8282
self._propellers_spin,
8383
self.KF,
8484
self.KM,
8585
self._model == "RACE",
86-
self.solver.links_state,
8786
)
8887

8988
def update_propeller_vgeoms(self):
@@ -93,7 +92,7 @@ def update_propeller_vgeoms(self):
9392
This method is a no-op if animation is disabled due to missing visual geometry.
9493
"""
9594
if self._animate_propellers:
96-
self.solver._update_drone_propeller_vgeoms(
95+
self.solver.update_drone_propeller_vgeoms(
9796
self._n_propellers, self._propellers_vgeom_idxs, self._propellers_revs, self._propellers_spin
9897
)
9998

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 3 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def inverse_kinematics_multilink(
12221222
if envs_idx is None:
12231223
envs_idx = torch.zeros(1, dtype=gs.tc_int, device=gs.device)
12241224

1225-
self._kernel_inverse_kinematics(
1225+
self._solver.rigid_entity_inverse_kinematics(
12261226
links_idx,
12271227
poss,
12281228
quats,
@@ -1245,7 +1245,9 @@ def inverse_kinematics_multilink(
12451245
max_step_size,
12461246
respect_joint_limit,
12471247
envs_idx,
1248+
self,
12481249
)
1250+
12491251
qpos = self._IK_qpos_best.to_torch(gs.device).transpose(1, 0)
12501252
if self._solver.n_envs > 0:
12511253
qpos = qpos[envs_idx]
@@ -1261,273 +1263,6 @@ def inverse_kinematics_multilink(
12611263
return qpos, error_pose
12621264
return qpos
12631265

1264-
@ti.kernel
1265-
def _kernel_inverse_kinematics(
1266-
self,
1267-
links_idx: ti.types.ndarray(),
1268-
poss: ti.types.ndarray(),
1269-
quats: ti.types.ndarray(),
1270-
n_links: ti.i32,
1271-
dofs_idx: ti.types.ndarray(),
1272-
n_dofs: ti.i32,
1273-
links_idx_by_dofs: ti.types.ndarray(),
1274-
n_links_by_dofs: ti.i32,
1275-
custom_init_qpos: ti.i32,
1276-
init_qpos: ti.types.ndarray(),
1277-
max_samples: ti.i32,
1278-
max_solver_iters: ti.i32,
1279-
damping: ti.f32,
1280-
pos_tol: ti.f32,
1281-
rot_tol: ti.f32,
1282-
pos_mask_: ti.types.ndarray(),
1283-
rot_mask_: ti.types.ndarray(),
1284-
link_pos_mask: ti.types.ndarray(),
1285-
link_rot_mask: ti.types.ndarray(),
1286-
max_step_size: ti.f32,
1287-
respect_joint_limit: ti.i32,
1288-
envs_idx: ti.types.ndarray(),
1289-
):
1290-
# convert to ti Vector
1291-
pos_mask = ti.Vector([pos_mask_[0], pos_mask_[1], pos_mask_[2]], dt=gs.ti_float)
1292-
rot_mask = ti.Vector([rot_mask_[0], rot_mask_[1], rot_mask_[2]], dt=gs.ti_float)
1293-
n_error_dims = 6 * n_links
1294-
1295-
for i_b in envs_idx:
1296-
# save original qpos
1297-
for i_q in range(self.n_qs):
1298-
self._IK_qpos_orig[i_q, i_b] = self._solver.qpos[i_q + self._q_start, i_b]
1299-
1300-
if custom_init_qpos:
1301-
for i_q in range(self.n_qs):
1302-
self._solver.qpos[i_q + self._q_start, i_b] = init_qpos[i_b, i_q]
1303-
1304-
for i_error in range(n_error_dims):
1305-
self._IK_err_pose_best[i_error, i_b] = 1e4
1306-
1307-
solved = False
1308-
for i_sample in range(max_samples):
1309-
for _ in range(max_solver_iters):
1310-
# run FK to update link states using current q
1311-
self._solver._func_forward_kinematics_entity(
1312-
self._idx_in_solver,
1313-
i_b,
1314-
self._solver.links_state,
1315-
self._solver.links_info,
1316-
self._solver.joints_state,
1317-
self._solver.joints_info,
1318-
self._solver.dofs_state,
1319-
self._solver.dofs_info,
1320-
self._solver.entities_info,
1321-
self._solver._rigid_global_info,
1322-
self._solver._static_rigid_sim_config,
1323-
)
1324-
# compute error
1325-
solved = True
1326-
for i_ee in range(n_links):
1327-
i_l_ee = links_idx[i_ee]
1328-
1329-
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
1330-
err_pos_i = tgt_pos_i - self._solver.links_state.pos[i_l_ee, i_b]
1331-
for k in range(3):
1332-
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
1333-
if err_pos_i.norm() > pos_tol:
1334-
solved = False
1335-
1336-
tgt_quat_i = ti.Vector(
1337-
[quats[i_ee, i_b, 0], quats[i_ee, i_b, 1], quats[i_ee, i_b, 2], quats[i_ee, i_b, 3]]
1338-
)
1339-
err_rot_i = gu.ti_quat_to_rotvec(
1340-
gu.ti_transform_quat_by_quat(
1341-
gu.ti_inv_quat(self._solver.links_state.quat[i_l_ee, i_b]), tgt_quat_i
1342-
)
1343-
)
1344-
for k in range(3):
1345-
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
1346-
if err_rot_i.norm() > rot_tol:
1347-
solved = False
1348-
1349-
# put into multi-link error array
1350-
for k in range(3):
1351-
self._IK_err_pose[i_ee * 6 + k, i_b] = err_pos_i[k]
1352-
self._IK_err_pose[i_ee * 6 + k + 3, i_b] = err_rot_i[k]
1353-
1354-
if solved:
1355-
break
1356-
1357-
# compute multi-link jacobian
1358-
for i_ee in range(n_links):
1359-
# update jacobian for ee link
1360-
i_l_ee = links_idx[i_ee]
1361-
self._func_get_jacobian(
1362-
i_l_ee, i_b, ti.Vector.zero(gs.ti_float, 3), pos_mask, rot_mask
1363-
) # NOTE: we still compute jacobian for all dofs as we haven't found a clean way to implement this
1364-
1365-
# copy to multi-link jacobian (only for the effective n_dofs instead of self.n_dofs)
1366-
for i_dof in range(n_dofs):
1367-
for i_error in ti.static(range(6)):
1368-
i_row = i_ee * 6 + i_error
1369-
i_dof_ = dofs_idx[i_dof]
1370-
self._IK_jacobian[i_row, i_dof, i_b] = self._jacobian[i_error, i_dof_, i_b]
1371-
1372-
# compute dq = jac.T @ inverse(jac @ jac.T + diag) @ error (only for the effective n_dofs instead of self.n_dofs)
1373-
lu.mat_transpose(self._IK_jacobian, self._IK_jacobian_T, n_error_dims, n_dofs, i_b)
1374-
lu.mat_mul(
1375-
self._IK_jacobian,
1376-
self._IK_jacobian_T,
1377-
self._IK_mat,
1378-
n_error_dims,
1379-
n_dofs,
1380-
n_error_dims,
1381-
i_b,
1382-
)
1383-
lu.mat_add_eye(self._IK_mat, damping**2, n_error_dims, i_b)
1384-
lu.mat_inverse(self._IK_mat, self._IK_L, self._IK_U, self._IK_y, self._IK_inv, n_error_dims, i_b)
1385-
lu.mat_mul_vec(self._IK_inv, self._IK_err_pose, self._IK_vec, n_error_dims, n_error_dims, i_b)
1386-
1387-
for i in range(self.n_dofs): # IK_delta_qpos = IK_jacobian_T @ IK_vec
1388-
self._IK_delta_qpos[i, i_b] = 0
1389-
for i in range(n_dofs):
1390-
for j in range(n_error_dims):
1391-
i_ = dofs_idx[
1392-
i
1393-
] # NOTE: IK_delta_qpos uses the original indexing instead of the effective n_dofs
1394-
self._IK_delta_qpos[i_, i_b] += self._IK_jacobian_T[i, j, i_b] * self._IK_vec[j, i_b]
1395-
1396-
for i in range(self.n_dofs):
1397-
self._IK_delta_qpos[i, i_b] = ti.math.clamp(
1398-
self._IK_delta_qpos[i, i_b], -max_step_size, max_step_size
1399-
)
1400-
1401-
# update q
1402-
self._solver._func_integrate_dq_entity(
1403-
self._IK_delta_qpos, self._idx_in_solver, i_b, respect_joint_limit
1404-
)
1405-
1406-
if not solved:
1407-
# re-compute final error if exited not due to solved
1408-
self._solver._func_forward_kinematics_entity(
1409-
self._idx_in_solver,
1410-
i_b,
1411-
self._solver.links_state,
1412-
self._solver.links_info,
1413-
self._solver.joints_state,
1414-
self._solver.joints_info,
1415-
self._solver.dofs_state,
1416-
self._solver.dofs_info,
1417-
self._solver.entities_info,
1418-
self._solver._rigid_global_info,
1419-
self._solver._static_rigid_sim_config,
1420-
)
1421-
solved = True
1422-
for i_ee in range(n_links):
1423-
i_l_ee = links_idx[i_ee]
1424-
1425-
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
1426-
err_pos_i = tgt_pos_i - self._solver.links_state.pos[i_l_ee, i_b]
1427-
for k in range(3):
1428-
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
1429-
if err_pos_i.norm() > pos_tol:
1430-
solved = False
1431-
1432-
tgt_quat_i = ti.Vector(
1433-
[quats[i_ee, i_b, 0], quats[i_ee, i_b, 1], quats[i_ee, i_b, 2], quats[i_ee, i_b, 3]]
1434-
)
1435-
err_rot_i = gu.ti_quat_to_rotvec(
1436-
gu.ti_transform_quat_by_quat(
1437-
gu.ti_inv_quat(self._solver.links_state.quat[i_l_ee, i_b]), tgt_quat_i
1438-
)
1439-
)
1440-
for k in range(3):
1441-
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
1442-
if err_rot_i.norm() > rot_tol:
1443-
solved = False
1444-
1445-
# put into multi-link error array
1446-
for k in range(3):
1447-
self._IK_err_pose[i_ee * 6 + k, i_b] = err_pos_i[k]
1448-
self._IK_err_pose[i_ee * 6 + k + 3, i_b] = err_rot_i[k]
1449-
1450-
if solved:
1451-
for i_q in range(self.n_qs):
1452-
self._IK_qpos_best[i_q, i_b] = self._solver.qpos[i_q + self._q_start, i_b]
1453-
for i_error in range(n_error_dims):
1454-
self._IK_err_pose_best[i_error, i_b] = self._IK_err_pose[i_error, i_b]
1455-
break
1456-
1457-
else:
1458-
# copy to _IK_qpos if this sample is better
1459-
improved = True
1460-
for i_ee in range(n_links):
1461-
error_pos_i = ti.Vector([self._IK_err_pose[i_ee * 6 + i_error, i_b] for i_error in range(3)])
1462-
error_rot_i = ti.Vector([self._IK_err_pose[i_ee * 6 + i_error, i_b] for i_error in range(3, 6)])
1463-
error_pos_best = ti.Vector(
1464-
[self._IK_err_pose_best[i_ee * 6 + i_error, i_b] for i_error in range(3)]
1465-
)
1466-
error_rot_best = ti.Vector(
1467-
[self._IK_err_pose_best[i_ee * 6 + i_error, i_b] for i_error in range(3, 6)]
1468-
)
1469-
if error_pos_i.norm() > error_pos_best.norm() or error_rot_i.norm() > error_rot_best.norm():
1470-
improved = False
1471-
break
1472-
1473-
if improved:
1474-
for i_q in range(self.n_qs):
1475-
self._IK_qpos_best[i_q, i_b] = self._solver.qpos[i_q + self._q_start, i_b]
1476-
for i_error in range(n_error_dims):
1477-
self._IK_err_pose_best[i_error, i_b] = self._IK_err_pose[i_error, i_b]
1478-
1479-
# Resample init q
1480-
if respect_joint_limit and i_sample < max_samples - 1:
1481-
for _i_l in range(n_links_by_dofs):
1482-
i_l = links_idx_by_dofs[_i_l]
1483-
I_l = [i_l, i_b] if ti.static(self.solver._options.batch_links_info) else i_l
1484-
1485-
for i_j in range(
1486-
self._solver.links_info.joint_start[I_l], self._solver.links_info.joint_end[I_l]
1487-
):
1488-
I_j = [i_j, i_b] if ti.static(self.solver._options.batch_joints_info) else i_j
1489-
1490-
I_dof_start = (
1491-
[self._solver.joints_info.dof_start[I_j], i_b]
1492-
if ti.static(self.solver._options.batch_dofs_info)
1493-
else self._solver.joints_info.dof_start[I_j]
1494-
)
1495-
q_start = self._solver.joints_info.q_start[I_j]
1496-
dof_limit = self._solver.dofs_info.limit[I_dof_start]
1497-
1498-
if self._solver.joints_info.type[I_j] == gs.JOINT_TYPE.FREE:
1499-
pass
1500-
1501-
elif (
1502-
self._solver.joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE
1503-
or self._solver.joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC
1504-
):
1505-
if ti.math.isinf(dof_limit[0]) or ti.math.isinf(dof_limit[1]):
1506-
pass
1507-
else:
1508-
self._solver.qpos[q_start, i_b] = dof_limit[0] + ti.random() * (
1509-
dof_limit[1] - dof_limit[0]
1510-
)
1511-
else:
1512-
pass # When respect_joint_limit=False, we can simply continue from the last solution
1513-
1514-
# restore original qpos and link state
1515-
for i_q in range(self.n_qs):
1516-
self._solver.qpos[i_q + self._q_start, i_b] = self._IK_qpos_orig[i_q, i_b]
1517-
self._solver._func_forward_kinematics_entity(
1518-
self._idx_in_solver,
1519-
i_b,
1520-
self._solver.links_state,
1521-
self._solver.links_info,
1522-
self._solver.joints_state,
1523-
self._solver.joints_info,
1524-
self._solver.dofs_state,
1525-
self._solver.dofs_info,
1526-
self._solver.entities_info,
1527-
self._solver._rigid_global_info,
1528-
self._solver._static_rigid_sim_config,
1529-
)
1530-
15311266
@gs.assert_built
15321267
def forward_kinematics(self, qpos, qs_idx_local=None, links_idx_local=None, envs_idx=None):
15331268
"""

genesis/engine/entities/rigid_entity/rigid_geom.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def get_verts(self):
388388
"""
389389
Get the vertices of the geom in world frame.
390390
"""
391+
self._solver.update_verts_for_geom(self._idx)
391392
if self.is_free:
392393
tensor = torch.empty(
393394
self._solver._batch_shape((self.n_verts, 3), True), dtype=gs.tc_float, device=gs.device
@@ -402,17 +403,12 @@ def get_verts(self):
402403

403404
@ti.kernel
404405
def _kernel_get_free_verts(self, tensor: ti.types.ndarray()):
405-
for i_b in range(self._solver._B):
406-
self._solver._func_update_verts_for_geom(self._idx, i_b)
407-
408406
for i_v, j, i_b in ti.ndrange(self.n_verts, 3, self._solver._B):
409407
idx_vert = i_v + self._verts_state_start
410408
tensor[i_b, i_v, j] = self._solver.free_verts_state.pos[idx_vert, i_b][j]
411409

412410
@ti.kernel
413411
def _kernel_get_fixed_verts(self, tensor: ti.types.ndarray()):
414-
self._solver._func_update_verts_for_geom(self._idx, 0)
415-
416412
for i_v, j in ti.ndrange(self.n_verts, 3):
417413
idx_vert = i_v + self._verts_state_start
418414
tensor[i_v, j] = self._solver.fixed_verts_state.pos[idx_vert][j]

genesis/engine/simulator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,7 @@ def step(self, in_backward=False):
273273
self.save_ckpt()
274274

275275
if self.rigid_solver.is_active():
276-
self.rigid_solver._kernel_clear_external_force(
277-
links_state=self.rigid_solver.links_state,
278-
rigid_global_info=self.rigid_solver._rigid_global_info,
279-
static_rigid_sim_config=self.rigid_solver._static_rigid_sim_config,
280-
)
276+
self.rigid_solver.clear_external_force()
281277

282278
def _step_grad(self):
283279
for _ in range(self._substeps - 1, -1, -1):

0 commit comments

Comments
 (0)