Skip to content

Add set_scale to Arena and MujocoObject #643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
145 changes: 145 additions & 0 deletions robosuite/models/arenas/arena.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Union

import numpy as np

from robosuite.models.base import MujocoXML
Expand All @@ -22,6 +24,7 @@ def __init__(self, fname):
# Get references to floor and bottom
self.bottom_pos = np.zeros(3)
self.floor = self.worldbody.find("./geom[@name='floor']")
self.object_scales = {}

# Add mocap bodies to self.root for mocap control in mjviewer UI for robot control
mocap_body_1 = new_body(name="left_eef_target", pos="0 0 -1", mocap=True)
Expand Down Expand Up @@ -129,3 +132,145 @@ def _postprocess_arena(self):
Runs any necessary post-processing on the imported Arena model
"""
pass

def _get_geoms(self, root, _parent=None):
"""
Helper function to recursively search through element tree starting at @root and returns
a list of (parent, child) tuples where the child is a geom element

Args:
root (ET.Element): Root of xml element tree to start recursively searching through
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
during the recursive call

Returns:
list: array of (parent, child) tuples where the child element is a geom type
"""
return self._get_elements(root, "geom", _parent)

def _get_elements(self, root, type, _parent=None):
"""
Helper function to recursively search through element tree starting at @root and returns
a list of (parent, child) tuples where the child is a specific type of element

Args:
root (ET.Element): Root of xml element tree to start recursively searching through
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
during the recursive call

Returns:
list: array of (parent, child) tuples where the child element is of type
"""
# Initialize return array
elem_pairs = []
# If the parent exists and this is a desired element, we add this current (parent, element) combo to the output
if _parent is not None and root.tag == type:
elem_pairs.append((_parent, root))
# Loop through all children elements recursively and add to pairs
for child in root:
elem_pairs += self._get_elements(child, type, _parent=root)

# Return all found pairs
return elem_pairs

def set_scale(self, scale: Union[float, List[float]], obj_name: str):
"""
Scales each geom, mesh, site, and body under obj_name.
Called during initialization but can also be used externally

Args:
scale (float or list of floats): Scale factor (1 or 3 dims)
obj_name Name of root object to apply.
"""
obj = self.worldbody.find(f"./body[@name='{obj_name}']")
if obj is None:
bodies = self.worldbody.findall("./body")
body_names = [body.get("name") for body in bodies if body.get("name") is not None]
raise ValueError(f"Object {obj_name} not found in arena; cannot set scale. Available objects: {body_names}")
self.object_scales[obj.get("name")] = scale

# scale geoms
geom_pairs = self._get_geoms(obj)
for _, (_, element) in enumerate(geom_pairs):
g_pos = element.get("pos")
g_size = element.get("size")
if g_pos is not None:
g_pos = array_to_string(string_to_array(g_pos) * scale)
element.set("pos", g_pos)
if g_size is not None:
g_size_np = string_to_array(g_size)
# handle cases where size is not 3 dimensional
if len(g_size_np) == 3:
g_size_np = g_size_np * scale
elif len(g_size_np) == 2:
scale = np.array(scale).reshape(-1)
if len(scale) == 1:
g_size_np[1] *= scale
elif len(scale) == 3:
# g_size_np[0] *= np.mean(scale[:2])
g_size_np[0] *= np.mean(scale[:2]) # width
g_size_np[1] *= scale[2] # height
else:
raise ValueError
else:
raise ValueError
g_size = array_to_string(g_size_np)
element.set("size", g_size)

# scale meshes
meshes = self.asset.findall("mesh")
for elem in meshes:
m_scale = elem.get("scale")
if m_scale is not None:
m_scale = string_to_array(m_scale)
else:
m_scale = np.ones(3)

m_scale *= scale
elem.set("scale", array_to_string(m_scale))

# scale bodies
body_pairs = self._get_elements(obj, "body")
for (_, elem) in body_pairs:
b_pos = elem.get("pos")
if b_pos is not None:
b_pos = string_to_array(b_pos) * scale
elem.set("pos", array_to_string(b_pos))

# scale joints
joint_pairs = self._get_elements(obj, "joint")
for (_, elem) in joint_pairs:
j_pos = elem.get("pos")
if j_pos is not None:
j_pos = string_to_array(j_pos) * scale
elem.set("pos", array_to_string(j_pos))

# scale sites
site_pairs = self._get_elements(self.worldbody, "site")
for (_, elem) in site_pairs:
s_pos = elem.get("pos")
if s_pos is not None:
s_pos = string_to_array(s_pos) * scale
elem.set("pos", array_to_string(s_pos))

s_size = elem.get("size")
if s_size is not None:
s_size_np = string_to_array(s_size)
# handle cases where size is not 3 dimensional
if len(s_size_np) == 3:
s_size_np = s_size_np * scale
elif len(s_size_np) == 2:
scale = np.array(scale).reshape(-1)
if len(scale) == 1:
s_size_np *= scale
elif len(scale) == 3:
s_size_np[0] *= np.mean(scale[:2]) # width
s_size_np[1] *= scale[2] # height
else:
raise ValueError
elif len(s_size_np) == 1:
s_size_np *= np.mean(scale)
else:
raise ValueError
s_size = array_to_string(s_size_np)
elem.set("size", s_size)
151 changes: 151 additions & 0 deletions robosuite/models/objects/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,157 @@ def get_obj(self):
assert self._obj is not None, "Object XML tree has not been generated yet!"
return self._obj

def set_scale(self, scale, obj=None):
"""
Scales each geom, mesh, site, body, and joint ranges (for slide joints).
Called during initialization but can also be used externally.
Args:
scale (float or list of floats): Scale factor (1 or 3 dims)
obj (ET.Element): Root object to apply scaling to. Defaults to root object of model.
"""
if obj is None:
obj = self._obj

self._scale = scale

# Ensure scale is an array of length 3
scale_array = np.array(self._scale).flatten()
if scale_array.size == 1:
scale_array = np.repeat(scale_array, 3)
elif scale_array.size != 3:
raise ValueError("Scale must be a scalar or a 3-element array.")

geom_pairs = self._get_geoms(obj)
for _, (_, element) in enumerate(geom_pairs):
g_pos = element.get("pos")
g_size = element.get("size")
if g_pos is not None:
g_pos = array_to_string(string_to_array(g_pos) * scale_array)
element.set("pos", g_pos)
if g_size is not None:
g_size_np = string_to_array(g_size)
# Handle cases where size is not 3-dimensional
if len(g_size_np) == 3:
g_size_np = g_size_np * scale_array
elif len(g_size_np) == 2:
# For 2D size, assume [radius, height] for cylinders
g_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y
g_size_np[1] *= scale_array[2] # Scaling in z
else:
raise ValueError("Unsupported geom size dimensions.")
g_size = array_to_string(g_size_np)
element.set("size", g_size)

# Scale meshes
meshes = self.asset.findall("mesh")
for elem in meshes:
m_scale = elem.get("scale")
if m_scale is not None:
m_scale = string_to_array(m_scale)
else:
m_scale = np.ones(3)
m_scale *= scale_array
elem.set("scale", array_to_string(m_scale))

# Scale bodies
body_pairs = self._get_elements(obj, "body")
for (_, elem) in body_pairs:
b_pos = elem.get("pos")
if b_pos is not None:
b_pos = string_to_array(b_pos) * scale_array
elem.set("pos", array_to_string(b_pos))

# Scale joints
joint_pairs = self._get_elements(obj, "joint")
for (_, elem) in joint_pairs:
j_pos = elem.get("pos")
if j_pos is not None:
j_pos = string_to_array(j_pos) * scale_array
elem.set("pos", array_to_string(j_pos))

# Scale joint ranges for slide joints
j_type = elem.get("type", "hinge") # Default joint type is 'hinge' if not specified
j_range = elem.get("range")
if j_range is not None and j_type == "slide":
# Get joint axis
j_axis = elem.get("axis", "1 0 0") # Default axis is [1, 0, 0]
j_axis = string_to_array(j_axis)
axis_norm = np.linalg.norm(j_axis)
if axis_norm > 0:
axis_unit = j_axis / axis_norm
else:
# Avoid division by zero
axis_unit = np.array([1.0, 0.0, 0.0])
# Compute scaling factor along the joint axis
s = np.linalg.norm(axis_unit * scale_array)
# Scale the range
j_range_vals = string_to_array(j_range)
j_range_vals = j_range_vals * s
elem.set("range", array_to_string(j_range_vals))

# Scale sites
site_pairs = self._get_elements(obj, "site")
for (_, elem) in site_pairs:
s_pos = elem.get("pos")
if s_pos is not None:
s_pos = string_to_array(s_pos) * scale_array
elem.set("pos", array_to_string(s_pos))

s_size = elem.get("size")
if s_size is not None:
s_size_np = string_to_array(s_size)
if len(s_size_np) == 3:
s_size_np = s_size_np * scale_array
elif len(s_size_np) == 2:
s_size_np[0] *= np.mean(scale_array[:2]) # Average scaling in x and y
s_size_np[1] *= scale_array[2] # Scaling in z
elif len(s_size_np) == 1:
s_size_np *= np.mean(scale_array)
else:
raise ValueError("Unsupported site size dimensions.")
s_size = array_to_string(s_size_np)
elem.set("size", s_size)

def _get_geoms(self, root, _parent=None):
"""
Helper function to recursively search through element tree starting at @root and returns
a list of (parent, child) tuples where the child is a geom element

Args:
root (ET.Element): Root of xml element tree to start recursively searching through
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
during the recursive call

Returns:
list: array of (parent, child) tuples where the child element is a geom type
"""
return self._get_elements(root, "geom", _parent)

def _get_elements(self, root, type, _parent=None):
"""
Helper function to recursively search through element tree starting at @root and returns
a list of (parent, child) tuples where the child is a specific type of element

Args:
root (ET.Element): Root of xml element tree to start recursively searching through
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
during the recursive call

Returns:
list: array of (parent, child) tuples where the child element is of type
"""
# Initialize return array
elem_pairs = []
# If the parent exists and this is a desired element, we add this current (parent, element) combo to the output
if _parent is not None and root.tag == type:
elem_pairs.append((_parent, root))
# Loop through all children elements recursively and add to pairs
for child in root:
elem_pairs += self._get_elements(child, type, _parent=root)

# Return all found pairs
return elem_pairs

def exclude_from_prefixing(self, inp):
"""
A function that should take in either an ET.Element or its attribute (str) and return either True or False,
Expand Down
Loading