diff --git a/robosuite/models/arenas/arena.py b/robosuite/models/arenas/arena.py index ff775eca2a..615317e9da 100644 --- a/robosuite/models/arenas/arena.py +++ b/robosuite/models/arenas/arena.py @@ -1,3 +1,5 @@ +from typing import List, Union + import numpy as np from robosuite.models.base import MujocoXML @@ -22,6 +24,7 @@ def __init__(self, fname): # Get references to floor and bottom self.bottom_pos = np.zeros(3) self.floor = self.worldbody.find("./geom[@name='floor']") + self.object_scales = {} # Add mocap bodies to self.root for mocap control in mjviewer UI for robot control mocap_body_1 = new_body(name="left_eef_target", pos="0 0 -1", mocap=True) @@ -129,3 +132,145 @@ def _postprocess_arena(self): Runs any necessary post-processing on the imported Arena model """ pass + + def _get_geoms(self, root, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a geom element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is a geom type + """ + return self._get_elements(root, "geom", _parent) + + def _get_elements(self, root, type, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a specific type of element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is of type + """ + # Initialize return array + elem_pairs = [] + # If the parent exists and this is a desired element, we add this current (parent, element) combo to the output + if _parent is not None and root.tag == type: + elem_pairs.append((_parent, root)) + # Loop through all children elements recursively and add to pairs + for child in root: + elem_pairs += self._get_elements(child, type, _parent=root) + + # Return all found pairs + return elem_pairs + + def set_scale(self, scale: Union[float, List[float]], obj_name: str): + """ + Scales each geom, mesh, site, and body under obj_name. + Called during initialization but can also be used externally + + Args: + scale (float or list of floats): Scale factor (1 or 3 dims) + obj_name Name of root object to apply. + """ + obj = self.worldbody.find(f"./body[@name='{obj_name}']") + if obj is None: + bodies = self.worldbody.findall("./body") + body_names = [body.get("name") for body in bodies if body.get("name") is not None] + raise ValueError(f"Object {obj_name} not found in arena; cannot set scale. Available objects: {body_names}") + self.object_scales[obj.get("name")] = scale + + # scale geoms + geom_pairs = self._get_geoms(obj) + for _, (_, element) in enumerate(geom_pairs): + g_pos = element.get("pos") + g_size = element.get("size") + if g_pos is not None: + g_pos = array_to_string(string_to_array(g_pos) * scale) + element.set("pos", g_pos) + if g_size is not None: + g_size_np = string_to_array(g_size) + # handle cases where size is not 3 dimensional + if len(g_size_np) == 3: + g_size_np = g_size_np * scale + elif len(g_size_np) == 2: + scale = np.array(scale).reshape(-1) + if len(scale) == 1: + g_size_np[1] *= scale + elif len(scale) == 3: + # g_size_np[0] *= np.mean(scale[:2]) + g_size_np[0] *= np.mean(scale[:2]) # width + g_size_np[1] *= scale[2] # height + else: + raise ValueError + else: + raise ValueError + g_size = array_to_string(g_size_np) + element.set("size", g_size) + + # scale meshes + meshes = self.asset.findall("mesh") + for elem in meshes: + m_scale = elem.get("scale") + if m_scale is not None: + m_scale = string_to_array(m_scale) + else: + m_scale = np.ones(3) + + m_scale *= scale + elem.set("scale", array_to_string(m_scale)) + + # scale bodies + body_pairs = self._get_elements(obj, "body") + for (_, elem) in body_pairs: + b_pos = elem.get("pos") + if b_pos is not None: + b_pos = string_to_array(b_pos) * scale + elem.set("pos", array_to_string(b_pos)) + + # scale joints + joint_pairs = self._get_elements(obj, "joint") + for (_, elem) in joint_pairs: + j_pos = elem.get("pos") + if j_pos is not None: + j_pos = string_to_array(j_pos) * scale + elem.set("pos", array_to_string(j_pos)) + + # scale sites + site_pairs = self._get_elements(self.worldbody, "site") + for (_, elem) in site_pairs: + s_pos = elem.get("pos") + if s_pos is not None: + s_pos = string_to_array(s_pos) * scale + elem.set("pos", array_to_string(s_pos)) + + s_size = elem.get("size") + if s_size is not None: + s_size_np = string_to_array(s_size) + # handle cases where size is not 3 dimensional + if len(s_size_np) == 3: + s_size_np = s_size_np * scale + elif len(s_size_np) == 2: + scale = np.array(scale).reshape(-1) + if len(scale) == 1: + s_size_np *= scale + elif len(scale) == 3: + s_size_np[0] *= np.mean(scale[:2]) # width + s_size_np[1] *= scale[2] # height + else: + raise ValueError + elif len(s_size_np) == 1: + s_size_np *= np.mean(scale) + else: + raise ValueError + s_size = array_to_string(s_size_np) + elem.set("size", s_size) diff --git a/robosuite/models/objects/objects.py b/robosuite/models/objects/objects.py index e7681f8f39..639186658e 100644 --- a/robosuite/models/objects/objects.py +++ b/robosuite/models/objects/objects.py @@ -95,6 +95,157 @@ def get_obj(self): assert self._obj is not None, "Object XML tree has not been generated yet!" return self._obj + def set_scale(self, scale, obj=None): + """ + Scales each geom, mesh, site, body, and joint ranges (for slide joints). + Called during initialization but can also be used externally. + Args: + scale (float or list of floats): Scale factor (1 or 3 dims) + obj (ET.Element): Root object to apply scaling to. Defaults to root object of model. + """ + if obj is None: + obj = self._obj + + self._scale = scale + + # Ensure scale is an array of length 3 + scale_array = np.array(self._scale).flatten() + if scale_array.size == 1: + scale_array = np.repeat(scale_array, 3) + elif scale_array.size != 3: + raise ValueError("Scale must be a scalar or a 3-element array.") + + geom_pairs = self._get_geoms(obj) + for _, (_, element) in enumerate(geom_pairs): + g_pos = element.get("pos") + g_size = element.get("size") + if g_pos is not None: + g_pos = array_to_string(string_to_array(g_pos) * scale_array) + element.set("pos", g_pos) + if g_size is not None: + g_size_np = string_to_array(g_size) + # Handle cases where size is not 3-dimensional + if len(g_size_np) == 3: + g_size_np = g_size_np * scale_array + elif len(g_size_np) == 2: + # For 2D size, assume [radius, height] for cylinders + g_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y + g_size_np[1] *= scale_array[2] # Scaling in z + else: + raise ValueError("Unsupported geom size dimensions.") + g_size = array_to_string(g_size_np) + element.set("size", g_size) + + # Scale meshes + meshes = self.asset.findall("mesh") + for elem in meshes: + m_scale = elem.get("scale") + if m_scale is not None: + m_scale = string_to_array(m_scale) + else: + m_scale = np.ones(3) + m_scale *= scale_array + elem.set("scale", array_to_string(m_scale)) + + # Scale bodies + body_pairs = self._get_elements(obj, "body") + for (_, elem) in body_pairs: + b_pos = elem.get("pos") + if b_pos is not None: + b_pos = string_to_array(b_pos) * scale_array + elem.set("pos", array_to_string(b_pos)) + + # Scale joints + joint_pairs = self._get_elements(obj, "joint") + for (_, elem) in joint_pairs: + j_pos = elem.get("pos") + if j_pos is not None: + j_pos = string_to_array(j_pos) * scale_array + elem.set("pos", array_to_string(j_pos)) + + # Scale joint ranges for slide joints + j_type = elem.get("type", "hinge") # Default joint type is 'hinge' if not specified + j_range = elem.get("range") + if j_range is not None and j_type == "slide": + # Get joint axis + j_axis = elem.get("axis", "1 0 0") # Default axis is [1, 0, 0] + j_axis = string_to_array(j_axis) + axis_norm = np.linalg.norm(j_axis) + if axis_norm > 0: + axis_unit = j_axis / axis_norm + else: + # Avoid division by zero + axis_unit = np.array([1.0, 0.0, 0.0]) + # Compute scaling factor along the joint axis + s = np.linalg.norm(axis_unit * scale_array) + # Scale the range + j_range_vals = string_to_array(j_range) + j_range_vals = j_range_vals * s + elem.set("range", array_to_string(j_range_vals)) + + # Scale sites + site_pairs = self._get_elements(obj, "site") + for (_, elem) in site_pairs: + s_pos = elem.get("pos") + if s_pos is not None: + s_pos = string_to_array(s_pos) * scale_array + elem.set("pos", array_to_string(s_pos)) + + s_size = elem.get("size") + if s_size is not None: + s_size_np = string_to_array(s_size) + if len(s_size_np) == 3: + s_size_np = s_size_np * scale_array + elif len(s_size_np) == 2: + s_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y + s_size_np[1] *= scale_array[2] # Scaling in z + elif len(s_size_np) == 1: + s_size_np *= np.mean(scale_array) + else: + raise ValueError("Unsupported site size dimensions.") + s_size = array_to_string(s_size_np) + elem.set("size", s_size) + + def _get_geoms(self, root, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a geom element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is a geom type + """ + return self._get_elements(root, "geom", _parent) + + def _get_elements(self, root, type, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a specific type of element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is of type + """ + # Initialize return array + elem_pairs = [] + # If the parent exists and this is a desired element, we add this current (parent, element) combo to the output + if _parent is not None and root.tag == type: + elem_pairs.append((_parent, root)) + # Loop through all children elements recursively and add to pairs + for child in root: + elem_pairs += self._get_elements(child, type, _parent=root) + + # Return all found pairs + return elem_pairs + def exclude_from_prefixing(self, inp): """ A function that should take in either an ET.Element or its attribute (str) and return either True or False,