@@ -148,6 +148,9 @@ def __init__(
148148 solve_freq : float = 20.0 ,
149149 hand_pos_cost : float = 1 ,
150150 hand_ori_cost : float = 0.5 ,
151+ com_cost : float = 0.0 ,
152+ use_mink_posture_task : bool = False ,
153+ initial_qpos_as_posture_target : bool = False ,
151154 ):
152155 self .full_model : mujoco .MjModel = model
153156 self .full_model_data : mujoco .MjData = data
@@ -157,9 +160,11 @@ def __init__(
157160 self .posture_weights = posture_weights
158161 self .hand_pos_cost = hand_pos_cost
159162 self .hand_ori_cost = hand_ori_cost
160-
163+ self .com_cost = com_cost
164+ self .use_mink_posture_task = use_mink_posture_task
161165 self .hand_tasks : List [mink .FrameTask ]
162166 self .posture_task : WeightedPostureTask
167+ self .com_task : mink .ComTask | None = None
163168
164169 if robot_joint_names is None :
165170 robot_joint_names : List [str ] = [
@@ -168,15 +173,40 @@ def __init__(
168173 if self .robot_model .joint (i ).type != 0
169174 ] # Exclude fixed joints
170175
171- self .full_model_dof_ids : List [int ] = np .array ([self .full_model .joint (name ).id for name in robot_joint_names ])
176+ # the order of the index is the same of model.qposadr
177+ self .all_robot_qpos_indexes_in_full_model : List [int ] = []
178+ for i in range (self .robot_model .njnt ):
179+ joint_name = self .robot_model .joint (i ).name
180+ self .all_robot_qpos_indexes_in_full_model .extend (
181+ self .full_model .joint (joint_name ).qposadr [0 ] + np .arange (len (self .full_model .joint (joint_name ).qpos0 ))
182+ )
183+ assert len (self .all_robot_qpos_indexes_in_full_model ) == self .robot_model .nq
184+
185+ # the order of the index is determined by the order of the actuation_part_names
186+ self .controlled_robot_qpos_indexes : List [int ] = []
187+ self .controlled_robot_qpos_indexes_in_full_model : List [int ] = []
188+ for name in robot_joint_names :
189+ self .controlled_robot_qpos_indexes .extend (
190+ self .robot_model .joint (name ).qposadr [0 ] + np .arange (len (self .robot_model .joint (name ).qpos0 ))
191+ )
192+ self .controlled_robot_qpos_indexes_in_full_model .extend (
193+ self .full_model .joint (name ).qposadr [0 ] + np .arange (len (self .full_model .joint (name ).qpos0 ))
194+ )
172195
173- self .robot_model_dof_ids : List [int ] = np .array ([self .robot_model .joint (name ).id for name in robot_joint_names ])
174- self .full_model_dof_ids : List [int ] = np .array ([self .full_model .joint (name ).id for name in robot_joint_names ])
175196 self .site_ids = [self .robot_model .site (site_name ).id for site_name in site_names ]
176-
177197 self .site_names = site_names
198+
199+ # update robot states
200+ self .update_robot_states ()
201+
202+ # setup tasks
178203 self ._setup_tasks ()
179- self .set_posture_target (np .zeros (self .robot_model .nq ))
204+ if initial_qpos_as_posture_target :
205+ self .set_posture_target ()
206+ else :
207+ self .set_posture_target (np .zeros (self .robot_model .nq ))
208+ if self .com_cost > 0.0 :
209+ self .set_com_target ()
180210
181211 self .solver = "quadprog"
182212
@@ -203,12 +233,17 @@ def __repr__(self) -> str:
203233 return "IKSolverMink"
204234
205235 def _setup_tasks (self ):
206- weights = np .ones (self .robot_model .nq )
236+ weights = np .ones (self .robot_model .nv )
237+
207238 for joint_name , posture_weight in self .posture_weights .items ():
208- joint_idx = self .robot_model .joint (joint_name ).id
209- weights [joint_idx ] = posture_weight
239+ joint = self .robot_model .joint (joint_name )
240+ joint_dof_idx = joint .dofadr [0 ] + np .arange (len (joint .jntid ))
241+ weights [joint_dof_idx ] = posture_weight
210242
211- self .posture_task = WeightedPostureTask (self .robot_model , cost = 0.01 , weights = weights , lm_damping = 2 )
243+ if self .use_mink_posture_task :
244+ self .posture_task = mink .PostureTask (self .robot_model , cost = weights * 0.1 , lm_damping = 1.0 )
245+ else :
246+ self .posture_task = WeightedPostureTask (self .robot_model , cost = 0.01 , weights = weights , lm_damping = 2 )
212247
213248 self .tasks = [self .posture_task ]
214249
@@ -217,6 +252,10 @@ def _setup_tasks(self):
217252 )
218253 self .tasks .extend (self .hand_tasks )
219254
255+ if self .com_cost > 0.0 :
256+ self .com_task = mink .ComTask (cost = self .com_cost )
257+ self .tasks .append (self .com_task )
258+
220259 def _create_frame_tasks (self , frame_names : List [str ], position_cost : float , orientation_cost : float ):
221260 return [
222261 mink .FrameTask (
@@ -229,13 +268,33 @@ def _create_frame_tasks(self, frame_names: List[str], position_cost: float, orie
229268 for frame in frame_names
230269 ]
231270
271+ def update_robot_states (self ):
272+ # update the base pose, important for mobile robots such as humanoids
273+ self .configuration .model .body ("robot0_base" ).pos = self .full_model .body ("robot0_base" ).pos
274+ self .configuration .model .body ("robot0_base" ).quat = self .full_model .body ("robot0_base" ).quat
275+
276+ # update the qpos for the robot model
277+ self .configuration .update (
278+ self .full_model_data .qpos [self .controlled_robot_qpos_indexes_in_full_model ],
279+ self .controlled_robot_qpos_indexes ,
280+ )
281+
232282 def set_target_poses (self , target_poses : List [np .ndarray ]):
233283 for task , target in zip (self .hand_tasks , target_poses ):
234284 se3_target = mink .SE3 .from_matrix (target )
235285 task .set_target (se3_target )
236286
237- def set_posture_target (self , posture_target : np .ndarray ):
238- self .posture_task .set_target (posture_target )
287+ def set_posture_target (self , posture_target : np .ndarray | None = None ):
288+ if posture_target is None :
289+ self .posture_task .set_target_from_configuration (self .configuration )
290+ else :
291+ self .posture_task .set_target (posture_target )
292+
293+ def set_com_target (self , com_target : np .ndarray | None = None ):
294+ assert self .com_task is not None , "COM task is not initialized"
295+ if com_target is None :
296+ com_target = self .configuration .data .subtree_com [1 ]
297+ self .com_task .set_target (com_target )
239298
240299 def action_split_indexes (self ) -> Dict [str , Tuple [int , int ]]:
241300 action_split_indexes : Dict [str , Tuple [int , int ]] = {}
@@ -261,9 +320,8 @@ def transform_pose(
261320 if src_frame == dst_frame :
262321 return src_frame_pose
263322
264- self .configuration .model .body ("robot0_base" ).pos = self .full_model .body ("robot0_base" ).pos
265- self .configuration .model .body ("robot0_base" ).quat = self .full_model .body ("robot0_base" ).quat
266- self .configuration .update ()
323+ self .robot_model .body ("robot0_base" ).pos = self .full_model .body ("robot0_base" ).pos
324+ self .robot_model .body ("robot0_base" ).quat = self .full_model .body ("robot0_base" ).quat
267325
268326 X_src_frame_pose = src_frame_pose
269327 # convert src frame pose to world frame pose
@@ -294,13 +352,8 @@ def solve(self, input_action: np.ndarray) -> np.ndarray:
294352 By updating configuration's bose to match the actual base pose (in 'world' frame),
295353 we're requiring our tasks' targets to be in the 'world' frame for mink.solve_ik().
296354 """
297- # update configuration's base to match actual base
298- self .configuration .model .body ("robot0_base" ).pos = self .full_model .body ("robot0_base" ).pos
299- self .configuration .model .body ("robot0_base" ).quat = self .full_model .body ("robot0_base" ).quat
300- # update configuration's qpos to match actual qpos
301- self .configuration .update (
302- self .full_model_data .qpos [self .full_model_dof_ids ], update_idxs = self .robot_model_dof_ids
303- )
355+
356+ self .update_robot_states ()
304357
305358 input_action = input_action .reshape (len (self .site_names ), - 1 )
306359 input_pos = input_action [:, : self .pos_dim ]
@@ -409,7 +462,7 @@ def solve(self, input_action: np.ndarray) -> np.ndarray:
409462 if self .i % 50 :
410463 print (f"Task errors: { task_translation_errors } " )
411464
412- return self .configuration .data .qpos [self .robot_model_dof_ids ]
465+ return self .configuration .data .qpos [self .controlled_robot_qpos_indexes ]
413466
414467 def _get_task_translation_errors (self ) -> List [float ]:
415468 errors = []
@@ -430,6 +483,7 @@ def _get_task_errors(self) -> List[float]:
430483 for task in self .hand_tasks :
431484 error = task .compute_error (self .configuration )
432485 errors .append (np .linalg .norm (error [:3 ]))
486+ errors .append (self .posture_task .compute_error (self .configuration ))
433487 return errors
434488
435489
@@ -518,5 +572,9 @@ def _init_joint_action_policy(self):
518572 posture_weights = self .composite_controller_specific_config .get ("ik_posture_weights" , {}),
519573 hand_pos_cost = self .composite_controller_specific_config .get ("ik_hand_pos_cost" , 1.0 ),
520574 hand_ori_cost = self .composite_controller_specific_config .get ("ik_hand_ori_cost" , 0.5 ),
521- verbose = self .composite_controller_specific_config .get ("ik_verbose" , False ),
575+ use_mink_posture_task = self .composite_controller_specific_config .get ("use_mink_posture_task" , False ),
576+ initial_qpos_as_posture_target = self .composite_controller_specific_config .get (
577+ "initial_qpos_as_posture_target" , False
578+ ),
579+ verbose = self .composite_controller_specific_config .get ("verbose" , False ),
522580 )
0 commit comments