Skip to content

Commit f4b2574

Browse files
authored
Update of robot controller and robot-mounted base (#679)
1 parent 2219758 commit f4b2574

File tree

12 files changed

+187
-30
lines changed

12 files changed

+187
-30
lines changed

robosuite/controllers/parts/generic/joint_pos.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def __init__(
175175
# initialize
176176
self.goal_qpos = None
177177

178+
self.use_torque_compensation = kwargs.get("use_torque_compensation", True)
179+
178180
def set_goal(self, action, set_qpos=None):
179181
"""
180182
Sets goal based on input @action. If self.impedance_mode is not "fixed", then the input will be parsed into the
@@ -260,7 +262,10 @@ def run_controller(self):
260262
desired_torque = np.multiply(np.array(position_error), np.array(self.kp)) + np.multiply(vel_pos_error, self.kd)
261263

262264
# Return desired torques plus gravity compensations
263-
self.torques = np.dot(self.mass_matrix, desired_torque) + self.torque_compensation
265+
if self.use_torque_compensation:
266+
self.torques = np.dot(self.mass_matrix, desired_torque) + self.torque_compensation
267+
else:
268+
self.torques = desired_torque
264269

265270
# Always run superclass call for any cleanups at the end
266271
super().run_controller()

robosuite/controllers/parts/generic/joint_tor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(
106106
self.current_torque = np.zeros(self.control_dim) # Current torques being outputted, pre-compensation
107107
self.torques = None # Torques returned every time run_controller is called
108108

109+
self.use_torque_compensation = kwargs.get("use_torque_compensation", True)
110+
109111
def set_goal(self, torques):
110112
"""
111113
Sets goal based on input @torques.
@@ -153,7 +155,10 @@ def run_controller(self):
153155
self.current_torque = np.array(self.goal_torque)
154156

155157
# Add gravity compensation
156-
self.torques = self.current_torque + self.torque_compensation
158+
if self.use_torque_compensation:
159+
self.torques = self.current_torque + self.torque_compensation
160+
else:
161+
self.torques = self.current_torque
157162

158163
# Always run superclass call for any cleanups at the end
159164
super().run_controller()

robosuite/controllers/parts/generic/joint_vel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def __init__(
124124
self.current_vel = np.zeros(self.joint_dim) # Current velocity setpoint, pre-compensation
125125
self.torques = None # Torques returned every time run_controller is called
126126

127+
self.torque_compensation = kwargs.get("use_torque_compensation", True)
128+
127129
def set_goal(self, velocities):
128130
"""
129131
Sets goal based on input @velocities.
@@ -187,7 +189,12 @@ def run_controller(self):
187189
self.summed_err += err
188190

189191
# Compute command torques via PID velocity controller plus gravity compensation torques
190-
torques = self.kp * err + self.ki * self.summed_err + self.kd * self.derr_buf.average + self.torque_compensation
192+
if self.torque_compensation:
193+
torques = (
194+
self.kp * err + self.ki * self.summed_err + self.kd * self.derr_buf.average + self.torque_compensation
195+
)
196+
else:
197+
torques = self.kp * err + self.ki * self.summed_err + self.kd * self.derr_buf.average
191198

192199
# Clip torques
193200
self.torques = self.clip_torques(torques)

robosuite/devices/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
194194
"delta": norm_delta,
195195
"abs": abs_action,
196196
}
197-
elif robot.composite_controller_config["type"] in ["WHOLE_BODY_MINK_IK"]:
197+
elif robot.composite_controller_config["type"] in ["WHOLE_BODY_MINK_IK", "HYBRID_WHOLE_BODY_MINK_IK"]:
198198
ref_frame = self.env.robots[0].composite_controller.composite_controller_specific_config.get(
199199
"ik_input_ref_frame", "world"
200200
)

robosuite/examples/third_party_controller/mink_controller.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def __init__(
148148
solve_freq: float = 20.0,
149149
hand_pos_cost: float = 1,
150150
hand_ori_cost: float = 0.5,
151+
com_cost: float = 0.0,
152+
use_mink_posture_task: bool = False,
153+
initial_qpos_as_posture_target: bool = False,
151154
):
152155
self.full_model: mujoco.MjModel = model
153156
self.full_model_data: mujoco.MjData = data
@@ -157,9 +160,11 @@ def __init__(
157160
self.posture_weights = posture_weights
158161
self.hand_pos_cost = hand_pos_cost
159162
self.hand_ori_cost = hand_ori_cost
160-
163+
self.com_cost = com_cost
164+
self.use_mink_posture_task = use_mink_posture_task
161165
self.hand_tasks: List[mink.FrameTask]
162166
self.posture_task: WeightedPostureTask
167+
self.com_task: mink.ComTask | None = None
163168

164169
if robot_joint_names is None:
165170
robot_joint_names: List[str] = [
@@ -168,15 +173,40 @@ def __init__(
168173
if self.robot_model.joint(i).type != 0
169174
] # Exclude fixed joints
170175

171-
self.full_model_dof_ids: List[int] = np.array([self.full_model.joint(name).id for name in robot_joint_names])
176+
# the order of the index is the same of model.qposadr
177+
self.all_robot_qpos_indexes_in_full_model: List[int] = []
178+
for i in range(self.robot_model.njnt):
179+
joint_name = self.robot_model.joint(i).name
180+
self.all_robot_qpos_indexes_in_full_model.extend(
181+
self.full_model.joint(joint_name).qposadr[0] + np.arange(len(self.full_model.joint(joint_name).qpos0))
182+
)
183+
assert len(self.all_robot_qpos_indexes_in_full_model) == self.robot_model.nq
184+
185+
# the order of the index is determined by the order of the actuation_part_names
186+
self.controlled_robot_qpos_indexes: List[int] = []
187+
self.controlled_robot_qpos_indexes_in_full_model: List[int] = []
188+
for name in robot_joint_names:
189+
self.controlled_robot_qpos_indexes.extend(
190+
self.robot_model.joint(name).qposadr[0] + np.arange(len(self.robot_model.joint(name).qpos0))
191+
)
192+
self.controlled_robot_qpos_indexes_in_full_model.extend(
193+
self.full_model.joint(name).qposadr[0] + np.arange(len(self.full_model.joint(name).qpos0))
194+
)
172195

173-
self.robot_model_dof_ids: List[int] = np.array([self.robot_model.joint(name).id for name in robot_joint_names])
174-
self.full_model_dof_ids: List[int] = np.array([self.full_model.joint(name).id for name in robot_joint_names])
175196
self.site_ids = [self.robot_model.site(site_name).id for site_name in site_names]
176-
177197
self.site_names = site_names
198+
199+
# update robot states
200+
self.update_robot_states()
201+
202+
# setup tasks
178203
self._setup_tasks()
179-
self.set_posture_target(np.zeros(self.robot_model.nq))
204+
if initial_qpos_as_posture_target:
205+
self.set_posture_target()
206+
else:
207+
self.set_posture_target(np.zeros(self.robot_model.nq))
208+
if self.com_cost > 0.0:
209+
self.set_com_target()
180210

181211
self.solver = "quadprog"
182212

@@ -203,12 +233,17 @@ def __repr__(self) -> str:
203233
return "IKSolverMink"
204234

205235
def _setup_tasks(self):
206-
weights = np.ones(self.robot_model.nq)
236+
weights = np.ones(self.robot_model.nv)
237+
207238
for joint_name, posture_weight in self.posture_weights.items():
208-
joint_idx = self.robot_model.joint(joint_name).id
209-
weights[joint_idx] = posture_weight
239+
joint = self.robot_model.joint(joint_name)
240+
joint_dof_idx = joint.dofadr[0] + np.arange(len(joint.jntid))
241+
weights[joint_dof_idx] = posture_weight
210242

211-
self.posture_task = WeightedPostureTask(self.robot_model, cost=0.01, weights=weights, lm_damping=2)
243+
if self.use_mink_posture_task:
244+
self.posture_task = mink.PostureTask(self.robot_model, cost=weights * 0.1, lm_damping=1.0)
245+
else:
246+
self.posture_task = WeightedPostureTask(self.robot_model, cost=0.01, weights=weights, lm_damping=2)
212247

213248
self.tasks = [self.posture_task]
214249

@@ -217,6 +252,10 @@ def _setup_tasks(self):
217252
)
218253
self.tasks.extend(self.hand_tasks)
219254

255+
if self.com_cost > 0.0:
256+
self.com_task = mink.ComTask(cost=self.com_cost)
257+
self.tasks.append(self.com_task)
258+
220259
def _create_frame_tasks(self, frame_names: List[str], position_cost: float, orientation_cost: float):
221260
return [
222261
mink.FrameTask(
@@ -229,13 +268,33 @@ def _create_frame_tasks(self, frame_names: List[str], position_cost: float, orie
229268
for frame in frame_names
230269
]
231270

271+
def update_robot_states(self):
272+
# update the base pose, important for mobile robots such as humanoids
273+
self.configuration.model.body("robot0_base").pos = self.full_model.body("robot0_base").pos
274+
self.configuration.model.body("robot0_base").quat = self.full_model.body("robot0_base").quat
275+
276+
# update the qpos for the robot model
277+
self.configuration.update(
278+
self.full_model_data.qpos[self.controlled_robot_qpos_indexes_in_full_model],
279+
self.controlled_robot_qpos_indexes,
280+
)
281+
232282
def set_target_poses(self, target_poses: List[np.ndarray]):
233283
for task, target in zip(self.hand_tasks, target_poses):
234284
se3_target = mink.SE3.from_matrix(target)
235285
task.set_target(se3_target)
236286

237-
def set_posture_target(self, posture_target: np.ndarray):
238-
self.posture_task.set_target(posture_target)
287+
def set_posture_target(self, posture_target: np.ndarray | None = None):
288+
if posture_target is None:
289+
self.posture_task.set_target_from_configuration(self.configuration)
290+
else:
291+
self.posture_task.set_target(posture_target)
292+
293+
def set_com_target(self, com_target: np.ndarray | None = None):
294+
assert self.com_task is not None, "COM task is not initialized"
295+
if com_target is None:
296+
com_target = self.configuration.data.subtree_com[1]
297+
self.com_task.set_target(com_target)
239298

240299
def action_split_indexes(self) -> Dict[str, Tuple[int, int]]:
241300
action_split_indexes: Dict[str, Tuple[int, int]] = {}
@@ -261,9 +320,8 @@ def transform_pose(
261320
if src_frame == dst_frame:
262321
return src_frame_pose
263322

264-
self.configuration.model.body("robot0_base").pos = self.full_model.body("robot0_base").pos
265-
self.configuration.model.body("robot0_base").quat = self.full_model.body("robot0_base").quat
266-
self.configuration.update()
323+
self.robot_model.body("robot0_base").pos = self.full_model.body("robot0_base").pos
324+
self.robot_model.body("robot0_base").quat = self.full_model.body("robot0_base").quat
267325

268326
X_src_frame_pose = src_frame_pose
269327
# convert src frame pose to world frame pose
@@ -294,13 +352,8 @@ def solve(self, input_action: np.ndarray) -> np.ndarray:
294352
By updating configuration's bose to match the actual base pose (in 'world' frame),
295353
we're requiring our tasks' targets to be in the 'world' frame for mink.solve_ik().
296354
"""
297-
# update configuration's base to match actual base
298-
self.configuration.model.body("robot0_base").pos = self.full_model.body("robot0_base").pos
299-
self.configuration.model.body("robot0_base").quat = self.full_model.body("robot0_base").quat
300-
# update configuration's qpos to match actual qpos
301-
self.configuration.update(
302-
self.full_model_data.qpos[self.full_model_dof_ids], update_idxs=self.robot_model_dof_ids
303-
)
355+
356+
self.update_robot_states()
304357

305358
input_action = input_action.reshape(len(self.site_names), -1)
306359
input_pos = input_action[:, : self.pos_dim]
@@ -409,7 +462,7 @@ def solve(self, input_action: np.ndarray) -> np.ndarray:
409462
if self.i % 50:
410463
print(f"Task errors: {task_translation_errors}")
411464

412-
return self.configuration.data.qpos[self.robot_model_dof_ids]
465+
return self.configuration.data.qpos[self.controlled_robot_qpos_indexes]
413466

414467
def _get_task_translation_errors(self) -> List[float]:
415468
errors = []
@@ -430,6 +483,7 @@ def _get_task_errors(self) -> List[float]:
430483
for task in self.hand_tasks:
431484
error = task.compute_error(self.configuration)
432485
errors.append(np.linalg.norm(error[:3]))
486+
errors.append(self.posture_task.compute_error(self.configuration))
433487
return errors
434488

435489

@@ -518,5 +572,9 @@ def _init_joint_action_policy(self):
518572
posture_weights=self.composite_controller_specific_config.get("ik_posture_weights", {}),
519573
hand_pos_cost=self.composite_controller_specific_config.get("ik_hand_pos_cost", 1.0),
520574
hand_ori_cost=self.composite_controller_specific_config.get("ik_hand_ori_cost", 0.5),
521-
verbose=self.composite_controller_specific_config.get("ik_verbose", False),
575+
use_mink_posture_task=self.composite_controller_specific_config.get("use_mink_posture_task", False),
576+
initial_qpos_as_posture_target=self.composite_controller_specific_config.get(
577+
"initial_qpos_as_posture_target", False
578+
),
579+
verbose=self.composite_controller_specific_config.get("verbose", False),
522580
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<mujoco model="null_mount">
2+
<worldbody>
3+
<body name="support" pos="0 0 0">
4+
<site name="center" type="sphere" pos="0 0 0" size="0.01" group="1" rgba="0 0 0 0"/>
5+
</body>
6+
</worldbody>
7+
</mujoco>

robosuite/models/bases/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .robot_base_factory import robot_base_factory
44
from .mobile_base_model import MobileBaseModel
55
from .leg_base_model import LegBaseModel
6+
from .null_base_model import NullBaseModel
67

78
from .rethink_mount import RethinkMount
89
from .rethink_minimal_mount import RethinkMinimalMount
@@ -12,7 +13,7 @@
1213
from .null_mobile_base import NullMobileBase
1314
from .no_actuation_base import NoActuationBase
1415
from .floating_legged_base import FloatingLeggedBase
15-
16+
from .null_base import NullBase
1617
from .spot_base import Spot, SpotFloating
1718

1819
BASE_MAPPING = {
@@ -25,6 +26,7 @@
2526
"FloatingLeggedBase": FloatingLeggedBase,
2627
"Spot": Spot,
2728
"SpotFloating": SpotFloating,
29+
"NullBase": NullBase,
2830
}
2931

3032
ALL_BASES = BASE_MAPPING.keys()

robosuite/models/bases/leg_base_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ def _remove_joint_actuation(self, part_name):
3434
parent_body = find_parent(self.actuator, motor)
3535
parent_body.remove(motor)
3636
self._actuators.remove(motor.get("name").replace(self.naming_prefix, ""))
37+
for sensor in self.root.findall(".//jointpos"):
38+
if part_name in sensor.get("joint"):
39+
find_parent(self.root, sensor).remove(sensor)
40+
for sensor in self.root.findall(".//jointvel"):
41+
if part_name in sensor.get("joint"):
42+
find_parent(self.root, sensor).remove(sensor)
43+
for sensor in self.root.findall(".//jointactuatorfrc"):
44+
if part_name in sensor.get("joint"):
45+
find_parent(self.root, sensor).remove(sensor)
3746

3847
def _remove_free_joint(self):
3948
"""Remove all freejoints from the model."""
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
Rethink's Generic Mount (Officially used on Sawyer).
3+
"""
4+
import numpy as np
5+
6+
from robosuite.models.bases.null_base_model import NullBaseModel
7+
from robosuite.utils.mjcf_utils import xml_path_completion
8+
9+
10+
class NullBase(NullBaseModel):
11+
"""
12+
Dummy mobile base to signify no mount.
13+
14+
Args:
15+
idn (int or str): Number or some other unique identification string for this mount instance
16+
"""
17+
18+
def __init__(self, idn=0):
19+
super().__init__(xml_path_completion("bases/null_base.xml"), idn=idn)
20+
21+
@property
22+
def top_offset(self):
23+
return np.array((0, 0, 0))
24+
25+
@property
26+
def horizontal_radius(self):
27+
return 0
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Defines the null base model
3+
"""
4+
5+
from robosuite.models.bases.robot_base_model import RobotBaseModel
6+
7+
8+
class NullBaseModel(RobotBaseModel):
9+
@property
10+
def naming_prefix(self):
11+
return "nullbase{}_".format(self.idn)

0 commit comments

Comments
 (0)