Skip to content

Commit 4062cb4

Browse files
Add set_scale to Arena and MujocoObject (#643)
* Add scale setting and saving in arena * Require obj param for set_scale in arena * Update interface of set_scale in arena * Update set_scale object indexing * Format * Improve logging if invalid object to scale * Add option to set_scale of MujocoObjects * Format * Refactor scale setting code * format * refactor * re-add assert to get_ids * Format * Format
1 parent fe4b3c5 commit 4062cb4

File tree

3 files changed

+313
-156
lines changed

3 files changed

+313
-156
lines changed

robosuite/models/arenas/arena.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from typing import List, Union
2+
13
import numpy as np
24

35
from robosuite.models.base import MujocoXML
46
from robosuite.utils.mjcf_utils import (
57
ENVIRONMENT_COLLISION_COLOR,
68
array_to_string,
79
find_elements,
10+
get_elements,
811
new_body,
912
new_element,
1013
new_geom,
1114
new_joint,
1215
recolor_collision_geoms,
16+
scale_mjcf_model,
1317
string_to_array,
1418
)
1519

@@ -22,6 +26,7 @@ def __init__(self, fname):
2226
# Get references to floor and bottom
2327
self.bottom_pos = np.zeros(3)
2428
self.floor = self.worldbody.find("./geom[@name='floor']")
29+
self.object_scales = {}
2530

2631
# Add mocap bodies to self.root for mocap control in mjviewer UI for robot control
2732
mocap_body_1 = new_body(name="left_eef_target", pos="0 0 -1", mocap=True)
@@ -129,3 +134,28 @@ def _postprocess_arena(self):
129134
Runs any necessary post-processing on the imported Arena model
130135
"""
131136
pass
137+
138+
def set_scale(self, scale: Union[float, List[float]], obj_name: str):
139+
"""
140+
Scales each geom, mesh, site, and body under obj_name.
141+
Called during initialization but can also be used externally
142+
143+
Args:
144+
scale (float or list of floats): Scale factor (1 or 3 dims)
145+
obj_name Name of root object to apply.
146+
"""
147+
obj = self.worldbody.find(f"./body[@name='{obj_name}']")
148+
if obj is None:
149+
bodies = self.worldbody.findall("./body")
150+
body_names = [body.get("name") for body in bodies if body.get("name") is not None]
151+
raise ValueError(f"Object {obj_name} not found in arena; cannot set scale. Available objects: {body_names}")
152+
self.object_scales[obj.get("name")] = scale
153+
154+
# Use the centralized scaling utility function
155+
scale_mjcf_model(
156+
obj=obj,
157+
asset_root=self.asset,
158+
scale=scale,
159+
get_elements_func=get_elements,
160+
scale_slide_joints=False, # Arena doesn't handle slide joints
161+
)

robosuite/models/objects/objects.py

Lines changed: 33 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
add_prefix,
1414
array_to_string,
1515
find_elements,
16+
get_elements,
1617
new_joint,
18+
scale_mjcf_model,
1719
sort_elements,
1820
string_to_array,
1921
)
@@ -95,6 +97,28 @@ def get_obj(self):
9597
assert self._obj is not None, "Object XML tree has not been generated yet!"
9698
return self._obj
9799

100+
def set_scale(self, scale, obj=None):
101+
"""
102+
Scales each geom, mesh, site, body, and joint ranges (for slide joints).
103+
Called during initialization but can also be used externally.
104+
Args:
105+
scale (float or list of floats): Scale factor (1 or 3 dims)
106+
obj (ET.Element): Root object to apply scaling to. Defaults to root object of model.
107+
"""
108+
if obj is None:
109+
obj = self._obj
110+
111+
self._scale = scale
112+
113+
# Use the centralized scaling utility function
114+
scale_mjcf_model(
115+
obj=obj,
116+
asset_root=self.asset,
117+
scale=scale,
118+
get_elements_func=get_elements,
119+
scale_slide_joints=True,
120+
)
121+
98122
def exclude_from_prefixing(self, inp):
99123
"""
100124
A function that should take in either an ET.Element or its attribute (str) and return either True or False,
@@ -371,7 +395,7 @@ def _get_object_subtree(self):
371395
# Rename this top level object body (will have self.naming_prefix added later)
372396
obj.attrib["name"] = "main"
373397
# Get all geom_pairs in this tree
374-
geom_pairs = self._get_geoms(obj)
398+
geom_pairs = get_elements(obj, "geom")
375399

376400
# Define a temp function so we don't duplicate so much code
377401
obj_type = self.obj_type
@@ -441,46 +465,6 @@ def _duplicate_visual_from_collision(element):
441465
vis_element.set("name", vis_element.get("name") + "_visual")
442466
return vis_element
443467

444-
def _get_geoms(self, root, _parent=None):
445-
"""
446-
Helper function to recursively search through element tree starting at @root and returns
447-
a list of (parent, child) tuples where the child is a geom element
448-
449-
Args:
450-
root (ET.Element): Root of xml element tree to start recursively searching through
451-
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
452-
during the recursive call
453-
454-
Returns:
455-
list: array of (parent, child) tuples where the child element is a geom type
456-
"""
457-
return self._get_elements(root, "geom", _parent)
458-
459-
def _get_elements(self, root, type, _parent=None):
460-
"""
461-
Helper function to recursively search through element tree starting at @root and returns
462-
a list of (parent, child) tuples where the child is a specific type of element
463-
464-
Args:
465-
root (ET.Element): Root of xml element tree to start recursively searching through
466-
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
467-
during the recursive call
468-
469-
Returns:
470-
list: array of (parent, child) tuples where the child element is of type
471-
"""
472-
# Initialize return array
473-
elem_pairs = []
474-
# If the parent exists and this is a desired element, we add this current (parent, element) combo to the output
475-
if _parent is not None and root.tag == type:
476-
elem_pairs.append((_parent, root))
477-
# Loop through all children elements recursively and add to pairs
478-
for child in root:
479-
elem_pairs += self._get_elements(child, type, _parent=root)
480-
481-
# Return all found pairs
482-
return elem_pairs
483-
484468
def set_pos(self, pos):
485469
"""
486470
Set position of object position is defined as center of bounding box
@@ -518,91 +502,14 @@ def set_scale(self, scale, obj=None):
518502

519503
self._scale = scale
520504

521-
# scale geoms
522-
geom_pairs = self._get_geoms(obj)
523-
for _, (_, element) in enumerate(geom_pairs):
524-
g_pos = element.get("pos")
525-
g_size = element.get("size")
526-
if g_pos is not None:
527-
g_pos = array_to_string(string_to_array(g_pos) * self._scale)
528-
element.set("pos", g_pos)
529-
if g_size is not None:
530-
g_size_np = string_to_array(g_size)
531-
# handle cases where size is not 3 dimensional
532-
if len(g_size_np) == 3:
533-
g_size_np = g_size_np * self._scale
534-
elif len(g_size_np) == 2:
535-
scale = np.array(self._scale).reshape(-1)
536-
if len(scale) == 1:
537-
g_size_np[1] *= scale
538-
elif len(scale) == 3:
539-
# g_size_np[0] *= np.mean(scale[:2])
540-
g_size_np[0] *= np.mean(scale[:2]) # width
541-
g_size_np[1] *= scale[2] # height
542-
else:
543-
raise ValueError
544-
else:
545-
raise ValueError
546-
g_size = array_to_string(g_size_np)
547-
element.set("size", g_size)
548-
549-
# scale meshes
550-
meshes = self.asset.findall("mesh")
551-
for elem in meshes:
552-
m_scale = elem.get("scale")
553-
if m_scale is not None:
554-
m_scale = string_to_array(m_scale)
555-
else:
556-
m_scale = np.ones(3)
557-
558-
m_scale *= self._scale
559-
elem.set("scale", array_to_string(m_scale))
560-
561-
# scale bodies
562-
body_pairs = self._get_elements(obj, "body")
563-
for (_, elem) in body_pairs:
564-
b_pos = elem.get("pos")
565-
if b_pos is not None:
566-
b_pos = string_to_array(b_pos) * self._scale
567-
elem.set("pos", array_to_string(b_pos))
568-
569-
# scale joints
570-
joint_pairs = self._get_elements(obj, "joint")
571-
for (_, elem) in joint_pairs:
572-
j_pos = elem.get("pos")
573-
if j_pos is not None:
574-
j_pos = string_to_array(j_pos) * self._scale
575-
elem.set("pos", array_to_string(j_pos))
576-
577-
# scale sites
578-
site_pairs = self._get_elements(self.worldbody, "site")
579-
for (_, elem) in site_pairs:
580-
s_pos = elem.get("pos")
581-
if s_pos is not None:
582-
s_pos = string_to_array(s_pos) * self._scale
583-
elem.set("pos", array_to_string(s_pos))
584-
585-
s_size = elem.get("size")
586-
if s_size is not None:
587-
s_size_np = string_to_array(s_size)
588-
# handle cases where size is not 3 dimensional
589-
if len(s_size_np) == 3:
590-
s_size_np = s_size_np * self._scale
591-
elif len(s_size_np) == 2:
592-
scale = np.array(self._scale).reshape(-1)
593-
if len(scale) == 1:
594-
s_size_np *= scale
595-
elif len(scale) == 3:
596-
s_size_np[0] *= np.mean(scale[:2]) # width
597-
s_size_np[1] *= scale[2] # height
598-
else:
599-
raise ValueError
600-
elif len(s_size_np) == 1:
601-
s_size_np *= np.mean(self._scale)
602-
else:
603-
raise ValueError
604-
s_size = array_to_string(s_size_np)
605-
elem.set("size", s_size)
505+
# Use the centralized scaling utility function
506+
scale_mjcf_model(
507+
obj=obj,
508+
asset_root=self.asset,
509+
scale=scale,
510+
get_elements_func=get_elements,
511+
scale_slide_joints=False, # MujocoXMLObject doesn't handle slide joints
512+
)
606513

607514
@property
608515
def bottom_offset(self):

0 commit comments

Comments
 (0)