Skip to content

Commit 502131b

Browse files
Merge pull request #638 from ARISE-Initiative/seg-bodies-in-arena
Enable instance segmentation to include bodies in arena
2 parents 51cc017 + 904251b commit 502131b

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

robosuite/demos/demo_segmentation.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,17 @@ def segmentation_to_rgb(seg_im, random_colors=False):
5757
parser.add_argument("--video-path", type=str, default="/tmp/video.mp4", help="Path to video file")
5858
parser.add_argument("--random-colors", action="store_true", help="Radnomize segmentation colors")
5959
parser.add_argument("--segmentation-level", type=str, default="element", help="instance, class, or element")
60+
parser.add_argument("--env-name", type=str, default="TwoArmHandover", help="Environment name")
61+
parser.add_argument("--camera", type=str, default="frontview", help="Camera name")
6062
args = parser.parse_args()
6163

6264
# Create dict to hold options that will be passed to env creation call
6365
options = {}
6466

6567
# Choose environment and add it to options
66-
options["env_name"] = "TwoArmHandover"
68+
options["env_name"] = args.env_name
6769
options["robots"] = ["Panda", "Panda"]
6870

69-
# Choose camera
70-
camera = "frontview"
71-
7271
# Choose segmentation type
7372
segmentation_level = args.segmentation_level # Options are {instance, class, element}
7473

@@ -80,7 +79,7 @@ def segmentation_to_rgb(seg_im, random_colors=False):
8079
ignore_done=True,
8180
use_camera_obs=True,
8281
control_freq=20,
83-
camera_names=camera,
82+
camera_names=args.camera,
8483
camera_segmentations=segmentation_level,
8584
camera_heights=512,
8685
camera_widths=512,
@@ -98,7 +97,7 @@ def segmentation_to_rgb(seg_im, random_colors=False):
9897
action = 0.5 * np.random.uniform(low, high)
9998
obs, reward, done, _ = env.step(action)
10099

101-
video_img = obs[f"{camera}_segmentation_{segmentation_level}"].squeeze(-1)[::-1]
100+
video_img = obs[f"{args.camera}_segmentation_{segmentation_level}"].squeeze(-1)[::-1]
102101
np.savetxt("/tmp/seg_{}.txt".format(i), video_img, fmt="%.2f")
103102
video_img = segmentation_to_rgb(video_img, args.random_colors)
104103
video_writer.append_data(video_img)

robosuite/models/tasks/task.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,41 @@
1+
import xml.etree.ElementTree as ET
12
from copy import deepcopy
23

4+
import mujoco
5+
36
from robosuite.models.objects import MujocoObject
47
from robosuite.models.robots import RobotModel
58
from robosuite.models.world import MujocoWorldBase
69
from robosuite.utils.mjcf_utils import get_ids
710

811

12+
def get_subtree_geom_ids_by_group(model: mujoco.MjModel, body_id: int, group: int) -> list[int]:
13+
"""Get all geoms belonging to a subtree starting at a given body, filtered by group.
14+
15+
Args:
16+
model: MuJoCo model.
17+
body_id: ID of body where subtree starts.
18+
group: Group ID to filter geoms.
19+
20+
Returns:
21+
A list containing all subtree geom ids in the specified group.
22+
23+
Adapted from https://github.com/kevinzakka/mink/blob/main/mink/utils.py
24+
"""
25+
26+
def gather_geoms(body_id: int) -> list[int]:
27+
geoms: list[int] = []
28+
geom_start = model.body_geomadr[body_id]
29+
geom_end = geom_start + model.body_geomnum[body_id]
30+
geoms.extend(geom_id for geom_id in range(geom_start, geom_end) if model.geom_group[geom_id] == group)
31+
children = [i for i in range(model.nbody) if model.body_parentid[i] == body_id]
32+
for child_id in children:
33+
geoms.extend(gather_geoms(child_id))
34+
return geoms
35+
36+
return gather_geoms(body_id)
37+
38+
939
class Task(MujocoWorldBase):
1040
"""
1141
Creates MJCF model for a task performed.
@@ -106,15 +136,32 @@ def generate_id_mappings(self, sim):
106136
for robot in self.mujoco_robots:
107137
models += [robot] + robot.models
108138

139+
worldbody = self.mujoco_arena.root.find("worldbody")
140+
exclude_bodies = ["table"]
141+
top_level_bodies = [
142+
body.attrib.get("name")
143+
for body in worldbody.findall("body")
144+
if body.attrib.get("name") not in exclude_bodies
145+
]
146+
models.extend(top_level_bodies)
147+
109148
# Parse all mujoco models from robots and objects
110149
for model in models:
111-
# Grab model class name and visual IDs
112-
cls = str(type(model)).split("'")[1].split(".")[-1]
113-
inst = model.name
114-
id_groups = [
115-
get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"),
116-
get_ids(sim=sim, elements=model.sites, element_type="site"),
117-
]
150+
if isinstance(model, str):
151+
body_name = model
152+
visual_group_number = 1
153+
body_id = sim.model.body_name2id(body_name)
154+
inst, cls = body_name, body_name
155+
geom_ids = get_subtree_geom_ids_by_group(sim.model, body_id, visual_group_number)
156+
id_groups = [geom_ids, []]
157+
else:
158+
# Grab model class name and visual IDs
159+
cls = str(type(model)).split("'")[1].split(".")[-1]
160+
inst = model.name
161+
id_groups = [
162+
get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"),
163+
get_ids(sim=sim, elements=model.sites, element_type="site"),
164+
]
118165
group_types = ("geom", "site")
119166
ids_to_instances = (self._geom_ids_to_instances, self._site_ids_to_instances)
120167
ids_to_classes = (self._geom_ids_to_classes, self._site_ids_to_classes)

0 commit comments

Comments
 (0)