|
| 1 | +import xml.etree.ElementTree as ET |
1 | 2 | from copy import deepcopy |
2 | 3 |
|
| 4 | +import mujoco |
| 5 | + |
3 | 6 | from robosuite.models.objects import MujocoObject |
4 | 7 | from robosuite.models.robots import RobotModel |
5 | 8 | from robosuite.models.world import MujocoWorldBase |
6 | 9 | from robosuite.utils.mjcf_utils import get_ids |
7 | 10 |
|
8 | 11 |
|
| 12 | +def get_subtree_geom_ids_by_group(model: mujoco.MjModel, body_id: int, group: int) -> list[int]: |
| 13 | + """Get all geoms belonging to a subtree starting at a given body, filtered by group. |
| 14 | +
|
| 15 | + Args: |
| 16 | + model: MuJoCo model. |
| 17 | + body_id: ID of body where subtree starts. |
| 18 | + group: Group ID to filter geoms. |
| 19 | +
|
| 20 | + Returns: |
| 21 | + A list containing all subtree geom ids in the specified group. |
| 22 | +
|
| 23 | + Adapted from https://github.com/kevinzakka/mink/blob/main/mink/utils.py |
| 24 | + """ |
| 25 | + |
| 26 | + def gather_geoms(body_id: int) -> list[int]: |
| 27 | + geoms: list[int] = [] |
| 28 | + geom_start = model.body_geomadr[body_id] |
| 29 | + geom_end = geom_start + model.body_geomnum[body_id] |
| 30 | + geoms.extend(geom_id for geom_id in range(geom_start, geom_end) if model.geom_group[geom_id] == group) |
| 31 | + children = [i for i in range(model.nbody) if model.body_parentid[i] == body_id] |
| 32 | + for child_id in children: |
| 33 | + geoms.extend(gather_geoms(child_id)) |
| 34 | + return geoms |
| 35 | + |
| 36 | + return gather_geoms(body_id) |
| 37 | + |
| 38 | + |
9 | 39 | class Task(MujocoWorldBase): |
10 | 40 | """ |
11 | 41 | Creates MJCF model for a task performed. |
@@ -106,15 +136,32 @@ def generate_id_mappings(self, sim): |
106 | 136 | for robot in self.mujoco_robots: |
107 | 137 | models += [robot] + robot.models |
108 | 138 |
|
| 139 | + worldbody = self.mujoco_arena.root.find("worldbody") |
| 140 | + exclude_bodies = ["table"] |
| 141 | + top_level_bodies = [ |
| 142 | + body.attrib.get("name") |
| 143 | + for body in worldbody.findall("body") |
| 144 | + if body.attrib.get("name") not in exclude_bodies |
| 145 | + ] |
| 146 | + models.extend(top_level_bodies) |
| 147 | + |
109 | 148 | # Parse all mujoco models from robots and objects |
110 | 149 | for model in models: |
111 | | - # Grab model class name and visual IDs |
112 | | - cls = str(type(model)).split("'")[1].split(".")[-1] |
113 | | - inst = model.name |
114 | | - id_groups = [ |
115 | | - get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"), |
116 | | - get_ids(sim=sim, elements=model.sites, element_type="site"), |
117 | | - ] |
| 150 | + if isinstance(model, str): |
| 151 | + body_name = model |
| 152 | + visual_group_number = 1 |
| 153 | + body_id = sim.model.body_name2id(body_name) |
| 154 | + inst, cls = body_name, body_name |
| 155 | + geom_ids = get_subtree_geom_ids_by_group(sim.model, body_id, visual_group_number) |
| 156 | + id_groups = [geom_ids, []] |
| 157 | + else: |
| 158 | + # Grab model class name and visual IDs |
| 159 | + cls = str(type(model)).split("'")[1].split(".")[-1] |
| 160 | + inst = model.name |
| 161 | + id_groups = [ |
| 162 | + get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"), |
| 163 | + get_ids(sim=sim, elements=model.sites, element_type="site"), |
| 164 | + ] |
118 | 165 | group_types = ("geom", "site") |
119 | 166 | ids_to_instances = (self._geom_ids_to_instances, self._site_ids_to_instances) |
120 | 167 | ids_to_classes = (self._geom_ids_to_classes, self._site_ids_to_classes) |
|
0 commit comments