1212from genesis .options .morphs import Morph
1313from genesis .options .surfaces import Surface
1414from genesis .utils import array_class
15+ from genesis .utils import linalg as lu
1516from genesis .utils import geom as gu
1617from genesis .utils import mesh as mu
1718from genesis .utils import mjcf as mju
@@ -1312,7 +1313,8 @@ def inverse_kinematics_multilink(
13121313 links_idx_by_dofs = self ._get_idx (links_idx_by_dofs , self .n_links , self ._link_start , unsafe = False )
13131314 n_links_by_dofs = len (links_idx_by_dofs )
13141315
1315- self ._solver .rigid_entity_inverse_kinematics (
1316+ kernel_rigid_entity_inverse_kinematics (
1317+ self ,
13161318 links_idx ,
13171319 poss ,
13181320 quats ,
@@ -1335,7 +1337,15 @@ def inverse_kinematics_multilink(
13351337 max_step_size ,
13361338 respect_joint_limit ,
13371339 envs_idx ,
1338- self ,
1340+ self ._solver .links_state ,
1341+ self ._solver .links_info ,
1342+ self ._solver .joints_state ,
1343+ self ._solver .joints_info ,
1344+ self ._solver .dofs_state ,
1345+ self ._solver .dofs_info ,
1346+ self ._solver .entities_info ,
1347+ self ._solver ._rigid_global_info ,
1348+ self ._solver ._static_rigid_sim_config ,
13391349 )
13401350
13411351 qpos = self ._IK_qpos_best .to_torch (gs .device ).transpose (1 , 0 )
@@ -3075,3 +3085,316 @@ def _kernel_get_fixed_verts(
30753085 for i_v_ , i , i_b in ti .ndrange (n_verts , 3 , _B ):
30763086 i_v = i_v_ + verts_state_start
30773087 tensor [i_b , fixed_verts_idx_local [i_v_ ], i ] = fixed_verts_state .pos [i_v ][i ]
3088+
3089+
3090+ # FIXME: RigidEntity is not compatible with fast cache
3091+ @ti .kernel (fastcache = False )
3092+ def kernel_rigid_entity_inverse_kinematics (
3093+ rigid_entity : ti .template (),
3094+ links_idx : ti .types .ndarray (),
3095+ poss : ti .types .ndarray (),
3096+ quats : ti .types .ndarray (),
3097+ n_links : ti .i32 ,
3098+ dofs_idx : ti .types .ndarray (),
3099+ n_dofs : ti .i32 ,
3100+ links_idx_by_dofs : ti .types .ndarray (),
3101+ n_links_by_dofs : ti .i32 ,
3102+ custom_init_qpos : ti .i32 ,
3103+ init_qpos : ti .types .ndarray (),
3104+ max_samples : ti .i32 ,
3105+ max_solver_iters : ti .i32 ,
3106+ damping : ti .f32 ,
3107+ pos_tol : ti .f32 ,
3108+ rot_tol : ti .f32 ,
3109+ pos_mask_ : ti .types .ndarray (),
3110+ rot_mask_ : ti .types .ndarray (),
3111+ link_pos_mask : ti .types .ndarray (),
3112+ link_rot_mask : ti .types .ndarray (),
3113+ max_step_size : ti .f32 ,
3114+ respect_joint_limit : ti .i32 ,
3115+ envs_idx : ti .types .ndarray (),
3116+ links_state : array_class .LinksState ,
3117+ links_info : array_class .LinksInfo ,
3118+ joints_state : array_class .JointsState ,
3119+ joints_info : array_class .JointsInfo ,
3120+ dofs_state : array_class .DofsState ,
3121+ dofs_info : array_class .DofsInfo ,
3122+ entities_info : array_class .EntitiesInfo ,
3123+ rigid_global_info : array_class .RigidGlobalInfo ,
3124+ static_rigid_sim_config : ti .template (),
3125+ ):
3126+ EPS = rigid_global_info .EPS [None ]
3127+
3128+ # convert to ti Vector
3129+ pos_mask = ti .Vector ([pos_mask_ [0 ], pos_mask_ [1 ], pos_mask_ [2 ]], dt = gs .ti_float )
3130+ rot_mask = ti .Vector ([rot_mask_ [0 ], rot_mask_ [1 ], rot_mask_ [2 ]], dt = gs .ti_float )
3131+ n_error_dims = 6 * n_links
3132+
3133+ for i_b_ in range (envs_idx .shape [0 ]):
3134+ i_b = envs_idx [i_b_ ]
3135+
3136+ # save original qpos
3137+ for i_q in range (rigid_entity .n_qs ):
3138+ rigid_entity ._IK_qpos_orig [i_q , i_b ] = rigid_global_info .qpos [i_q + rigid_entity ._q_start , i_b ]
3139+
3140+ if custom_init_qpos :
3141+ for i_q in range (rigid_entity .n_qs ):
3142+ rigid_global_info .qpos [i_q + rigid_entity ._q_start , i_b ] = init_qpos [i_b_ , i_q ]
3143+
3144+ for i_error in range (n_error_dims ):
3145+ rigid_entity ._IK_err_pose_best [i_error , i_b ] = 1e4
3146+
3147+ solved = False
3148+ for i_sample in range (max_samples ):
3149+ for _ in range (max_solver_iters ):
3150+ # run FK to update link states using current q
3151+ gs .engine .solvers .rigid .rigid_solver_decomp .func_forward_kinematics_entity (
3152+ rigid_entity ._idx_in_solver ,
3153+ i_b ,
3154+ links_state ,
3155+ links_info ,
3156+ joints_state ,
3157+ joints_info ,
3158+ dofs_state ,
3159+ dofs_info ,
3160+ entities_info ,
3161+ rigid_global_info ,
3162+ static_rigid_sim_config ,
3163+ )
3164+ # compute error
3165+ solved = True
3166+ for i_ee in range (n_links ):
3167+ i_l_ee = links_idx [i_ee ]
3168+
3169+ tgt_pos_i = ti .Vector ([poss [i_ee , i_b_ , 0 ], poss [i_ee , i_b_ , 1 ], poss [i_ee , i_b_ , 2 ]])
3170+ err_pos_i = tgt_pos_i - links_state .pos [i_l_ee , i_b ]
3171+ for k in range (3 ):
3172+ err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [i_ee ]
3173+ if err_pos_i .norm () > pos_tol :
3174+ solved = False
3175+
3176+ tgt_quat_i = ti .Vector (
3177+ [quats [i_ee , i_b_ , 0 ], quats [i_ee , i_b_ , 1 ], quats [i_ee , i_b_ , 2 ], quats [i_ee , i_b_ , 3 ]]
3178+ )
3179+ err_rot_i = gu .ti_quat_to_rotvec (
3180+ gu .ti_transform_quat_by_quat (gu .ti_inv_quat (links_state .quat [i_l_ee , i_b ]), tgt_quat_i ), EPS
3181+ )
3182+ for k in range (3 ):
3183+ err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [i_ee ]
3184+ if err_rot_i .norm () > rot_tol :
3185+ solved = False
3186+
3187+ # put into multi-link error array
3188+ for k in range (3 ):
3189+ rigid_entity ._IK_err_pose [i_ee * 6 + k , i_b ] = err_pos_i [k ]
3190+ rigid_entity ._IK_err_pose [i_ee * 6 + k + 3 , i_b ] = err_rot_i [k ]
3191+
3192+ if solved :
3193+ break
3194+
3195+ # compute multi-link jacobian
3196+ for i_ee in range (n_links ):
3197+ # update jacobian for ee link
3198+ i_l_ee = links_idx [i_ee ]
3199+ rigid_entity ._func_get_jacobian (
3200+ tgt_link_idx = i_l_ee ,
3201+ i_b = i_b ,
3202+ p_local = ti .Vector .zero (gs .ti_float , 3 ),
3203+ pos_mask = pos_mask ,
3204+ rot_mask = rot_mask ,
3205+ dofs_info = dofs_info ,
3206+ joints_info = joints_info ,
3207+ links_info = links_info ,
3208+ links_state = links_state ,
3209+ ) # NOTE: we still compute jacobian for all dofs as we haven't found a clean way to implement this
3210+
3211+ # copy to multi-link jacobian (only for the effective n_dofs instead of self.n_dofs)
3212+ for i_dof in range (n_dofs ):
3213+ for i_error in ti .static (range (6 )):
3214+ i_row = i_ee * 6 + i_error
3215+ i_dof_ = dofs_idx [i_dof ]
3216+ rigid_entity ._IK_jacobian [i_row , i_dof , i_b ] = rigid_entity ._jacobian [i_error , i_dof_ , i_b ]
3217+
3218+ # compute dq = jac.T @ inverse(jac @ jac.T + diag) @ error (only for the effective n_dofs instead of self.n_dofs)
3219+ lu .mat_transpose (rigid_entity ._IK_jacobian , rigid_entity ._IK_jacobian_T , n_error_dims , n_dofs , i_b )
3220+ lu .mat_mul (
3221+ rigid_entity ._IK_jacobian ,
3222+ rigid_entity ._IK_jacobian_T ,
3223+ rigid_entity ._IK_mat ,
3224+ n_error_dims ,
3225+ n_dofs ,
3226+ n_error_dims ,
3227+ i_b ,
3228+ )
3229+ lu .mat_add_eye (rigid_entity ._IK_mat , damping ** 2 , n_error_dims , i_b )
3230+ lu .mat_inverse (
3231+ rigid_entity ._IK_mat ,
3232+ rigid_entity ._IK_L ,
3233+ rigid_entity ._IK_U ,
3234+ rigid_entity ._IK_y ,
3235+ rigid_entity ._IK_inv ,
3236+ n_error_dims ,
3237+ i_b ,
3238+ )
3239+ lu .mat_mul_vec (
3240+ rigid_entity ._IK_inv ,
3241+ rigid_entity ._IK_err_pose ,
3242+ rigid_entity ._IK_vec ,
3243+ n_error_dims ,
3244+ n_error_dims ,
3245+ i_b ,
3246+ )
3247+
3248+ for i_d in range (rigid_entity .n_dofs ): # IK_delta_qpos = IK_jacobian_T @ IK_vec
3249+ rigid_entity ._IK_delta_qpos [i_d , i_b ] = 0
3250+ for i_d in range (n_dofs ):
3251+ for j in range (n_error_dims ):
3252+ # NOTE: IK_delta_qpos uses the original indexing instead of the effective n_dofs
3253+ i_d_ = dofs_idx [i_d ]
3254+ rigid_entity ._IK_delta_qpos [i_d_ , i_b ] += (
3255+ rigid_entity ._IK_jacobian_T [i_d , j , i_b ] * rigid_entity ._IK_vec [j , i_b ]
3256+ )
3257+
3258+ for i_d in range (rigid_entity .n_dofs ):
3259+ rigid_entity ._IK_delta_qpos [i_d , i_b ] = ti .math .clamp (
3260+ rigid_entity ._IK_delta_qpos [i_d , i_b ], - max_step_size , max_step_size
3261+ )
3262+
3263+ # update q
3264+ gs .engine .solvers .rigid .rigid_solver_decomp .func_integrate_dq_entity (
3265+ rigid_entity ._IK_delta_qpos ,
3266+ rigid_entity ._idx_in_solver ,
3267+ i_b ,
3268+ respect_joint_limit ,
3269+ links_info ,
3270+ joints_info ,
3271+ dofs_info ,
3272+ entities_info ,
3273+ rigid_global_info ,
3274+ static_rigid_sim_config ,
3275+ )
3276+
3277+ if not solved :
3278+ # re-compute final error if exited not due to solved
3279+ gs .engine .solvers .rigid .rigid_solver_decomp .func_forward_kinematics_entity (
3280+ rigid_entity ._idx_in_solver ,
3281+ i_b ,
3282+ links_state ,
3283+ links_info ,
3284+ joints_state ,
3285+ joints_info ,
3286+ dofs_state ,
3287+ dofs_info ,
3288+ entities_info ,
3289+ rigid_global_info ,
3290+ static_rigid_sim_config ,
3291+ )
3292+ solved = True
3293+ for i_ee in range (n_links ):
3294+ i_l_ee = links_idx [i_ee ]
3295+
3296+ tgt_pos_i = ti .Vector ([poss [i_ee , i_b_ , 0 ], poss [i_ee , i_b_ , 1 ], poss [i_ee , i_b_ , 2 ]])
3297+ err_pos_i = tgt_pos_i - links_state .pos [i_l_ee , i_b ]
3298+ for k in range (3 ):
3299+ err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [i_ee ]
3300+ if err_pos_i .norm () > pos_tol :
3301+ solved = False
3302+
3303+ tgt_quat_i = ti .Vector (
3304+ [quats [i_ee , i_b_ , 0 ], quats [i_ee , i_b_ , 1 ], quats [i_ee , i_b_ , 2 ], quats [i_ee , i_b_ , 3 ]]
3305+ )
3306+ err_rot_i = gu .ti_quat_to_rotvec (
3307+ gu .ti_transform_quat_by_quat (gu .ti_inv_quat (links_state .quat [i_l_ee , i_b ]), tgt_quat_i ), EPS
3308+ )
3309+ for k in range (3 ):
3310+ err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [i_ee ]
3311+ if err_rot_i .norm () > rot_tol :
3312+ solved = False
3313+
3314+ # put into multi-link error array
3315+ for k in range (3 ):
3316+ rigid_entity ._IK_err_pose [i_ee * 6 + k , i_b ] = err_pos_i [k ]
3317+ rigid_entity ._IK_err_pose [i_ee * 6 + k + 3 , i_b ] = err_rot_i [k ]
3318+
3319+ if solved :
3320+ for i_q in range (rigid_entity .n_qs ):
3321+ rigid_entity ._IK_qpos_best [i_q , i_b ] = rigid_global_info .qpos [i_q + rigid_entity ._q_start , i_b ]
3322+ for i_error in range (n_error_dims ):
3323+ rigid_entity ._IK_err_pose_best [i_error , i_b ] = rigid_entity ._IK_err_pose [i_error , i_b ]
3324+ break
3325+
3326+ else :
3327+ # copy to _IK_qpos if this sample is better
3328+ improved = True
3329+ for i_ee in range (n_links ):
3330+ error_pos_i = ti .Vector (
3331+ [rigid_entity ._IK_err_pose [i_ee * 6 + i_error , i_b ] for i_error in range (3 )]
3332+ )
3333+ error_rot_i = ti .Vector (
3334+ [rigid_entity ._IK_err_pose [i_ee * 6 + i_error , i_b ] for i_error in range (3 , 6 )]
3335+ )
3336+ error_pos_best = ti .Vector (
3337+ [rigid_entity ._IK_err_pose_best [i_ee * 6 + i_error , i_b ] for i_error in range (3 )]
3338+ )
3339+ error_rot_best = ti .Vector (
3340+ [rigid_entity ._IK_err_pose_best [i_ee * 6 + i_error , i_b ] for i_error in range (3 , 6 )]
3341+ )
3342+ if error_pos_i .norm () > error_pos_best .norm () or error_rot_i .norm () > error_rot_best .norm ():
3343+ improved = False
3344+ break
3345+
3346+ if improved :
3347+ for i_q in range (rigid_entity .n_qs ):
3348+ rigid_entity ._IK_qpos_best [i_q , i_b ] = rigid_global_info .qpos [i_q + rigid_entity ._q_start , i_b ]
3349+ for i_error in range (n_error_dims ):
3350+ rigid_entity ._IK_err_pose_best [i_error , i_b ] = rigid_entity ._IK_err_pose [i_error , i_b ]
3351+
3352+ # Resample init q
3353+ if respect_joint_limit and i_sample < max_samples - 1 :
3354+ for _i_l in range (n_links_by_dofs ):
3355+ i_l = links_idx_by_dofs [_i_l ]
3356+ I_l = [i_l , i_b ] if ti .static (static_rigid_sim_config .batch_links_info ) else i_l
3357+
3358+ for i_j in range (links_info .joint_start [I_l ], links_info .joint_end [I_l ]):
3359+ I_j = [i_j , i_b ] if ti .static (static_rigid_sim_config .batch_joints_info ) else i_j
3360+
3361+ I_dof_start = (
3362+ [joints_info .dof_start [I_j ], i_b ]
3363+ if ti .static (static_rigid_sim_config .batch_dofs_info )
3364+ else joints_info .dof_start [I_j ]
3365+ )
3366+ q_start = joints_info .q_start [I_j ]
3367+ dof_limit = dofs_info .limit [I_dof_start ]
3368+
3369+ if joints_info .type [I_j ] == gs .JOINT_TYPE .FREE :
3370+ pass
3371+
3372+ elif (
3373+ joints_info .type [I_j ] == gs .JOINT_TYPE .REVOLUTE
3374+ or joints_info .type [I_j ] == gs .JOINT_TYPE .PRISMATIC
3375+ ):
3376+ if ti .math .isinf (dof_limit [0 ]) or ti .math .isinf (dof_limit [1 ]):
3377+ pass
3378+ else :
3379+ rigid_global_info .qpos [q_start , i_b ] = dof_limit [0 ] + ti .random () * (
3380+ dof_limit [1 ] - dof_limit [0 ]
3381+ )
3382+ else :
3383+ pass # When respect_joint_limit=False, we can simply continue from the last solution
3384+
3385+ # restore original qpos and link state
3386+ for i_q in range (rigid_entity .n_qs ):
3387+ rigid_global_info .qpos [i_q + rigid_entity ._q_start , i_b ] = rigid_entity ._IK_qpos_orig [i_q , i_b ]
3388+ gs .engine .solvers .rigid .rigid_solver_decomp .func_forward_kinematics_entity (
3389+ rigid_entity ._idx_in_solver ,
3390+ i_b ,
3391+ links_state ,
3392+ links_info ,
3393+ joints_state ,
3394+ joints_info ,
3395+ dofs_state ,
3396+ dofs_info ,
3397+ entities_info ,
3398+ rigid_global_info ,
3399+ static_rigid_sim_config ,
3400+ )
0 commit comments