@@ -1222,7 +1222,7 @@ def inverse_kinematics_multilink(
12221222 if envs_idx is None :
12231223 envs_idx = torch .zeros (1 , dtype = gs .tc_int , device = gs .device )
12241224
1225- self ._kernel_inverse_kinematics (
1225+ self ._solver . rigid_entity_inverse_kinematics (
12261226 links_idx ,
12271227 poss ,
12281228 quats ,
@@ -1245,7 +1245,9 @@ def inverse_kinematics_multilink(
12451245 max_step_size ,
12461246 respect_joint_limit ,
12471247 envs_idx ,
1248+ self ,
12481249 )
1250+
12491251 qpos = self ._IK_qpos_best .to_torch (gs .device ).transpose (1 , 0 )
12501252 if self ._solver .n_envs > 0 :
12511253 qpos = qpos [envs_idx ]
@@ -1261,273 +1263,6 @@ def inverse_kinematics_multilink(
12611263 return qpos , error_pose
12621264 return qpos
12631265
1264- @ti .kernel
1265- def _kernel_inverse_kinematics (
1266- self ,
1267- links_idx : ti .types .ndarray (),
1268- poss : ti .types .ndarray (),
1269- quats : ti .types .ndarray (),
1270- n_links : ti .i32 ,
1271- dofs_idx : ti .types .ndarray (),
1272- n_dofs : ti .i32 ,
1273- links_idx_by_dofs : ti .types .ndarray (),
1274- n_links_by_dofs : ti .i32 ,
1275- custom_init_qpos : ti .i32 ,
1276- init_qpos : ti .types .ndarray (),
1277- max_samples : ti .i32 ,
1278- max_solver_iters : ti .i32 ,
1279- damping : ti .f32 ,
1280- pos_tol : ti .f32 ,
1281- rot_tol : ti .f32 ,
1282- pos_mask_ : ti .types .ndarray (),
1283- rot_mask_ : ti .types .ndarray (),
1284- link_pos_mask : ti .types .ndarray (),
1285- link_rot_mask : ti .types .ndarray (),
1286- max_step_size : ti .f32 ,
1287- respect_joint_limit : ti .i32 ,
1288- envs_idx : ti .types .ndarray (),
1289- ):
1290- # convert to ti Vector
1291- pos_mask = ti .Vector ([pos_mask_ [0 ], pos_mask_ [1 ], pos_mask_ [2 ]], dt = gs .ti_float )
1292- rot_mask = ti .Vector ([rot_mask_ [0 ], rot_mask_ [1 ], rot_mask_ [2 ]], dt = gs .ti_float )
1293- n_error_dims = 6 * n_links
1294-
1295- for i_b in envs_idx :
1296- # save original qpos
1297- for i_q in range (self .n_qs ):
1298- self ._IK_qpos_orig [i_q , i_b ] = self ._solver .qpos [i_q + self ._q_start , i_b ]
1299-
1300- if custom_init_qpos :
1301- for i_q in range (self .n_qs ):
1302- self ._solver .qpos [i_q + self ._q_start , i_b ] = init_qpos [i_b , i_q ]
1303-
1304- for i_error in range (n_error_dims ):
1305- self ._IK_err_pose_best [i_error , i_b ] = 1e4
1306-
1307- solved = False
1308- for i_sample in range (max_samples ):
1309- for _ in range (max_solver_iters ):
1310- # run FK to update link states using current q
1311- self ._solver ._func_forward_kinematics_entity (
1312- self ._idx_in_solver ,
1313- i_b ,
1314- self ._solver .links_state ,
1315- self ._solver .links_info ,
1316- self ._solver .joints_state ,
1317- self ._solver .joints_info ,
1318- self ._solver .dofs_state ,
1319- self ._solver .dofs_info ,
1320- self ._solver .entities_info ,
1321- self ._solver ._rigid_global_info ,
1322- self ._solver ._static_rigid_sim_config ,
1323- )
1324- # compute error
1325- solved = True
1326- for i_ee in range (n_links ):
1327- i_l_ee = links_idx [i_ee ]
1328-
1329- tgt_pos_i = ti .Vector ([poss [i_ee , i_b , 0 ], poss [i_ee , i_b , 1 ], poss [i_ee , i_b , 2 ]])
1330- err_pos_i = tgt_pos_i - self ._solver .links_state .pos [i_l_ee , i_b ]
1331- for k in range (3 ):
1332- err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [i_ee ]
1333- if err_pos_i .norm () > pos_tol :
1334- solved = False
1335-
1336- tgt_quat_i = ti .Vector (
1337- [quats [i_ee , i_b , 0 ], quats [i_ee , i_b , 1 ], quats [i_ee , i_b , 2 ], quats [i_ee , i_b , 3 ]]
1338- )
1339- err_rot_i = gu .ti_quat_to_rotvec (
1340- gu .ti_transform_quat_by_quat (
1341- gu .ti_inv_quat (self ._solver .links_state .quat [i_l_ee , i_b ]), tgt_quat_i
1342- )
1343- )
1344- for k in range (3 ):
1345- err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [i_ee ]
1346- if err_rot_i .norm () > rot_tol :
1347- solved = False
1348-
1349- # put into multi-link error array
1350- for k in range (3 ):
1351- self ._IK_err_pose [i_ee * 6 + k , i_b ] = err_pos_i [k ]
1352- self ._IK_err_pose [i_ee * 6 + k + 3 , i_b ] = err_rot_i [k ]
1353-
1354- if solved :
1355- break
1356-
1357- # compute multi-link jacobian
1358- for i_ee in range (n_links ):
1359- # update jacobian for ee link
1360- i_l_ee = links_idx [i_ee ]
1361- self ._func_get_jacobian (
1362- i_l_ee , i_b , ti .Vector .zero (gs .ti_float , 3 ), pos_mask , rot_mask
1363- ) # NOTE: we still compute jacobian for all dofs as we haven't found a clean way to implement this
1364-
1365- # copy to multi-link jacobian (only for the effective n_dofs instead of self.n_dofs)
1366- for i_dof in range (n_dofs ):
1367- for i_error in ti .static (range (6 )):
1368- i_row = i_ee * 6 + i_error
1369- i_dof_ = dofs_idx [i_dof ]
1370- self ._IK_jacobian [i_row , i_dof , i_b ] = self ._jacobian [i_error , i_dof_ , i_b ]
1371-
1372- # compute dq = jac.T @ inverse(jac @ jac.T + diag) @ error (only for the effective n_dofs instead of self.n_dofs)
1373- lu .mat_transpose (self ._IK_jacobian , self ._IK_jacobian_T , n_error_dims , n_dofs , i_b )
1374- lu .mat_mul (
1375- self ._IK_jacobian ,
1376- self ._IK_jacobian_T ,
1377- self ._IK_mat ,
1378- n_error_dims ,
1379- n_dofs ,
1380- n_error_dims ,
1381- i_b ,
1382- )
1383- lu .mat_add_eye (self ._IK_mat , damping ** 2 , n_error_dims , i_b )
1384- lu .mat_inverse (self ._IK_mat , self ._IK_L , self ._IK_U , self ._IK_y , self ._IK_inv , n_error_dims , i_b )
1385- lu .mat_mul_vec (self ._IK_inv , self ._IK_err_pose , self ._IK_vec , n_error_dims , n_error_dims , i_b )
1386-
1387- for i in range (self .n_dofs ): # IK_delta_qpos = IK_jacobian_T @ IK_vec
1388- self ._IK_delta_qpos [i , i_b ] = 0
1389- for i in range (n_dofs ):
1390- for j in range (n_error_dims ):
1391- i_ = dofs_idx [
1392- i
1393- ] # NOTE: IK_delta_qpos uses the original indexing instead of the effective n_dofs
1394- self ._IK_delta_qpos [i_ , i_b ] += self ._IK_jacobian_T [i , j , i_b ] * self ._IK_vec [j , i_b ]
1395-
1396- for i in range (self .n_dofs ):
1397- self ._IK_delta_qpos [i , i_b ] = ti .math .clamp (
1398- self ._IK_delta_qpos [i , i_b ], - max_step_size , max_step_size
1399- )
1400-
1401- # update q
1402- self ._solver ._func_integrate_dq_entity (
1403- self ._IK_delta_qpos , self ._idx_in_solver , i_b , respect_joint_limit
1404- )
1405-
1406- if not solved :
1407- # re-compute final error if exited not due to solved
1408- self ._solver ._func_forward_kinematics_entity (
1409- self ._idx_in_solver ,
1410- i_b ,
1411- self ._solver .links_state ,
1412- self ._solver .links_info ,
1413- self ._solver .joints_state ,
1414- self ._solver .joints_info ,
1415- self ._solver .dofs_state ,
1416- self ._solver .dofs_info ,
1417- self ._solver .entities_info ,
1418- self ._solver ._rigid_global_info ,
1419- self ._solver ._static_rigid_sim_config ,
1420- )
1421- solved = True
1422- for i_ee in range (n_links ):
1423- i_l_ee = links_idx [i_ee ]
1424-
1425- tgt_pos_i = ti .Vector ([poss [i_ee , i_b , 0 ], poss [i_ee , i_b , 1 ], poss [i_ee , i_b , 2 ]])
1426- err_pos_i = tgt_pos_i - self ._solver .links_state .pos [i_l_ee , i_b ]
1427- for k in range (3 ):
1428- err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [i_ee ]
1429- if err_pos_i .norm () > pos_tol :
1430- solved = False
1431-
1432- tgt_quat_i = ti .Vector (
1433- [quats [i_ee , i_b , 0 ], quats [i_ee , i_b , 1 ], quats [i_ee , i_b , 2 ], quats [i_ee , i_b , 3 ]]
1434- )
1435- err_rot_i = gu .ti_quat_to_rotvec (
1436- gu .ti_transform_quat_by_quat (
1437- gu .ti_inv_quat (self ._solver .links_state .quat [i_l_ee , i_b ]), tgt_quat_i
1438- )
1439- )
1440- for k in range (3 ):
1441- err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [i_ee ]
1442- if err_rot_i .norm () > rot_tol :
1443- solved = False
1444-
1445- # put into multi-link error array
1446- for k in range (3 ):
1447- self ._IK_err_pose [i_ee * 6 + k , i_b ] = err_pos_i [k ]
1448- self ._IK_err_pose [i_ee * 6 + k + 3 , i_b ] = err_rot_i [k ]
1449-
1450- if solved :
1451- for i_q in range (self .n_qs ):
1452- self ._IK_qpos_best [i_q , i_b ] = self ._solver .qpos [i_q + self ._q_start , i_b ]
1453- for i_error in range (n_error_dims ):
1454- self ._IK_err_pose_best [i_error , i_b ] = self ._IK_err_pose [i_error , i_b ]
1455- break
1456-
1457- else :
1458- # copy to _IK_qpos if this sample is better
1459- improved = True
1460- for i_ee in range (n_links ):
1461- error_pos_i = ti .Vector ([self ._IK_err_pose [i_ee * 6 + i_error , i_b ] for i_error in range (3 )])
1462- error_rot_i = ti .Vector ([self ._IK_err_pose [i_ee * 6 + i_error , i_b ] for i_error in range (3 , 6 )])
1463- error_pos_best = ti .Vector (
1464- [self ._IK_err_pose_best [i_ee * 6 + i_error , i_b ] for i_error in range (3 )]
1465- )
1466- error_rot_best = ti .Vector (
1467- [self ._IK_err_pose_best [i_ee * 6 + i_error , i_b ] for i_error in range (3 , 6 )]
1468- )
1469- if error_pos_i .norm () > error_pos_best .norm () or error_rot_i .norm () > error_rot_best .norm ():
1470- improved = False
1471- break
1472-
1473- if improved :
1474- for i_q in range (self .n_qs ):
1475- self ._IK_qpos_best [i_q , i_b ] = self ._solver .qpos [i_q + self ._q_start , i_b ]
1476- for i_error in range (n_error_dims ):
1477- self ._IK_err_pose_best [i_error , i_b ] = self ._IK_err_pose [i_error , i_b ]
1478-
1479- # Resample init q
1480- if respect_joint_limit and i_sample < max_samples - 1 :
1481- for _i_l in range (n_links_by_dofs ):
1482- i_l = links_idx_by_dofs [_i_l ]
1483- I_l = [i_l , i_b ] if ti .static (self .solver ._options .batch_links_info ) else i_l
1484-
1485- for i_j in range (
1486- self ._solver .links_info .joint_start [I_l ], self ._solver .links_info .joint_end [I_l ]
1487- ):
1488- I_j = [i_j , i_b ] if ti .static (self .solver ._options .batch_joints_info ) else i_j
1489-
1490- I_dof_start = (
1491- [self ._solver .joints_info .dof_start [I_j ], i_b ]
1492- if ti .static (self .solver ._options .batch_dofs_info )
1493- else self ._solver .joints_info .dof_start [I_j ]
1494- )
1495- q_start = self ._solver .joints_info .q_start [I_j ]
1496- dof_limit = self ._solver .dofs_info .limit [I_dof_start ]
1497-
1498- if self ._solver .joints_info .type [I_j ] == gs .JOINT_TYPE .FREE :
1499- pass
1500-
1501- elif (
1502- self ._solver .joints_info .type [I_j ] == gs .JOINT_TYPE .REVOLUTE
1503- or self ._solver .joints_info .type [I_j ] == gs .JOINT_TYPE .PRISMATIC
1504- ):
1505- if ti .math .isinf (dof_limit [0 ]) or ti .math .isinf (dof_limit [1 ]):
1506- pass
1507- else :
1508- self ._solver .qpos [q_start , i_b ] = dof_limit [0 ] + ti .random () * (
1509- dof_limit [1 ] - dof_limit [0 ]
1510- )
1511- else :
1512- pass # When respect_joint_limit=False, we can simply continue from the last solution
1513-
1514- # restore original qpos and link state
1515- for i_q in range (self .n_qs ):
1516- self ._solver .qpos [i_q + self ._q_start , i_b ] = self ._IK_qpos_orig [i_q , i_b ]
1517- self ._solver ._func_forward_kinematics_entity (
1518- self ._idx_in_solver ,
1519- i_b ,
1520- self ._solver .links_state ,
1521- self ._solver .links_info ,
1522- self ._solver .joints_state ,
1523- self ._solver .joints_info ,
1524- self ._solver .dofs_state ,
1525- self ._solver .dofs_info ,
1526- self ._solver .entities_info ,
1527- self ._solver ._rigid_global_info ,
1528- self ._solver ._static_rigid_sim_config ,
1529- )
1530-
15311266 @gs .assert_built
15321267 def forward_kinematics (self , qpos , qs_idx_local = None , links_idx_local = None , envs_idx = None ):
15331268 """
0 commit comments