Skip to content

Commit 6c10ef2

Browse files
Robocasa dev (#792)
* add option to remove unnecessary resets from env init functions * add option to enable multiccd flag programatically * change dc wrapper to use env.model.get_xml() instead of env.sim.model.get_xml() * add flag for sleeping * run precomit * add sleeping flag * dynamically decide how actions are interpreted * pre commit * change naming * update code which applies rotation incorrectly * add flag for dcwrapper reset and add comments in joint vel * undo conditional reset * move joint detection to outside function * fixes from bug report * move axis checking code to subclass * move axis checking code to subclass --------- Co-authored-by: Soroush Nasiriany <sornasir324@gmail.com> Co-authored-by: snasiriany <snasiriany@gmail.com>
1 parent e941195 commit 6c10ef2

File tree

11 files changed

+206
-33
lines changed

11 files changed

+206
-33
lines changed

robosuite/controllers/parts/controller_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def mobile_base_controller_factory(name, params):
172172
interpolator = None
173173
if name == "JOINT_VELOCITY":
174174
return mobile_base_controllers.MobileBaseJointVelocityController(interpolator=interpolator, **params)
175+
elif name == "JOINT_VELOCITY_LEGACY":
176+
return mobile_base_controllers.LegacyMobileBaseJointVelocityController(interpolator=interpolator, **params)
175177
elif name == "JOINT_POSITION":
176178
raise NotImplementedError
177179
raise ValueError("Unknown controller name: {}".format(name))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .joint_vel import MobileBaseJointVelocityController
1+
from .joint_vel import MobileBaseJointVelocityController, LegacyMobileBaseJointVelocityController

robosuite/controllers/parts/mobile_base/joint_vel.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,85 @@ def control_limits(self):
299299
@property
300300
def name(self):
301301
return "JOINT_VELOCITY"
302+
303+
304+
class LegacyMobileBaseJointVelocityController(MobileBaseJointVelocityController):
305+
"""
306+
Legacy version of MobileBaseJointVelocityController, created to address
307+
the recent change in the axis of the forward joint in the mobile base xml.
308+
This controller is identical to the original MobileBaseJointVelocityController,
309+
except that it dynamically checks the axis of the forward joint and reorders
310+
the input action accordingly if the forward axis is the y axis instead of the x axis.
311+
This allows for backwards compatibility with previously collected datasets
312+
that were generated using older versions of the mobile base xml.
313+
"""
314+
315+
def __init__(self, *args, **kwargs):
316+
super().__init__(*args, **kwargs)
317+
318+
def _check_forward_joint_reversed(self):
319+
# Detect the axis for the forward joint and dynamically reorder action accordingly.
320+
# This is needed because previous versions of the mobile base xml had different forward
321+
# axis definitions. In order to maintain backwards compatibility with previous datasets
322+
# we dynamically detect the forward joint axis.
323+
forward_jnt = None
324+
forward_jnt_axis = None
325+
for jnt in self.joint_names:
326+
if "joint_mobile_forward" in jnt:
327+
forward_jnt = jnt
328+
forward_jnt_axis = self.sim.model.jnt_axis[self.sim.model.joint_name2id(jnt)]
329+
break
330+
return forward_jnt is not None and (forward_jnt_axis == np.array([0, 1, 0])).all()
331+
332+
def set_goal(self, action, set_qpos=None):
333+
# Update state
334+
self.update()
335+
336+
# Parse action based on the impedance mode, and update kp / kd as necessary
337+
jnt_dim = len(self.qpos_index)
338+
if self.impedance_mode == "variable":
339+
damping_ratio, kp, delta = action[:jnt_dim], action[jnt_dim : 2 * jnt_dim], action[2 * jnt_dim :]
340+
self.kp = np.clip(kp, self.kp_min, self.kp_max)
341+
self.kd = 2 * np.sqrt(self.kp) * np.clip(damping_ratio, self.damping_ratio_min, self.damping_ratio_max)
342+
elif self.impedance_mode == "variable_kp":
343+
kp, delta = action[:jnt_dim], action[jnt_dim:]
344+
self.kp = np.clip(kp, self.kp_min, self.kp_max)
345+
self.kd = 2 * np.sqrt(self.kp) # critically damped
346+
else: # This is case "fixed"
347+
delta = action
348+
349+
# Check to make sure delta is size self.joint_dim
350+
assert len(delta) == jnt_dim, "Delta qpos must be equal to the robot's joint dimension space!"
351+
352+
if delta is not None:
353+
scaled_delta = self.scale_action(delta)
354+
else:
355+
scaled_delta = None
356+
357+
curr_pos, curr_ori = self.get_base_pose()
358+
359+
# transform the action relative to initial base orientation
360+
init_theta = T.mat2euler(self.init_ori)[2] # np.arctan2(self.init_pos[1], self.init_pos[0])
361+
curr_theta = T.mat2euler(curr_ori)[2] # np.arctan2(curr_pos[1], curr_pos[0])
362+
theta = curr_theta - init_theta
363+
364+
# reorder action if forward axis is y axis
365+
if self._check_forward_joint_reversed():
366+
action = np.copy([action[i] for i in [1, 0, 2]])
367+
368+
x, y = action[0:2]
369+
# do the reverse of theta rotation
370+
action[0] = x * np.cos(theta) + y * np.sin(theta)
371+
action[1] = -x * np.sin(theta) + y * np.cos(theta)
372+
else:
373+
# input raw base action is delta relative to current pose of base
374+
# controller expects deltas relative to initial pose of base at start of episode
375+
# transform deltas from current base pose coordinates to initial base pose coordinates
376+
action = action.copy()
377+
x, y = action[0:2]
378+
action[0] = x * np.cos(theta) - y * np.sin(theta)
379+
action[1] = x * np.sin(theta) + y * np.cos(theta)
380+
381+
self.goal_qvel = action
382+
if self.interpolator is not None:
383+
self.interpolator.set_goal(self.goal_qvel)

robosuite/devices/device.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def __init__(self, env):
2121
using this device.
2222
"""
2323
self.env = env
24-
self.all_robot_arms = [robot.arms for robot in self.env.robots]
25-
self.num_robots = len(self.all_robot_arms)
24+
self._all_robot_arms = None
2625

2726
def _reset_internal_state(self):
2827
"""
@@ -36,6 +35,22 @@ def _reset_internal_state(self):
3635
self._prev_target = {arm: None for arm in self.all_robot_arms[self.active_robot]}
3736
self._prev_torso_target = None
3837

38+
@property
39+
def all_robot_arms(self):
40+
robots = getattr(self.env, "robots", None)
41+
assert robots is not None and all(r is not None for r in robots), (
42+
"Environment has not robots to control. "
43+
"Please make sure to initialize the environment and call "
44+
"reset() before using the device."
45+
)
46+
if self._all_robot_arms is None:
47+
self._all_robot_arms = [robot.arms for robot in self.env.robots]
48+
return self._all_robot_arms
49+
50+
@property
51+
def num_robots(self):
52+
return len(self.all_robot_arms)
53+
3954
@property
4055
def active_arm(self):
4156
return self.all_robot_arms[self.active_robot][self.active_arm_index]

robosuite/devices/keyboard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(self, env, pos_sensitivity=1.0, rot_sensitivity=1.0):
2323
super().__init__(env)
2424

2525
self._display_controls()
26-
self._reset_internal_state()
2726

2827
self._reset_state = 0
2928
self._enabled = False

robosuite/environments/base.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class MujocoEnv(metaclass=EnvMeta):
8282
ignore_done (bool): True if never terminating the environment (ignore @horizon).
8383
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
8484
only calls sim.reset and resets all robosuite-internal variables
85+
load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
86+
else, initialize these components in the first call to reset()
8587
renderer (str): string for the renderer to use
8688
renderer_config (dict): dictionary for the renderer configurations
8789
seed (int): environment seed. Default is None, where environment is unseeded, ie. random
@@ -102,6 +104,7 @@ def __init__(
102104
horizon=1000,
103105
ignore_done=False,
104106
hard_reset=True,
107+
load_model_on_init=True,
105108
renderer="mjviewer",
106109
renderer_config=None,
107110
seed=None,
@@ -143,31 +146,44 @@ def __init__(
143146

144147
self._ep_meta = {}
145148

146-
# Load the model
147-
self._load_model()
149+
self.load_model_on_init = load_model_on_init
148150

149-
# Initialize the simulation
150-
self._initialize_sim()
151+
# variable to keep track of whether the env has been fully initialized
152+
self._env_is_initialized = False
151153

152-
# initializes the rendering
153-
self.initialize_renderer()
154+
if self.load_model_on_init:
155+
# Load the model
156+
self._load_model()
154157

155-
# the variables will be set later.
156-
# need to set to None, in case these variables are referenced before being set
157-
self.viewer = None
158-
self.viewer_get_obs = None
158+
# Initialize the simulation
159+
self._initialize_sim()
159160

160-
# Run all further internal (re-)initialization required
161-
self._reset_internal()
161+
# initializes the rendering
162+
self.initialize_renderer()
162163

163-
# Load observables
164-
if hasattr(self.viewer, "_setup_observables"):
165-
self._observables = self.viewer._setup_observables()
166-
else:
167-
self._observables = self._setup_observables()
164+
# the variables will be set later.
165+
# need to set to None, in case these variables are referenced before being set
166+
self.viewer = None
167+
self.viewer_get_obs = None
168168

169-
# check if viewer has get observations method and set a flag for future use.
170-
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
169+
# Run all further internal (re-)initialization required
170+
self._reset_internal()
171+
172+
# Load observables
173+
if hasattr(self.viewer, "_setup_observables"):
174+
self._observables = self.viewer._setup_observables()
175+
else:
176+
self._observables = self._setup_observables()
177+
178+
# check if viewer has get observations method and set a flag for future use.
179+
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
180+
self._env_is_initialized = True
181+
else:
182+
# the variables will be set later.
183+
# need to set to None, in case these variables are referenced before being set
184+
self.sim = None
185+
self.viewer = None
186+
self.viewer_get_obs = None
171187

172188
def initialize_renderer(self):
173189
self.renderer = self.renderer.lower()
@@ -271,7 +287,7 @@ def reset(self):
271287
if self.renderer == "mjviewer":
272288
self._destroy_viewer()
273289

274-
if self.hard_reset and not self.deterministic_reset:
290+
if (self.sim is None) or (self.hard_reset and not self.deterministic_reset):
275291
if self.renderer == "mujoco":
276292
self._destroy_viewer()
277293
self._destroy_sim()
@@ -281,9 +297,33 @@ def reset(self):
281297
else:
282298
self.sim.reset()
283299

284-
# Reset necessary robosuite-centric variables
285-
self._reset_internal()
286-
self.sim.forward()
300+
if self._env_is_initialized is True:
301+
# Reset necessary robosuite-centric variables
302+
self._reset_internal()
303+
self.sim.forward()
304+
else:
305+
# initializes the rendering
306+
self.initialize_renderer()
307+
308+
# the variables will be set later.
309+
# need to set to None, in case these variables are referenced before being set
310+
self.viewer = None
311+
self.viewer_get_obs = None
312+
313+
# Run all further internal (re-)initialization required
314+
self._reset_internal()
315+
self.sim.forward()
316+
317+
# Load observables
318+
if hasattr(self.viewer, "_setup_observables"):
319+
self._observables = self.viewer._setup_observables()
320+
else:
321+
self._observables = self._setup_observables()
322+
323+
# check if viewer has get observations method and set a flag for future use.
324+
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
325+
self._env_is_initialized = True
326+
287327
# Setup observables, reloading if
288328
self._obs_cache = {}
289329
self._reset_observables()

robosuite/environments/manipulation/manipulation_env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class ManipulationEnv(RobotEnv):
8484
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
8585
only calls sim.reset and resets all robosuite-internal variables
8686
87+
load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
88+
else, initialize these components in the first call to reset()
89+
8790
camera_names (str or list of str): name of camera to be rendered. Should either be single str if
8891
same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
8992
@@ -144,6 +147,7 @@ def __init__(
144147
horizon=1000,
145148
ignore_done=False,
146149
hard_reset=True,
150+
load_model_on_init=True,
147151
camera_names="agentview",
148152
camera_heights=256,
149153
camera_widths=256,
@@ -187,6 +191,7 @@ def __init__(
187191
horizon=horizon,
188192
ignore_done=ignore_done,
189193
hard_reset=hard_reset,
194+
load_model_on_init=load_model_on_init,
190195
camera_names=camera_names,
191196
camera_heights=camera_heights,
192197
camera_widths=camera_widths,

robosuite/environments/robot_env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class RobotEnv(MujocoEnv):
7878
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
7979
only calls sim.reset and resets all robosuite-internal variables
8080
81+
load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
82+
else, initialize these components in the first call to reset()
83+
8184
camera_names (str or list of str): name of camera to be rendered. Should either be single str if
8285
same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
8386
@@ -139,6 +142,7 @@ def __init__(
139142
horizon=1000,
140143
ignore_done=False,
141144
hard_reset=True,
145+
load_model_on_init=True,
142146
camera_names="agentview",
143147
camera_heights=256,
144148
camera_widths=256,
@@ -230,6 +234,7 @@ def __init__(
230234
horizon=horizon,
231235
ignore_done=ignore_done,
232236
hard_reset=hard_reset,
237+
load_model_on_init=load_model_on_init,
233238
renderer=renderer,
234239
renderer_config=renderer_config,
235240
seed=seed,

robosuite/models/tasks/task.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class Task(MujocoWorldBase):
5151
5252
mujoco_objects (None or MujocoObject or list of MujocoObject): a list of MJCF models of physical objects
5353
54+
enable_multiccd (bool) whether to set the multiccd flag in MuJoCo. False by default
55+
5456
Raises:
5557
AssertionError: [Invalid input object type]
5658
"""
@@ -60,8 +62,10 @@ def __init__(
6062
mujoco_arena,
6163
mujoco_robots,
6264
mujoco_objects=None,
65+
enable_multiccd=False,
66+
enable_sleeping_islands=False,
6367
):
64-
super().__init__()
68+
super().__init__(enable_multiccd=enable_multiccd, enable_sleeping_islands=enable_sleeping_islands)
6569

6670
# Store references to all models
6771
self.mujoco_arena = mujoco_arena

robosuite/models/world.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import xml.etree.ElementTree as ET
2+
13
import robosuite.macros as macros
24
from robosuite.models.base import MujocoXML
35
from robosuite.utils.mjcf_utils import convert_to_string, find_elements, xml_path_completion
@@ -6,8 +8,18 @@
68
class MujocoWorldBase(MujocoXML):
79
"""Base class to inherit all mujoco worlds from."""
810

9-
def __init__(self):
11+
def __init__(self, enable_multiccd=False, enable_sleeping_islands=False):
1012
super().__init__(xml_path_completion("base.xml"))
1113
# Modify the simulation timestep to be the requested value
1214
options = find_elements(root=self.root, tags="option", attribs=None, return_first=True)
1315
options.set("timestep", convert_to_string(macros.SIMULATION_TIMESTEP))
16+
self.enable_multiccd = enable_multiccd
17+
self.enable_sleeping_islands = enable_sleeping_islands
18+
if self.enable_multiccd:
19+
multiccd_elem = ET.fromstring("""<option> <flag multiccd="enable"/> </option>""")
20+
mujoco_elem = find_elements(self.root, "mujoco")
21+
mujoco_elem.insert(0, multiccd_elem)
22+
if self.enable_sleeping_islands:
23+
sleeping_elem = ET.fromstring("""<option> <flag sleep="enable"/> </option>""")
24+
mujoco_elem = find_elements(self.root, "mujoco")
25+
mujoco_elem.insert(0, sleeping_elem)

0 commit comments

Comments
 (0)