Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions robosuite/controllers/parts/controller_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def mobile_base_controller_factory(name, params):
interpolator = None
if name == "JOINT_VELOCITY":
return mobile_base_controllers.MobileBaseJointVelocityController(interpolator=interpolator, **params)
elif name == "JOINT_VELOCITY_LEGACY":
return mobile_base_controllers.LegacyMobileBaseJointVelocityController(interpolator=interpolator, **params)
elif name == "JOINT_POSITION":
raise NotImplementedError
raise ValueError("Unknown controller name: {}".format(name))
Expand Down
2 changes: 1 addition & 1 deletion robosuite/controllers/parts/mobile_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .joint_vel import MobileBaseJointVelocityController
from .joint_vel import MobileBaseJointVelocityController, LegacyMobileBaseJointVelocityController
82 changes: 82 additions & 0 deletions robosuite/controllers/parts/mobile_base/joint_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,85 @@ def control_limits(self):
@property
def name(self):
return "JOINT_VELOCITY"


class LegacyMobileBaseJointVelocityController(MobileBaseJointVelocityController):
"""
Legacy version of MobileBaseJointVelocityController, created to address
the recent change in the axis of the forward joint in the mobile base xml.
This controller is identical to the original MobileBaseJointVelocityController,
except that it dynamically checks the axis of the forward joint and reorders
the input action accordingly if the forward axis is the y axis instead of the x axis.
This allows for backwards compatibility with previously collected datasets
that were generated using older versions of the mobile base xml.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _check_forward_joint_reversed(self):
# Detect the axis for the forward joint and dynamically reorder action accordingly.
# This is needed because previous versions of the mobile base xml had different forward
# axis definitions. In order to maintain backwards compatibility with previous datasets
# we dynamically detect the forward joint axis.
forward_jnt = None
forward_jnt_axis = None
for jnt in self.joint_names:
if "joint_mobile_forward" in jnt:
forward_jnt = jnt
forward_jnt_axis = self.sim.model.jnt_axis[self.sim.model.joint_name2id(jnt)]
break
return forward_jnt is not None and (forward_jnt_axis == np.array([0, 1, 0])).all()

def set_goal(self, action, set_qpos=None):
# Update state
self.update()

# Parse action based on the impedance mode, and update kp / kd as necessary
jnt_dim = len(self.qpos_index)
if self.impedance_mode == "variable":
damping_ratio, kp, delta = action[:jnt_dim], action[jnt_dim : 2 * jnt_dim], action[2 * jnt_dim :]
self.kp = np.clip(kp, self.kp_min, self.kp_max)
self.kd = 2 * np.sqrt(self.kp) * np.clip(damping_ratio, self.damping_ratio_min, self.damping_ratio_max)
elif self.impedance_mode == "variable_kp":
kp, delta = action[:jnt_dim], action[jnt_dim:]
self.kp = np.clip(kp, self.kp_min, self.kp_max)
self.kd = 2 * np.sqrt(self.kp) # critically damped
else: # This is case "fixed"
delta = action

# Check to make sure delta is size self.joint_dim
assert len(delta) == jnt_dim, "Delta qpos must be equal to the robot's joint dimension space!"

if delta is not None:
scaled_delta = self.scale_action(delta)
else:
scaled_delta = None

curr_pos, curr_ori = self.get_base_pose()

# transform the action relative to initial base orientation
init_theta = T.mat2euler(self.init_ori)[2] # np.arctan2(self.init_pos[1], self.init_pos[0])
curr_theta = T.mat2euler(curr_ori)[2] # np.arctan2(curr_pos[1], curr_pos[0])
theta = curr_theta - init_theta

# reorder action if forward axis is y axis
if self._check_forward_joint_reversed():
action = np.copy([action[i] for i in [1, 0, 2]])

x, y = action[0:2]
# do the reverse of theta rotation
action[0] = x * np.cos(theta) + y * np.sin(theta)
action[1] = -x * np.sin(theta) + y * np.cos(theta)
else:
# input raw base action is delta relative to current pose of base
# controller expects deltas relative to initial pose of base at start of episode
# transform deltas from current base pose coordinates to initial base pose coordinates
action = action.copy()
x, y = action[0:2]
action[0] = x * np.cos(theta) - y * np.sin(theta)
action[1] = x * np.sin(theta) + y * np.cos(theta)

self.goal_qvel = action
if self.interpolator is not None:
self.interpolator.set_goal(self.goal_qvel)
19 changes: 17 additions & 2 deletions robosuite/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def __init__(self, env):
using this device.
"""
self.env = env
self.all_robot_arms = [robot.arms for robot in self.env.robots]
self.num_robots = len(self.all_robot_arms)
self._all_robot_arms = None

def _reset_internal_state(self):
"""
Expand All @@ -36,6 +35,22 @@ def _reset_internal_state(self):
self._prev_target = {arm: None for arm in self.all_robot_arms[self.active_robot]}
self._prev_torso_target = None

@property
def all_robot_arms(self):
robots = getattr(self.env, "robots", None)
assert robots is not None and all(r is not None for r in robots), (
"Environment has not robots to control. "
"Please make sure to initialize the environment and call "
"reset() before using the device."
)
if self._all_robot_arms is None:
self._all_robot_arms = [robot.arms for robot in self.env.robots]
return self._all_robot_arms

@property
def num_robots(self):
return len(self.all_robot_arms)

@property
def active_arm(self):
return self.all_robot_arms[self.active_robot][self.active_arm_index]
Expand Down
1 change: 0 additions & 1 deletion robosuite/devices/keyboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, env, pos_sensitivity=1.0, rot_sensitivity=1.0):
super().__init__(env)

self._display_controls()
self._reset_internal_state()

self._reset_state = 0
self._enabled = False
Expand Down
86 changes: 63 additions & 23 deletions robosuite/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class MujocoEnv(metaclass=EnvMeta):
ignore_done (bool): True if never terminating the environment (ignore @horizon).
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
only calls sim.reset and resets all robosuite-internal variables
load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
else, initialize these components in the first call to reset()
renderer (str): string for the renderer to use
renderer_config (dict): dictionary for the renderer configurations
seed (int): environment seed. Default is None, where environment is unseeded, ie. random
Expand All @@ -102,6 +104,7 @@ def __init__(
horizon=1000,
ignore_done=False,
hard_reset=True,
load_model_on_init=True,
renderer="mjviewer",
renderer_config=None,
seed=None,
Expand Down Expand Up @@ -143,31 +146,44 @@ def __init__(

self._ep_meta = {}

# Load the model
self._load_model()
self.load_model_on_init = load_model_on_init

# Initialize the simulation
self._initialize_sim()
# variable to keep track of whether the env has been fully initialized
self._env_is_initialized = False

# initializes the rendering
self.initialize_renderer()
if self.load_model_on_init:
# Load the model
self._load_model()

# the variables will be set later.
# need to set to None, in case these variables are referenced before being set
self.viewer = None
self.viewer_get_obs = None
# Initialize the simulation
self._initialize_sim()

# Run all further internal (re-)initialization required
self._reset_internal()
# initializes the rendering
self.initialize_renderer()

# Load observables
if hasattr(self.viewer, "_setup_observables"):
self._observables = self.viewer._setup_observables()
else:
self._observables = self._setup_observables()
# the variables will be set later.
# need to set to None, in case these variables are referenced before being set
self.viewer = None
self.viewer_get_obs = None

# check if viewer has get observations method and set a flag for future use.
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
# Run all further internal (re-)initialization required
self._reset_internal()

# Load observables
if hasattr(self.viewer, "_setup_observables"):
self._observables = self.viewer._setup_observables()
else:
self._observables = self._setup_observables()

# check if viewer has get observations method and set a flag for future use.
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
self._env_is_initialized = True
else:
# the variables will be set later.
# need to set to None, in case these variables are referenced before being set
self.sim = None
self.viewer = None
self.viewer_get_obs = None

def initialize_renderer(self):
self.renderer = self.renderer.lower()
Expand Down Expand Up @@ -271,7 +287,7 @@ def reset(self):
if self.renderer == "mjviewer":
self._destroy_viewer()

if self.hard_reset and not self.deterministic_reset:
if (self.sim is None) or (self.hard_reset and not self.deterministic_reset):
if self.renderer == "mujoco":
self._destroy_viewer()
self._destroy_sim()
Expand All @@ -281,9 +297,33 @@ def reset(self):
else:
self.sim.reset()

# Reset necessary robosuite-centric variables
self._reset_internal()
self.sim.forward()
if self._env_is_initialized is True:
# Reset necessary robosuite-centric variables
self._reset_internal()
self.sim.forward()
else:
# initializes the rendering
self.initialize_renderer()

# the variables will be set later.
# need to set to None, in case these variables are referenced before being set
self.viewer = None
self.viewer_get_obs = None

# Run all further internal (re-)initialization required
self._reset_internal()
self.sim.forward()

# Load observables
if hasattr(self.viewer, "_setup_observables"):
self._observables = self.viewer._setup_observables()
else:
self._observables = self._setup_observables()

# check if viewer has get observations method and set a flag for future use.
self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
self._env_is_initialized = True

# Setup observables, reloading if
self._obs_cache = {}
self._reset_observables()
Expand Down
5 changes: 5 additions & 0 deletions robosuite/environments/manipulation/manipulation_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class ManipulationEnv(RobotEnv):
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
only calls sim.reset and resets all robosuite-internal variables

load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
else, initialize these components in the first call to reset()

camera_names (str or list of str): name of camera to be rendered. Should either be single str if
same name is to be used for all cameras' rendering or else it should be a list of cameras to render.

Expand Down Expand Up @@ -144,6 +147,7 @@ def __init__(
horizon=1000,
ignore_done=False,
hard_reset=True,
load_model_on_init=True,
camera_names="agentview",
camera_heights=256,
camera_widths=256,
Expand Down Expand Up @@ -187,6 +191,7 @@ def __init__(
horizon=horizon,
ignore_done=ignore_done,
hard_reset=hard_reset,
load_model_on_init=load_model_on_init,
camera_names=camera_names,
camera_heights=camera_heights,
camera_widths=camera_widths,
Expand Down
5 changes: 5 additions & 0 deletions robosuite/environments/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class RobotEnv(MujocoEnv):
hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
only calls sim.reset and resets all robosuite-internal variables

load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
else, initialize these components in the first call to reset()

camera_names (str or list of str): name of camera to be rendered. Should either be single str if
same name is to be used for all cameras' rendering or else it should be a list of cameras to render.

Expand Down Expand Up @@ -139,6 +142,7 @@ def __init__(
horizon=1000,
ignore_done=False,
hard_reset=True,
load_model_on_init=True,
camera_names="agentview",
camera_heights=256,
camera_widths=256,
Expand Down Expand Up @@ -230,6 +234,7 @@ def __init__(
horizon=horizon,
ignore_done=ignore_done,
hard_reset=hard_reset,
load_model_on_init=load_model_on_init,
renderer=renderer,
renderer_config=renderer_config,
seed=seed,
Expand Down
6 changes: 5 additions & 1 deletion robosuite/models/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class Task(MujocoWorldBase):

mujoco_objects (None or MujocoObject or list of MujocoObject): a list of MJCF models of physical objects

enable_multiccd (bool) whether to set the multiccd flag in MuJoCo. False by default

Raises:
AssertionError: [Invalid input object type]
"""
Expand All @@ -60,8 +62,10 @@ def __init__(
mujoco_arena,
mujoco_robots,
mujoco_objects=None,
enable_multiccd=False,
enable_sleeping_islands=False,
):
super().__init__()
super().__init__(enable_multiccd=enable_multiccd, enable_sleeping_islands=enable_sleeping_islands)

# Store references to all models
self.mujoco_arena = mujoco_arena
Expand Down
14 changes: 13 additions & 1 deletion robosuite/models/world.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import xml.etree.ElementTree as ET

import robosuite.macros as macros
from robosuite.models.base import MujocoXML
from robosuite.utils.mjcf_utils import convert_to_string, find_elements, xml_path_completion
Expand All @@ -6,8 +8,18 @@
class MujocoWorldBase(MujocoXML):
"""Base class to inherit all mujoco worlds from."""

def __init__(self):
def __init__(self, enable_multiccd=False, enable_sleeping_islands=False):
super().__init__(xml_path_completion("base.xml"))
# Modify the simulation timestep to be the requested value
options = find_elements(root=self.root, tags="option", attribs=None, return_first=True)
options.set("timestep", convert_to_string(macros.SIMULATION_TIMESTEP))
self.enable_multiccd = enable_multiccd
self.enable_sleeping_islands = enable_sleeping_islands
if self.enable_multiccd:
multiccd_elem = ET.fromstring("""<option> <flag multiccd="enable"/> </option>""")
mujoco_elem = find_elements(self.root, "mujoco")
mujoco_elem.insert(0, multiccd_elem)
if self.enable_sleeping_islands:
sleeping_elem = ET.fromstring("""<option> <flag sleep="enable"/> </option>""")
mujoco_elem = find_elements(self.root, "mujoco")
mujoco_elem.insert(0, sleeping_elem)
Loading