|
13 | 13 | add_prefix, |
14 | 14 | array_to_string, |
15 | 15 | find_elements, |
| 16 | + get_elements, |
16 | 17 | new_joint, |
| 18 | + scale_mjcf_model, |
17 | 19 | sort_elements, |
18 | 20 | string_to_array, |
19 | 21 | ) |
@@ -95,6 +97,28 @@ def get_obj(self): |
95 | 97 | assert self._obj is not None, "Object XML tree has not been generated yet!" |
96 | 98 | return self._obj |
97 | 99 |
|
| 100 | + def set_scale(self, scale, obj=None): |
| 101 | + """ |
| 102 | + Scales each geom, mesh, site, body, and joint ranges (for slide joints). |
| 103 | + Called during initialization but can also be used externally. |
| 104 | + Args: |
| 105 | + scale (float or list of floats): Scale factor (1 or 3 dims) |
| 106 | + obj (ET.Element): Root object to apply scaling to. Defaults to root object of model. |
| 107 | + """ |
| 108 | + if obj is None: |
| 109 | + obj = self._obj |
| 110 | + |
| 111 | + self._scale = scale |
| 112 | + |
| 113 | + # Use the centralized scaling utility function |
| 114 | + scale_mjcf_model( |
| 115 | + obj=obj, |
| 116 | + asset_root=self.asset, |
| 117 | + scale=scale, |
| 118 | + get_elements_func=get_elements, |
| 119 | + scale_slide_joints=True, |
| 120 | + ) |
| 121 | + |
98 | 122 | def exclude_from_prefixing(self, inp): |
99 | 123 | """ |
100 | 124 | A function that should take in either an ET.Element or its attribute (str) and return either True or False, |
@@ -371,7 +395,7 @@ def _get_object_subtree(self): |
371 | 395 | # Rename this top level object body (will have self.naming_prefix added later) |
372 | 396 | obj.attrib["name"] = "main" |
373 | 397 | # Get all geom_pairs in this tree |
374 | | - geom_pairs = self._get_geoms(obj) |
| 398 | + geom_pairs = get_elements(obj, "geom") |
375 | 399 |
|
376 | 400 | # Define a temp function so we don't duplicate so much code |
377 | 401 | obj_type = self.obj_type |
@@ -441,46 +465,6 @@ def _duplicate_visual_from_collision(element): |
441 | 465 | vis_element.set("name", vis_element.get("name") + "_visual") |
442 | 466 | return vis_element |
443 | 467 |
|
444 | | - def _get_geoms(self, root, _parent=None): |
445 | | - """ |
446 | | - Helper function to recursively search through element tree starting at @root and returns |
447 | | - a list of (parent, child) tuples where the child is a geom element |
448 | | -
|
449 | | - Args: |
450 | | - root (ET.Element): Root of xml element tree to start recursively searching through |
451 | | - _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set |
452 | | - during the recursive call |
453 | | -
|
454 | | - Returns: |
455 | | - list: array of (parent, child) tuples where the child element is a geom type |
456 | | - """ |
457 | | - return self._get_elements(root, "geom", _parent) |
458 | | - |
459 | | - def _get_elements(self, root, type, _parent=None): |
460 | | - """ |
461 | | - Helper function to recursively search through element tree starting at @root and returns |
462 | | - a list of (parent, child) tuples where the child is a specific type of element |
463 | | -
|
464 | | - Args: |
465 | | - root (ET.Element): Root of xml element tree to start recursively searching through |
466 | | - _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set |
467 | | - during the recursive call |
468 | | -
|
469 | | - Returns: |
470 | | - list: array of (parent, child) tuples where the child element is of type |
471 | | - """ |
472 | | - # Initialize return array |
473 | | - elem_pairs = [] |
474 | | - # If the parent exists and this is a desired element, we add this current (parent, element) combo to the output |
475 | | - if _parent is not None and root.tag == type: |
476 | | - elem_pairs.append((_parent, root)) |
477 | | - # Loop through all children elements recursively and add to pairs |
478 | | - for child in root: |
479 | | - elem_pairs += self._get_elements(child, type, _parent=root) |
480 | | - |
481 | | - # Return all found pairs |
482 | | - return elem_pairs |
483 | | - |
484 | 468 | def set_pos(self, pos): |
485 | 469 | """ |
486 | 470 | Set position of object position is defined as center of bounding box |
@@ -518,91 +502,14 @@ def set_scale(self, scale, obj=None): |
518 | 502 |
|
519 | 503 | self._scale = scale |
520 | 504 |
|
521 | | - # scale geoms |
522 | | - geom_pairs = self._get_geoms(obj) |
523 | | - for _, (_, element) in enumerate(geom_pairs): |
524 | | - g_pos = element.get("pos") |
525 | | - g_size = element.get("size") |
526 | | - if g_pos is not None: |
527 | | - g_pos = array_to_string(string_to_array(g_pos) * self._scale) |
528 | | - element.set("pos", g_pos) |
529 | | - if g_size is not None: |
530 | | - g_size_np = string_to_array(g_size) |
531 | | - # handle cases where size is not 3 dimensional |
532 | | - if len(g_size_np) == 3: |
533 | | - g_size_np = g_size_np * self._scale |
534 | | - elif len(g_size_np) == 2: |
535 | | - scale = np.array(self._scale).reshape(-1) |
536 | | - if len(scale) == 1: |
537 | | - g_size_np[1] *= scale |
538 | | - elif len(scale) == 3: |
539 | | - # g_size_np[0] *= np.mean(scale[:2]) |
540 | | - g_size_np[0] *= np.mean(scale[:2]) # width |
541 | | - g_size_np[1] *= scale[2] # height |
542 | | - else: |
543 | | - raise ValueError |
544 | | - else: |
545 | | - raise ValueError |
546 | | - g_size = array_to_string(g_size_np) |
547 | | - element.set("size", g_size) |
548 | | - |
549 | | - # scale meshes |
550 | | - meshes = self.asset.findall("mesh") |
551 | | - for elem in meshes: |
552 | | - m_scale = elem.get("scale") |
553 | | - if m_scale is not None: |
554 | | - m_scale = string_to_array(m_scale) |
555 | | - else: |
556 | | - m_scale = np.ones(3) |
557 | | - |
558 | | - m_scale *= self._scale |
559 | | - elem.set("scale", array_to_string(m_scale)) |
560 | | - |
561 | | - # scale bodies |
562 | | - body_pairs = self._get_elements(obj, "body") |
563 | | - for (_, elem) in body_pairs: |
564 | | - b_pos = elem.get("pos") |
565 | | - if b_pos is not None: |
566 | | - b_pos = string_to_array(b_pos) * self._scale |
567 | | - elem.set("pos", array_to_string(b_pos)) |
568 | | - |
569 | | - # scale joints |
570 | | - joint_pairs = self._get_elements(obj, "joint") |
571 | | - for (_, elem) in joint_pairs: |
572 | | - j_pos = elem.get("pos") |
573 | | - if j_pos is not None: |
574 | | - j_pos = string_to_array(j_pos) * self._scale |
575 | | - elem.set("pos", array_to_string(j_pos)) |
576 | | - |
577 | | - # scale sites |
578 | | - site_pairs = self._get_elements(self.worldbody, "site") |
579 | | - for (_, elem) in site_pairs: |
580 | | - s_pos = elem.get("pos") |
581 | | - if s_pos is not None: |
582 | | - s_pos = string_to_array(s_pos) * self._scale |
583 | | - elem.set("pos", array_to_string(s_pos)) |
584 | | - |
585 | | - s_size = elem.get("size") |
586 | | - if s_size is not None: |
587 | | - s_size_np = string_to_array(s_size) |
588 | | - # handle cases where size is not 3 dimensional |
589 | | - if len(s_size_np) == 3: |
590 | | - s_size_np = s_size_np * self._scale |
591 | | - elif len(s_size_np) == 2: |
592 | | - scale = np.array(self._scale).reshape(-1) |
593 | | - if len(scale) == 1: |
594 | | - s_size_np *= scale |
595 | | - elif len(scale) == 3: |
596 | | - s_size_np[0] *= np.mean(scale[:2]) # width |
597 | | - s_size_np[1] *= scale[2] # height |
598 | | - else: |
599 | | - raise ValueError |
600 | | - elif len(s_size_np) == 1: |
601 | | - s_size_np *= np.mean(self._scale) |
602 | | - else: |
603 | | - raise ValueError |
604 | | - s_size = array_to_string(s_size_np) |
605 | | - elem.set("size", s_size) |
| 505 | + # Use the centralized scaling utility function |
| 506 | + scale_mjcf_model( |
| 507 | + obj=obj, |
| 508 | + asset_root=self.asset, |
| 509 | + scale=scale, |
| 510 | + get_elements_func=get_elements, |
| 511 | + scale_slide_joints=False, # MujocoXMLObject doesn't handle slide joints |
| 512 | + ) |
606 | 513 |
|
607 | 514 | @property |
608 | 515 | def bottom_offset(self): |
|
0 commit comments