diff --git a/robosuite/models/arenas/arena.py b/robosuite/models/arenas/arena.py index ff775eca2a..4aaee3096c 100644 --- a/robosuite/models/arenas/arena.py +++ b/robosuite/models/arenas/arena.py @@ -1,3 +1,5 @@ +from typing import List, Union + import numpy as np from robosuite.models.base import MujocoXML @@ -10,6 +12,7 @@ new_geom, new_joint, recolor_collision_geoms, + scale_mjcf_model, string_to_array, ) @@ -22,6 +25,7 @@ def __init__(self, fname): # Get references to floor and bottom self.bottom_pos = np.zeros(3) self.floor = self.worldbody.find("./geom[@name='floor']") + self.object_scales = {} # Add mocap bodies to self.root for mocap control in mjviewer UI for robot control mocap_body_1 = new_body(name="left_eef_target", pos="0 0 -1", mocap=True) @@ -129,3 +133,69 @@ def _postprocess_arena(self): Runs any necessary post-processing on the imported Arena model """ pass + + def _get_geoms(self, root, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a geom element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is a geom type + """ + return self._get_elements(root, "geom", _parent) + + def _get_elements(self, root, type, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a specific type of element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is of type + """ + # Initialize return array + elem_pairs = [] + # If the parent exists and this is a desired element, we add this current (parent, element) combo to the output + if _parent is not None and root.tag == type: + elem_pairs.append((_parent, root)) + # Loop through all children elements recursively and add to pairs + for child in root: + elem_pairs += self._get_elements(child, type, _parent=root) + + # Return all found pairs + return elem_pairs + + def set_scale(self, scale: Union[float, List[float]], obj_name: str): + """ + Scales each geom, mesh, site, and body under obj_name. + Called during initialization but can also be used externally + + Args: + scale (float or list of floats): Scale factor (1 or 3 dims) + obj_name Name of root object to apply. + """ + obj = self.worldbody.find(f"./body[@name='{obj_name}']") + if obj is None: + bodies = self.worldbody.findall("./body") + body_names = [body.get("name") for body in bodies if body.get("name") is not None] + raise ValueError(f"Object {obj_name} not found in arena; cannot set scale. Available objects: {body_names}") + self.object_scales[obj.get("name")] = scale + + # Use the centralized scaling utility function + scale_mjcf_model( + obj=obj, + asset_root=self.asset, + scale=scale, + get_elements_func=self._get_elements, + get_geoms_func=self._get_geoms, + scale_slide_joints=False, # Arena doesn't handle slide joints + ) diff --git a/robosuite/models/objects/objects.py b/robosuite/models/objects/objects.py index e7681f8f39..cf1df260e3 100644 --- a/robosuite/models/objects/objects.py +++ b/robosuite/models/objects/objects.py @@ -14,6 +14,7 @@ array_to_string, find_elements, new_joint, + scale_mjcf_model, sort_elements, string_to_array, ) @@ -95,6 +96,69 @@ def get_obj(self): assert self._obj is not None, "Object XML tree has not been generated yet!" return self._obj + def set_scale(self, scale, obj=None): + """ + Scales each geom, mesh, site, body, and joint ranges (for slide joints). + Called during initialization but can also be used externally. + Args: + scale (float or list of floats): Scale factor (1 or 3 dims) + obj (ET.Element): Root object to apply scaling to. Defaults to root object of model. + """ + if obj is None: + obj = self._obj + + self._scale = scale + + # Use the centralized scaling utility function + scale_mjcf_model( + obj=obj, + asset_root=self.asset, + scale=scale, + get_elements_func=self._get_elements, + get_geoms_func=self._get_geoms, + scale_slide_joints=True, + ) + + def _get_geoms(self, root, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a geom element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is a geom type + """ + return self._get_elements(root, "geom", _parent) + + def _get_elements(self, root, type, _parent=None): + """ + Helper function to recursively search through element tree starting at @root and returns + a list of (parent, child) tuples where the child is a specific type of element + + Args: + root (ET.Element): Root of xml element tree to start recursively searching through + _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set + during the recursive call + + Returns: + list: array of (parent, child) tuples where the child element is of type + """ + # Initialize return array + elem_pairs = [] + # If the parent exists and this is a desired element, we add this current (parent, element) combo to the output + if _parent is not None and root.tag == type: + elem_pairs.append((_parent, root)) + # Loop through all children elements recursively and add to pairs + for child in root: + elem_pairs += self._get_elements(child, type, _parent=root) + + # Return all found pairs + return elem_pairs + def exclude_from_prefixing(self, inp): """ A function that should take in either an ET.Element or its attribute (str) and return either True or False, @@ -518,91 +582,15 @@ def set_scale(self, scale, obj=None): self._scale = scale - # scale geoms - geom_pairs = self._get_geoms(obj) - for _, (_, element) in enumerate(geom_pairs): - g_pos = element.get("pos") - g_size = element.get("size") - if g_pos is not None: - g_pos = array_to_string(string_to_array(g_pos) * self._scale) - element.set("pos", g_pos) - if g_size is not None: - g_size_np = string_to_array(g_size) - # handle cases where size is not 3 dimensional - if len(g_size_np) == 3: - g_size_np = g_size_np * self._scale - elif len(g_size_np) == 2: - scale = np.array(self._scale).reshape(-1) - if len(scale) == 1: - g_size_np[1] *= scale - elif len(scale) == 3: - # g_size_np[0] *= np.mean(scale[:2]) - g_size_np[0] *= np.mean(scale[:2]) # width - g_size_np[1] *= scale[2] # height - else: - raise ValueError - else: - raise ValueError - g_size = array_to_string(g_size_np) - element.set("size", g_size) - - # scale meshes - meshes = self.asset.findall("mesh") - for elem in meshes: - m_scale = elem.get("scale") - if m_scale is not None: - m_scale = string_to_array(m_scale) - else: - m_scale = np.ones(3) - - m_scale *= self._scale - elem.set("scale", array_to_string(m_scale)) - - # scale bodies - body_pairs = self._get_elements(obj, "body") - for (_, elem) in body_pairs: - b_pos = elem.get("pos") - if b_pos is not None: - b_pos = string_to_array(b_pos) * self._scale - elem.set("pos", array_to_string(b_pos)) - - # scale joints - joint_pairs = self._get_elements(obj, "joint") - for (_, elem) in joint_pairs: - j_pos = elem.get("pos") - if j_pos is not None: - j_pos = string_to_array(j_pos) * self._scale - elem.set("pos", array_to_string(j_pos)) - - # scale sites - site_pairs = self._get_elements(self.worldbody, "site") - for (_, elem) in site_pairs: - s_pos = elem.get("pos") - if s_pos is not None: - s_pos = string_to_array(s_pos) * self._scale - elem.set("pos", array_to_string(s_pos)) - - s_size = elem.get("size") - if s_size is not None: - s_size_np = string_to_array(s_size) - # handle cases where size is not 3 dimensional - if len(s_size_np) == 3: - s_size_np = s_size_np * self._scale - elif len(s_size_np) == 2: - scale = np.array(self._scale).reshape(-1) - if len(scale) == 1: - s_size_np *= scale - elif len(scale) == 3: - s_size_np[0] *= np.mean(scale[:2]) # width - s_size_np[1] *= scale[2] # height - else: - raise ValueError - elif len(s_size_np) == 1: - s_size_np *= np.mean(self._scale) - else: - raise ValueError - s_size = array_to_string(s_size_np) - elem.set("size", s_size) + # Use the centralized scaling utility function + scale_mjcf_model( + obj=obj, + asset_root=self.asset, + scale=scale, + get_elements_func=self._get_elements, + get_geoms_func=self._get_geoms, + scale_slide_joints=False, # MujocoXMLObject doesn't handle slide joints + ) @property def bottom_offset(self): diff --git a/robosuite/utils/mjcf_utils.py b/robosuite/utils/mjcf_utils.py index 6f7207a81f..cb5ff05ad4 100644 --- a/robosuite/utils/mjcf_utils.py +++ b/robosuite/utils/mjcf_utils.py @@ -872,41 +872,224 @@ def save_sim_model(sim, fname): def get_ids(sim, elements, element_type="geom", inplace=False): """ - Grabs the mujoco IDs for each element in @elements, corresponding to the specified @element_type. + Grabs the ids corresponding to @elements. If the inputted elements are already a list of ids, immediately + returns that list. Args: - sim (MjSim): Active mujoco simulation object - elements (str or list or dict): Element(s) to convert into IDs. Note that the return type corresponds to - @elements type, where each element name is replaced with the ID - element_type (str): The type of element to grab ID for. Options are {geom, body, site} - inplace (bool): If False, will create a copy of @elements to prevent overwriting the original data structure + sim (MjSim): Mujoco sim reference + elements (str or list or int): Object(s) to grab ids for. Can be a string (name), a list of strings (names), + or an int / list of ints (ids). Also supported are lists of mixed types + element_type (str): Type of element to grab ids for. + Options are {'body', 'geom', 'site', 'joint', 'actuator', 'sensor', 'tendon', 'camera', 'light'} + inplace (bool): If True, will replace the inputted @elements list in-place Returns: - str or list or dict: IDs corresponding to @elements. + list: id(s) corresponding to @elements """ + if type(elements) is not list: + elements = [elements] if not inplace: - # Copy elements first so we don't write to the underlying object - elements = deepcopy(elements) - # Choose what to do based on elements type - if isinstance(elements, str): - # We simply return the value of this single element - assert element_type in { - "geom", - "body", - "site", - }, f"element_type must be either geom, body, or site. Got: {element_type}" - if element_type == "geom": - elements = sim.model.geom_name2id(elements) - elif element_type == "body": - elements = sim.model.body_name2id(elements) - else: # site - elements = sim.model.site_name2id(elements) - elif isinstance(elements, dict): - # Iterate over each element in dict and recursively repeat - for name, ele in elements: - elements[name] = get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) - else: # We assume this is an iterable array - assert isinstance(elements, Iterable), "Elements must be iterable for get_id!" - elements = [get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) for ele in elements] + elements = list(elements) + + # Iterate through all elements and grab their corresponding IDs + for i, element in enumerate(elements): + if type(element) is not int: + element_func = sim.model.__getattribute__("{}_name2id".format(element_type)) + elements[i] = element_func(element) return elements + + +def normalize_scale_array(scale): + """ + Normalizes a scale factor to be a 3-element numpy array. + + Args: + scale (float or array-like): Scale factor (1 or 3 dims) + + Returns: + np.array: 3-element scale array + + Raises: + ValueError: If scale is not scalar or 3-element array + """ + scale_array = np.array(scale).flatten() + if scale_array.size == 1: + scale_array = np.repeat(scale_array, 3) + elif scale_array.size != 3: + raise ValueError("Scale must be a scalar or a 3-element array.") + return scale_array + + +def scale_geom_element(element, scale_array): + """ + Scales a single geom element's position and size. + + Args: + element (ET.Element): Geom element to scale + scale_array (np.array): 3-element scale array + """ + g_pos = element.get("pos") + g_size = element.get("size") + + if g_pos is not None: + g_pos = array_to_string(string_to_array(g_pos) * scale_array) + element.set("pos", g_pos) + + if g_size is not None: + g_size_np = string_to_array(g_size) + # Handle cases where size is not 3-dimensional + if len(g_size_np) == 3: + g_size_np = g_size_np * scale_array + elif len(g_size_np) == 2: + # For 2D size, assume [radius, height] for cylinders + g_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y + g_size_np[1] *= scale_array[2] # Scaling in z + elif len(g_size_np) == 1: + g_size_np *= np.mean(scale_array) + else: + raise ValueError("Unsupported geom size dimensions.") + g_size = array_to_string(g_size_np) + element.set("size", g_size) + + +def scale_mesh_element(element, scale_array): + """ + Scales a single mesh element. + + Args: + element (ET.Element): Mesh element to scale + scale_array (np.array): 3-element scale array + """ + m_scale = element.get("scale") + if m_scale is not None: + m_scale = string_to_array(m_scale) + else: + m_scale = np.ones(3) + m_scale *= scale_array + element.set("scale", array_to_string(m_scale)) + + +def scale_body_element(element, scale_array): + """ + Scales a single body element's position. + + Args: + element (ET.Element): Body element to scale + scale_array (np.array): 3-element scale array + """ + b_pos = element.get("pos") + if b_pos is not None: + b_pos = string_to_array(b_pos) * scale_array + element.set("pos", array_to_string(b_pos)) + + +def scale_joint_element(element, scale_array, scale_slide_joints=True): + """ + Scales a single joint element's position and optionally range for slide joints. + + Args: + element (ET.Element): Joint element to scale + scale_array (np.array): 3-element scale array + scale_slide_joints (bool): Whether to scale ranges for slide joints + """ + j_pos = element.get("pos") + if j_pos is not None: + j_pos = string_to_array(j_pos) * scale_array + element.set("pos", array_to_string(j_pos)) + + # Scale joint ranges for slide joints if requested + if scale_slide_joints: + j_type = element.get("type", "hinge") # Default joint type is 'hinge' if not specified + j_range = element.get("range") + if j_range is not None and j_type == "slide": + # Get joint axis + j_axis = element.get("axis", "1 0 0") # Default axis is [1, 0, 0] + j_axis = string_to_array(j_axis) + axis_norm = np.linalg.norm(j_axis) + if axis_norm > 0: + axis_unit = j_axis / axis_norm + else: + # Avoid division by zero + axis_unit = np.array([1.0, 0.0, 0.0]) + # Compute scaling factor along the joint axis + s = np.linalg.norm(axis_unit * scale_array) + # Scale the range + j_range_vals = string_to_array(j_range) + j_range_vals = j_range_vals * s + element.set("range", array_to_string(j_range_vals)) + + +def scale_site_element(element, scale_array): + """ + Scales a single site element's position and size. + + Args: + element (ET.Element): Site element to scale + scale_array (np.array): 3-element scale array + """ + s_pos = element.get("pos") + if s_pos is not None: + s_pos = string_to_array(s_pos) * scale_array + element.set("pos", array_to_string(s_pos)) + + s_size = element.get("size") + if s_size is not None: + s_size_np = string_to_array(s_size) + if len(s_size_np) == 3: + s_size_np = s_size_np * scale_array + elif len(s_size_np) == 2: + s_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y + s_size_np[1] *= scale_array[2] # Scaling in z + elif len(s_size_np) == 1: + s_size_np *= np.mean(scale_array) + else: + raise ValueError("Unsupported site size dimensions.") + s_size = array_to_string(s_size_np) + element.set("size", s_size) + + +def scale_mjcf_model(obj, asset_root, scale, get_elements_func, get_geoms_func, scale_slide_joints=True): + """ + Scales all elements (geoms, meshes, bodies, joints, sites) in an MJCF model. + + Args: + obj (ET.Element): Root object element to scale + asset_root (ET.Element): Asset root element containing meshes + scale (float or array-like): Scale factor (1 or 3 dims) + get_elements_func (callable): Function to get elements of a specific type from obj + get_geoms_func (callable): Function to get geom elements from obj + scale_slide_joints (bool): Whether to scale ranges for slide joints + + Returns: + np.array: The normalized 3-element scale array that was applied + """ + # Normalize scale to 3-element array + scale_array = normalize_scale_array(scale) + + # Scale geoms + geom_pairs = get_geoms_func(obj) + for _, (_, element) in enumerate(geom_pairs): + scale_geom_element(element, scale_array) + + # Scale meshes + meshes = asset_root.findall("mesh") + for elem in meshes: + scale_mesh_element(elem, scale_array) + + # Scale bodies + body_pairs = get_elements_func(obj, "body") + for (_, elem) in body_pairs: + scale_body_element(elem, scale_array) + + # Scale joints + joint_pairs = get_elements_func(obj, "joint") + for (_, elem) in joint_pairs: + scale_joint_element(elem, scale_array, scale_slide_joints) + + # Scale sites + site_pairs = get_elements_func(obj, "site") + for (_, elem) in site_pairs: + scale_site_element(elem, scale_array) + + return scale_array