Skip to content

Commit 8979ef3

Browse files
committed
Fix recomputing inertia for primitive geometries.
1 parent 9158282 commit 8979ef3

File tree

2 files changed

+80
-49
lines changed

2 files changed

+80
-49
lines changed

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,51 @@ def _build(self):
161161
geom_pos = geom._init_pos
162162
geom_quat = geom._init_quat
163163

164-
if not inertia_mesh.is_watertight:
165-
inertia_mesh = trimesh.convex.convex_hull(inertia_mesh)
166-
167-
# FIXME: without this check, some geom will have negative volume even after the above convex
168-
# hull operation, e.g. 'tests/test_examples.py::test_example[rigid/terrain_from_mesh.py-None]'
169-
if inertia_mesh.volume < -gs.EPS:
170-
inertia_mesh.invert()
171-
172-
geom_mass = inertia_mesh.volume * rho
173-
geom_com_local = np.array(inertia_mesh.center_mass, dtype=gs.np_float)
174-
175-
geom_inertia_local = inertia_mesh.moment_inertia / inertia_mesh.mass * geom_mass
164+
geom_com_local = np.zeros(3)
165+
if geom.type == gs.GEOM_TYPE.PLANE:
166+
pass
167+
elif geom.type == gs.GEOM_TYPE.SPHERE:
168+
radius = geom.data[0]
169+
geom_mass = (4.0 / 3.0) * np.pi * radius**3 * rho
170+
I = (2.0 / 5.0) * geom_mass * radius**2
171+
geom_inertia_local = np.diag([I, I, I])
172+
elif geom.type == gs.GEOM_TYPE.ELLIPSOID:
173+
hx, hy, hz = geom.data[:3]
174+
geom_mass = (4.0 / 3.0) * np.pi * hx * hy * hz * rho
175+
geom_inertia_local = (geom_mass / 5.0) * np.diag([hy**2 + hz**2, hx**2 + hz**2, hx**2 + hy**2])
176+
elif geom.type == gs.GEOM_TYPE.CYLINDER:
177+
radius, height = geom.data[:2]
178+
geom_mass = np.pi * radius**2 * height * rho
179+
I_r = (geom_mass / 12.0) * (3.0 * radius**2 + height**2)
180+
I_z = 0.5 * geom_mass * radius**2
181+
geom_inertia_local = np.diag([I_r, I_r, I_z])
182+
elif geom.type == gs.GEOM_TYPE.CAPSULE:
183+
radius, height = geom.data[:2]
184+
m_cyl = np.pi * radius**2 * height * rho
185+
m_sph = (4.0 / 3.0) * np.pi * radius**3 * rho
186+
geom_mass = m_cyl + m_sph
187+
I_r = (m_cyl * radius**2 / 12.0 * (3.0 + height**2 / radius**2)) + (
188+
m_sph * radius**2 / 4.0 * (83.0 / 80.0 + (height / radius + 3.0 / 4.0) ** 2)
189+
)
190+
I_h = 0.5 * m_cyl * radius**2 + (2.0 / 5.0) * m_sph * radius**2
191+
geom_inertia_local = np.diag([I_r, I_r, I_h])
192+
elif geom.type == gs.GEOM_TYPE.BOX:
193+
hx, hy, hz = geom.data[:3]
194+
geom_mass = (hx * hy * hz) * rho
195+
geom_inertia_local = (geom_mass / 12.0) * np.diag([hy**2 + hz**2, hx**2 + hz**2, hx**2 + hy**2])
196+
else: # if geom.type == gs.GEOM_TYPE.MESH:
197+
if not inertia_mesh.is_watertight:
198+
inertia_mesh = trimesh.convex.convex_hull(inertia_mesh)
199+
200+
# FIXME: without this check, some geom will have negative volume even after the above convex
201+
# hull operation, e.g. 'tests/test_examples.py::test_example[rigid/terrain_from_mesh.py-None]'
202+
if inertia_mesh.volume < -gs.EPS:
203+
inertia_mesh.invert()
204+
205+
geom_mass = inertia_mesh.volume * rho
206+
geom_com_local = inertia_mesh.center_mass
207+
208+
geom_inertia_local = inertia_mesh.moment_inertia / inertia_mesh.mass * geom_mass
176209

177210
# Transform geom properties to link frame
178211
geom_com_link = gu.transform_by_quat(geom_com_local, geom_quat) + geom_pos

tests/test_usd.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@
2828
HAS_USD_SUPPORT = False
2929
HAS_OMNIVERSE_KIT_SUPPORT = False
3030

31-
USD_COLOR_TOL = 1e-07 # Parsing from .usd loses a little precision in color
32-
USD_NORMALS_TOL = 1e-02 # Conversion from .usd to .glb loses a little precision in normals
31+
32+
# Conversion from .usd to .glb significantly affects precision
33+
USD_COLOR_TOL = 1e-07
34+
USD_NORMALS_TOL = 1e-02
3335

3436

3537
def to_array(s: str) -> np.ndarray:
3638
"""Convert a string of space-separated floats to a numpy array."""
3739
return np.array([float(x) for x in s.split()])
3840

3941

40-
def compare_links(compared_links, usd_links, tol, strict=True):
42+
def compare_links(compared_links, usd_links, tol):
4143
"""Compare links between two scenes."""
4244
# Check number of links
4345
assert len(compared_links) == len(usd_links)
@@ -87,8 +89,7 @@ def compare_links(compared_links, usd_links, tol, strict=True):
8789
assert_allclose(compared_link.inertial_quat, usd_link.inertial_quat, tol=tol, err_msg=err_msg)
8890

8991
# Skip mass and inertia checks for fixed links - they're not used in simulation
90-
if strict and not compared_link.is_fixed:
91-
# Both scenes now use the same material density (1000 kg/m³), so values should match closely
92+
if not compared_link.is_fixed:
9293
assert_allclose(compared_link.inertial_mass, usd_link.inertial_mass, atol=tol, err_msg=err_msg)
9394
assert_allclose(compared_link.inertial_i, usd_link.inertial_i, atol=tol, err_msg=err_msg)
9495

@@ -162,7 +163,7 @@ def compare_geoms(compared_geoms, usd_geoms, tol):
162163
assert_allclose(compared_geom.get_AABB(), usd_geom.get_AABB(), tol=tol, err_msg=err_msg)
163164

164165

165-
def compare_vgeoms(compared_vgeoms, usd_vgeoms, tol, strict=True):
166+
def compare_vgeoms(compared_vgeoms, usd_vgeoms, tol):
166167
"""Compare visual geoms between two scenes."""
167168
assert len(compared_vgeoms) == len(usd_vgeoms)
168169

@@ -171,31 +172,28 @@ def compare_vgeoms(compared_vgeoms, usd_vgeoms, tol, strict=True):
171172
usd_vgeoms_sorted = sorted(usd_vgeoms, key=lambda g: g.vmesh.metadata["name"].split("/")[-1])
172173

173174
for compared_vgeom, usd_vgeom in zip(compared_vgeoms_sorted, usd_vgeoms_sorted):
174-
if strict:
175-
compared_vgeom_pos, compared_vgeom_quat = gu.transform_pos_quat_by_trans_quat(
176-
compared_vgeom.init_pos, compared_vgeom.init_quat, compared_vgeom.link.pos, compared_vgeom.link.quat
177-
)
178-
usd_vgeom_pos, usd_vgeom_quat = gu.transform_pos_quat_by_trans_quat(
179-
usd_vgeom.init_pos, usd_vgeom.init_quat, usd_vgeom.link.pos, usd_vgeom.link.quat
180-
)
181-
compared_vgeom_T = gu.trans_quat_to_T(compared_vgeom_pos, compared_vgeom_quat)
182-
usd_vgeom_T = gu.trans_quat_to_T(usd_vgeom_pos, usd_vgeom_quat)
183-
184-
compared_vgeom_mesh = compared_vgeom.vmesh.copy()
185-
usd_vgeom_mesh = usd_vgeom.vmesh.copy()
186-
mesh_name = usd_vgeom_mesh.metadata["name"]
187-
compared_vgeom_mesh.apply_transform(compared_vgeom_T)
188-
usd_vgeom_mesh.apply_transform(usd_vgeom_T)
189-
check_gs_meshes(compared_vgeom_mesh, usd_vgeom_mesh, mesh_name, tol, USD_NORMALS_TOL)
190-
else:
191-
assert_allclose(compared_vgeom.get_AABB(), usd_vgeom.get_AABB(), tol=tol)
175+
compared_vgeom_pos, compared_vgeom_quat = gu.transform_pos_quat_by_trans_quat(
176+
compared_vgeom.init_pos, compared_vgeom.init_quat, compared_vgeom.link.pos, compared_vgeom.link.quat
177+
)
178+
usd_vgeom_pos, usd_vgeom_quat = gu.transform_pos_quat_by_trans_quat(
179+
usd_vgeom.init_pos, usd_vgeom.init_quat, usd_vgeom.link.pos, usd_vgeom.link.quat
180+
)
181+
compared_vgeom_T = gu.trans_quat_to_T(compared_vgeom_pos, compared_vgeom_quat)
182+
usd_vgeom_T = gu.trans_quat_to_T(usd_vgeom_pos, usd_vgeom_quat)
183+
184+
compared_vgeom_mesh = compared_vgeom.vmesh.copy()
185+
usd_vgeom_mesh = usd_vgeom.vmesh.copy()
186+
mesh_name = usd_vgeom_mesh.metadata["name"]
187+
compared_vgeom_mesh.apply_transform(compared_vgeom_T)
188+
usd_vgeom_mesh.apply_transform(usd_vgeom_T)
189+
check_gs_meshes(compared_vgeom_mesh, usd_vgeom_mesh, mesh_name, tol, USD_NORMALS_TOL)
192190

193191
compared_vgeom_surface = compared_vgeom_mesh.surface
194192
usd_vgeom_surface = usd_vgeom_mesh.surface
195193
check_gs_surfaces(compared_vgeom_surface, usd_vgeom_surface, mesh_name)
196194

197195

198-
def compare_scene(compared_scene: gs.Scene, usd_scene: gs.Scene, tol: float, strict: bool = True):
196+
def compare_scene(compared_scene: gs.Scene, usd_scene: gs.Scene, tol: float):
199197
"""Compare structure and data between compared scene and USD scene."""
200198
compared_entities = compared_scene.entities
201199
usd_entities = usd_scene.entities
@@ -210,7 +208,7 @@ def compare_scene(compared_scene: gs.Scene, usd_scene: gs.Scene, tol: float, str
210208

211209
compared_links = [link for entity in compared_entities for link in entity.links]
212210
usd_links = [link for entity in usd_entities for link in entity.links]
213-
compare_links(compared_links, usd_links, tol=tol, strict=strict)
211+
compare_links(compared_links, usd_links, tol=tol)
214212

215213

216214
def compare_mesh_scene(compared_scene: gs.Scene, usd_scene: gs.Scene, tol: float):
@@ -289,8 +287,14 @@ def build_mesh_scene(mesh_file: str, scale: float):
289287
scale=scale,
290288
euler=(-90, 0, 0),
291289
group_by_material=False,
290+
convexify=False,
291+
)
292+
mesh_scene.add_entity(
293+
mesh_morph,
294+
material=gs.materials.Rigid(
295+
rho=1000.0,
296+
),
292297
)
293-
mesh_scene.add_entity(mesh_morph, material=gs.materials.Rigid(rho=1000.0))
294298
mesh_scene.build()
295299
return mesh_scene
296300

@@ -465,19 +469,15 @@ def all_primitives_usd(asset_tmp_path, all_primitives_mjcf: ET.ElementTree):
465469
return usd_file
466470

467471

468-
@pytest.mark.parametrize("precision", ["32"])
472+
@pytest.mark.required
469473
@pytest.mark.parametrize("model_name", ["all_primitives_mjcf"])
470474
@pytest.mark.parametrize("scale", [1.0, 2.0])
471475
@pytest.mark.skipif(not HAS_USD_SUPPORT, reason="USD support not available")
472476
def test_primitives_mjcf_vs_usd(xml_path, all_primitives_usd, scale, tol):
473477
"""Test that MJCF and USD scenes produce equivalent Genesis entities."""
474478
mjcf_scene = build_mjcf_scene(xml_path, scale=scale)
475479
usd_scene = build_usd_scene(all_primitives_usd, scale=scale)
476-
# FIXME: Now parsed primitives have the same geometry for both visual and collision meshes
477-
# which is different from how we parsed in MJCF.
478-
# Additionally, in MuJoCo, primitives' masses are computed directly from their analytical
479-
# parameters rather than using the actual mesh volume. This should be considered in USD and URDF parsing.
480-
compare_scene(mjcf_scene, usd_scene, tol=tol, strict=False)
480+
compare_scene(mjcf_scene, usd_scene, tol=tol)
481481

482482

483483
# ==================== Joint Tests ====================
@@ -741,10 +741,11 @@ def all_joints_usd(asset_tmp_path, all_joints_mjcf: ET.ElementTree):
741741
free_joint_prim.CreateLocalPos1Attr().Set(Gf.Vec3f(0.0, 0.0, 0.0))
742742

743743
stage.Save()
744+
744745
return usd_file
745746

746747

747-
@pytest.mark.parametrize("precision", ["32"])
748+
@pytest.mark.required
748749
@pytest.mark.parametrize("model_name", ["all_joints_mjcf"])
749750
@pytest.mark.parametrize("scale", [1.0, 2.0])
750751
@pytest.mark.skipif(not HAS_USD_SUPPORT, reason="USD support not available")
@@ -764,7 +765,6 @@ def test_joints_mjcf_vs_usd(xml_path, all_joints_usd, scale, tol):
764765

765766

766767
@pytest.mark.required
767-
@pytest.mark.parametrize("precision", ["32"])
768768
@pytest.mark.parametrize("model_name", ["usd/sneaker_airforce", "usd/RoughnessTest"])
769769
@pytest.mark.skipif(not HAS_USD_SUPPORT, reason="USD support not available")
770770
def test_usd_visual_parse(model_name, tol):
@@ -780,7 +780,6 @@ def test_usd_visual_parse(model_name, tol):
780780

781781

782782
@pytest.mark.required
783-
@pytest.mark.parametrize("precision", ["32"])
784783
@pytest.mark.parametrize("usd_file", ["usd/nodegraph.usda"])
785784
@pytest.mark.skipif(not HAS_USD_SUPPORT, reason="USD support not available")
786785
def test_usd_parse_nodegraph(usd_file):
@@ -798,7 +797,6 @@ def test_usd_parse_nodegraph(usd_file):
798797

799798

800799
@pytest.mark.required
801-
@pytest.mark.parametrize("precision", ["32"])
802800
@pytest.mark.parametrize(
803801
"usd_file", ["usd/WoodenCrate/WoodenCrate_D1_1002.usda", "usd/franka_mocap_teleop/table_scene.usd"]
804802
)

0 commit comments

Comments
 (0)