Skip to content

Commit 2b9865a

Browse files
committed
Move inverse kinematic kernel implementation back in RigidEntity.
1 parent df558b5 commit 2b9865a

File tree

4 files changed

+326
-382
lines changed

4 files changed

+326
-382
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 325 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from genesis.options.morphs import Morph
1313
from genesis.options.surfaces import Surface
1414
from genesis.utils import array_class
15+
from genesis.utils import linalg as lu
1516
from genesis.utils import geom as gu
1617
from genesis.utils import mesh as mu
1718
from genesis.utils import mjcf as mju
@@ -1312,7 +1313,8 @@ def inverse_kinematics_multilink(
13121313
links_idx_by_dofs = self._get_idx(links_idx_by_dofs, self.n_links, self._link_start, unsafe=False)
13131314
n_links_by_dofs = len(links_idx_by_dofs)
13141315

1315-
self._solver.rigid_entity_inverse_kinematics(
1316+
kernel_rigid_entity_inverse_kinematics(
1317+
self,
13161318
links_idx,
13171319
poss,
13181320
quats,
@@ -1335,7 +1337,15 @@ def inverse_kinematics_multilink(
13351337
max_step_size,
13361338
respect_joint_limit,
13371339
envs_idx,
1338-
self,
1340+
self._solver.links_state,
1341+
self._solver.links_info,
1342+
self._solver.joints_state,
1343+
self._solver.joints_info,
1344+
self._solver.dofs_state,
1345+
self._solver.dofs_info,
1346+
self._solver.entities_info,
1347+
self._solver._rigid_global_info,
1348+
self._solver._static_rigid_sim_config,
13391349
)
13401350

13411351
qpos = self._IK_qpos_best.to_torch(gs.device).transpose(1, 0)
@@ -3075,3 +3085,316 @@ def _kernel_get_fixed_verts(
30753085
for i_v_, i, i_b in ti.ndrange(n_verts, 3, _B):
30763086
i_v = i_v_ + verts_state_start
30773087
tensor[i_b, fixed_verts_idx_local[i_v_], i] = fixed_verts_state.pos[i_v][i]
3088+
3089+
3090+
# FIXME: RigidEntity is not compatible with fast cache
3091+
@ti.kernel(fastcache=False)
3092+
def kernel_rigid_entity_inverse_kinematics(
3093+
rigid_entity: ti.template(),
3094+
links_idx: ti.types.ndarray(),
3095+
poss: ti.types.ndarray(),
3096+
quats: ti.types.ndarray(),
3097+
n_links: ti.i32,
3098+
dofs_idx: ti.types.ndarray(),
3099+
n_dofs: ti.i32,
3100+
links_idx_by_dofs: ti.types.ndarray(),
3101+
n_links_by_dofs: ti.i32,
3102+
custom_init_qpos: ti.i32,
3103+
init_qpos: ti.types.ndarray(),
3104+
max_samples: ti.i32,
3105+
max_solver_iters: ti.i32,
3106+
damping: ti.f32,
3107+
pos_tol: ti.f32,
3108+
rot_tol: ti.f32,
3109+
pos_mask_: ti.types.ndarray(),
3110+
rot_mask_: ti.types.ndarray(),
3111+
link_pos_mask: ti.types.ndarray(),
3112+
link_rot_mask: ti.types.ndarray(),
3113+
max_step_size: ti.f32,
3114+
respect_joint_limit: ti.i32,
3115+
envs_idx: ti.types.ndarray(),
3116+
links_state: array_class.LinksState,
3117+
links_info: array_class.LinksInfo,
3118+
joints_state: array_class.JointsState,
3119+
joints_info: array_class.JointsInfo,
3120+
dofs_state: array_class.DofsState,
3121+
dofs_info: array_class.DofsInfo,
3122+
entities_info: array_class.EntitiesInfo,
3123+
rigid_global_info: array_class.RigidGlobalInfo,
3124+
static_rigid_sim_config: ti.template(),
3125+
):
3126+
EPS = rigid_global_info.EPS[None]
3127+
3128+
# convert to ti Vector
3129+
pos_mask = ti.Vector([pos_mask_[0], pos_mask_[1], pos_mask_[2]], dt=gs.ti_float)
3130+
rot_mask = ti.Vector([rot_mask_[0], rot_mask_[1], rot_mask_[2]], dt=gs.ti_float)
3131+
n_error_dims = 6 * n_links
3132+
3133+
for i_b_ in range(envs_idx.shape[0]):
3134+
i_b = envs_idx[i_b_]
3135+
3136+
# save original qpos
3137+
for i_q in range(rigid_entity.n_qs):
3138+
rigid_entity._IK_qpos_orig[i_q, i_b] = rigid_global_info.qpos[i_q + rigid_entity._q_start, i_b]
3139+
3140+
if custom_init_qpos:
3141+
for i_q in range(rigid_entity.n_qs):
3142+
rigid_global_info.qpos[i_q + rigid_entity._q_start, i_b] = init_qpos[i_b_, i_q]
3143+
3144+
for i_error in range(n_error_dims):
3145+
rigid_entity._IK_err_pose_best[i_error, i_b] = 1e4
3146+
3147+
solved = False
3148+
for i_sample in range(max_samples):
3149+
for _ in range(max_solver_iters):
3150+
# run FK to update link states using current q
3151+
gs.engine.solvers.rigid.rigid_solver_decomp.func_forward_kinematics_entity(
3152+
rigid_entity._idx_in_solver,
3153+
i_b,
3154+
links_state,
3155+
links_info,
3156+
joints_state,
3157+
joints_info,
3158+
dofs_state,
3159+
dofs_info,
3160+
entities_info,
3161+
rigid_global_info,
3162+
static_rigid_sim_config,
3163+
)
3164+
# compute error
3165+
solved = True
3166+
for i_ee in range(n_links):
3167+
i_l_ee = links_idx[i_ee]
3168+
3169+
tgt_pos_i = ti.Vector([poss[i_ee, i_b_, 0], poss[i_ee, i_b_, 1], poss[i_ee, i_b_, 2]])
3170+
err_pos_i = tgt_pos_i - links_state.pos[i_l_ee, i_b]
3171+
for k in range(3):
3172+
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
3173+
if err_pos_i.norm() > pos_tol:
3174+
solved = False
3175+
3176+
tgt_quat_i = ti.Vector(
3177+
[quats[i_ee, i_b_, 0], quats[i_ee, i_b_, 1], quats[i_ee, i_b_, 2], quats[i_ee, i_b_, 3]]
3178+
)
3179+
err_rot_i = gu.ti_quat_to_rotvec(
3180+
gu.ti_transform_quat_by_quat(gu.ti_inv_quat(links_state.quat[i_l_ee, i_b]), tgt_quat_i), EPS
3181+
)
3182+
for k in range(3):
3183+
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
3184+
if err_rot_i.norm() > rot_tol:
3185+
solved = False
3186+
3187+
# put into multi-link error array
3188+
for k in range(3):
3189+
rigid_entity._IK_err_pose[i_ee * 6 + k, i_b] = err_pos_i[k]
3190+
rigid_entity._IK_err_pose[i_ee * 6 + k + 3, i_b] = err_rot_i[k]
3191+
3192+
if solved:
3193+
break
3194+
3195+
# compute multi-link jacobian
3196+
for i_ee in range(n_links):
3197+
# update jacobian for ee link
3198+
i_l_ee = links_idx[i_ee]
3199+
rigid_entity._func_get_jacobian(
3200+
tgt_link_idx=i_l_ee,
3201+
i_b=i_b,
3202+
p_local=ti.Vector.zero(gs.ti_float, 3),
3203+
pos_mask=pos_mask,
3204+
rot_mask=rot_mask,
3205+
dofs_info=dofs_info,
3206+
joints_info=joints_info,
3207+
links_info=links_info,
3208+
links_state=links_state,
3209+
) # NOTE: we still compute jacobian for all dofs as we haven't found a clean way to implement this
3210+
3211+
# copy to multi-link jacobian (only for the effective n_dofs instead of self.n_dofs)
3212+
for i_dof in range(n_dofs):
3213+
for i_error in ti.static(range(6)):
3214+
i_row = i_ee * 6 + i_error
3215+
i_dof_ = dofs_idx[i_dof]
3216+
rigid_entity._IK_jacobian[i_row, i_dof, i_b] = rigid_entity._jacobian[i_error, i_dof_, i_b]
3217+
3218+
# compute dq = jac.T @ inverse(jac @ jac.T + diag) @ error (only for the effective n_dofs instead of self.n_dofs)
3219+
lu.mat_transpose(rigid_entity._IK_jacobian, rigid_entity._IK_jacobian_T, n_error_dims, n_dofs, i_b)
3220+
lu.mat_mul(
3221+
rigid_entity._IK_jacobian,
3222+
rigid_entity._IK_jacobian_T,
3223+
rigid_entity._IK_mat,
3224+
n_error_dims,
3225+
n_dofs,
3226+
n_error_dims,
3227+
i_b,
3228+
)
3229+
lu.mat_add_eye(rigid_entity._IK_mat, damping**2, n_error_dims, i_b)
3230+
lu.mat_inverse(
3231+
rigid_entity._IK_mat,
3232+
rigid_entity._IK_L,
3233+
rigid_entity._IK_U,
3234+
rigid_entity._IK_y,
3235+
rigid_entity._IK_inv,
3236+
n_error_dims,
3237+
i_b,
3238+
)
3239+
lu.mat_mul_vec(
3240+
rigid_entity._IK_inv,
3241+
rigid_entity._IK_err_pose,
3242+
rigid_entity._IK_vec,
3243+
n_error_dims,
3244+
n_error_dims,
3245+
i_b,
3246+
)
3247+
3248+
for i_d in range(rigid_entity.n_dofs): # IK_delta_qpos = IK_jacobian_T @ IK_vec
3249+
rigid_entity._IK_delta_qpos[i_d, i_b] = 0
3250+
for i_d in range(n_dofs):
3251+
for j in range(n_error_dims):
3252+
# NOTE: IK_delta_qpos uses the original indexing instead of the effective n_dofs
3253+
i_d_ = dofs_idx[i_d]
3254+
rigid_entity._IK_delta_qpos[i_d_, i_b] += (
3255+
rigid_entity._IK_jacobian_T[i_d, j, i_b] * rigid_entity._IK_vec[j, i_b]
3256+
)
3257+
3258+
for i_d in range(rigid_entity.n_dofs):
3259+
rigid_entity._IK_delta_qpos[i_d, i_b] = ti.math.clamp(
3260+
rigid_entity._IK_delta_qpos[i_d, i_b], -max_step_size, max_step_size
3261+
)
3262+
3263+
# update q
3264+
gs.engine.solvers.rigid.rigid_solver_decomp.func_integrate_dq_entity(
3265+
rigid_entity._IK_delta_qpos,
3266+
rigid_entity._idx_in_solver,
3267+
i_b,
3268+
respect_joint_limit,
3269+
links_info,
3270+
joints_info,
3271+
dofs_info,
3272+
entities_info,
3273+
rigid_global_info,
3274+
static_rigid_sim_config,
3275+
)
3276+
3277+
if not solved:
3278+
# re-compute final error if exited not due to solved
3279+
gs.engine.solvers.rigid.rigid_solver_decomp.func_forward_kinematics_entity(
3280+
rigid_entity._idx_in_solver,
3281+
i_b,
3282+
links_state,
3283+
links_info,
3284+
joints_state,
3285+
joints_info,
3286+
dofs_state,
3287+
dofs_info,
3288+
entities_info,
3289+
rigid_global_info,
3290+
static_rigid_sim_config,
3291+
)
3292+
solved = True
3293+
for i_ee in range(n_links):
3294+
i_l_ee = links_idx[i_ee]
3295+
3296+
tgt_pos_i = ti.Vector([poss[i_ee, i_b_, 0], poss[i_ee, i_b_, 1], poss[i_ee, i_b_, 2]])
3297+
err_pos_i = tgt_pos_i - links_state.pos[i_l_ee, i_b]
3298+
for k in range(3):
3299+
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
3300+
if err_pos_i.norm() > pos_tol:
3301+
solved = False
3302+
3303+
tgt_quat_i = ti.Vector(
3304+
[quats[i_ee, i_b_, 0], quats[i_ee, i_b_, 1], quats[i_ee, i_b_, 2], quats[i_ee, i_b_, 3]]
3305+
)
3306+
err_rot_i = gu.ti_quat_to_rotvec(
3307+
gu.ti_transform_quat_by_quat(gu.ti_inv_quat(links_state.quat[i_l_ee, i_b]), tgt_quat_i), EPS
3308+
)
3309+
for k in range(3):
3310+
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
3311+
if err_rot_i.norm() > rot_tol:
3312+
solved = False
3313+
3314+
# put into multi-link error array
3315+
for k in range(3):
3316+
rigid_entity._IK_err_pose[i_ee * 6 + k, i_b] = err_pos_i[k]
3317+
rigid_entity._IK_err_pose[i_ee * 6 + k + 3, i_b] = err_rot_i[k]
3318+
3319+
if solved:
3320+
for i_q in range(rigid_entity.n_qs):
3321+
rigid_entity._IK_qpos_best[i_q, i_b] = rigid_global_info.qpos[i_q + rigid_entity._q_start, i_b]
3322+
for i_error in range(n_error_dims):
3323+
rigid_entity._IK_err_pose_best[i_error, i_b] = rigid_entity._IK_err_pose[i_error, i_b]
3324+
break
3325+
3326+
else:
3327+
# copy to _IK_qpos if this sample is better
3328+
improved = True
3329+
for i_ee in range(n_links):
3330+
error_pos_i = ti.Vector(
3331+
[rigid_entity._IK_err_pose[i_ee * 6 + i_error, i_b] for i_error in range(3)]
3332+
)
3333+
error_rot_i = ti.Vector(
3334+
[rigid_entity._IK_err_pose[i_ee * 6 + i_error, i_b] for i_error in range(3, 6)]
3335+
)
3336+
error_pos_best = ti.Vector(
3337+
[rigid_entity._IK_err_pose_best[i_ee * 6 + i_error, i_b] for i_error in range(3)]
3338+
)
3339+
error_rot_best = ti.Vector(
3340+
[rigid_entity._IK_err_pose_best[i_ee * 6 + i_error, i_b] for i_error in range(3, 6)]
3341+
)
3342+
if error_pos_i.norm() > error_pos_best.norm() or error_rot_i.norm() > error_rot_best.norm():
3343+
improved = False
3344+
break
3345+
3346+
if improved:
3347+
for i_q in range(rigid_entity.n_qs):
3348+
rigid_entity._IK_qpos_best[i_q, i_b] = rigid_global_info.qpos[i_q + rigid_entity._q_start, i_b]
3349+
for i_error in range(n_error_dims):
3350+
rigid_entity._IK_err_pose_best[i_error, i_b] = rigid_entity._IK_err_pose[i_error, i_b]
3351+
3352+
# Resample init q
3353+
if respect_joint_limit and i_sample < max_samples - 1:
3354+
for _i_l in range(n_links_by_dofs):
3355+
i_l = links_idx_by_dofs[_i_l]
3356+
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
3357+
3358+
for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
3359+
I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
3360+
3361+
I_dof_start = (
3362+
[joints_info.dof_start[I_j], i_b]
3363+
if ti.static(static_rigid_sim_config.batch_dofs_info)
3364+
else joints_info.dof_start[I_j]
3365+
)
3366+
q_start = joints_info.q_start[I_j]
3367+
dof_limit = dofs_info.limit[I_dof_start]
3368+
3369+
if joints_info.type[I_j] == gs.JOINT_TYPE.FREE:
3370+
pass
3371+
3372+
elif (
3373+
joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE
3374+
or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC
3375+
):
3376+
if ti.math.isinf(dof_limit[0]) or ti.math.isinf(dof_limit[1]):
3377+
pass
3378+
else:
3379+
rigid_global_info.qpos[q_start, i_b] = dof_limit[0] + ti.random() * (
3380+
dof_limit[1] - dof_limit[0]
3381+
)
3382+
else:
3383+
pass # When respect_joint_limit=False, we can simply continue from the last solution
3384+
3385+
# restore original qpos and link state
3386+
for i_q in range(rigid_entity.n_qs):
3387+
rigid_global_info.qpos[i_q + rigid_entity._q_start, i_b] = rigid_entity._IK_qpos_orig[i_q, i_b]
3388+
gs.engine.solvers.rigid.rigid_solver_decomp.func_forward_kinematics_entity(
3389+
rigid_entity._idx_in_solver,
3390+
i_b,
3391+
links_state,
3392+
links_info,
3393+
joints_state,
3394+
joints_info,
3395+
dofs_state,
3396+
dofs_info,
3397+
entities_info,
3398+
rigid_global_info,
3399+
static_rigid_sim_config,
3400+
)

genesis/engine/solvers/avatar_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import gstaichi as ti
3+
34
import genesis as gs
4-
from genesis.engine.entities import AvatarEntity
55
from genesis.engine.states.solvers import AvatarSolverState
66

77
from .base_solver import Solver

0 commit comments

Comments
 (0)