Skip to content

Commit 9af6fdb

Browse files
committed
update review
1 parent 2116897 commit 9af6fdb

File tree

3 files changed

+252
-381
lines changed

3 files changed

+252
-381
lines changed

genesis/utils/mjcf.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,9 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
419419
name_start = mj.name_geomadr[i_g]
420420
metadata["name"] = mj.names[name_start : mj.names.find(b"\x00", name_start)].decode("utf-8")
421421

422-
if mj_geom.matid >= 0:
423-
mj_mat = mj.mat(mj_geom.matid[0])
422+
mj_mat_id = int(mj_geom.matid[0])
423+
if mj_mat_id >= 0:
424+
mj_mat = mj.mat(mj_mat_id)
424425
tex_id_RGB = mj_mat.texid[mujoco.mjtTextureRole.mjTEXROLE_RGB]
425426
tex_id_RGBA = mj_mat.texid[mujoco.mjtTextureRole.mjTEXROLE_RGBA]
426427
tex_id = tex_id_RGB if tex_id_RGB >= 0 else tex_id_RGBA
@@ -450,15 +451,16 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
450451
length = length or 1e3
451452
width = width or 1e3
452453

453-
tmesh = trimesh.Trimesh(
454+
uv = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=np.float32)
455+
mesh_params = dict(
454456
vertices=np.array(
455-
[[-length, width, 0.0], [length, width, 0.0], [-length, -width, 0.0], [length, -width, 0.0]]
457+
[[-length, width, 0.0], [length, width, 0.0], [-length, -width, 0.0], [length, -width, 0.0]],
458+
dtype=np.float32,
456459
),
457-
faces=np.array([[0, 2, 3], [0, 3, 1]]),
458-
face_normals=np.array([[0, 0, 1], [0, 0, 1]]),
460+
faces=np.array([[0, 2, 3], [0, 3, 1]], dtype=np.int64),
461+
face_normals=np.array([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32),
459462
)
460-
uv = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
461-
geom_data = np.array([0.0, 0.0, 1.0])
463+
geom_data = np.array([0.0, 0.0, 1.0], dtype=np.float32)
462464
gs_type = gs.GEOM_TYPE.PLANE
463465

464466
elif mj_geom.type == mujoco.mjtGeom.mjGEOM_SPHERE:
@@ -467,6 +469,7 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
467469
tmesh = trimesh.creation.icosphere(radius=radius, subdivisions=2)
468470
else:
469471
tmesh = trimesh.creation.icosphere(radius=radius)
472+
mesh_params = dict(vertices=tmesh.vertices, faces=tmesh.faces)
470473
uv = None
471474
gs_type = gs.GEOM_TYPE.SPHERE
472475
geom_data = np.array([radius * scale])
@@ -476,6 +479,7 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
476479
tmesh = trimesh.creation.icosphere(radius=1.0, subdivisions=2)
477480
else:
478481
tmesh = trimesh.creation.icosphere(radius=1.0)
482+
mesh_params = dict(vertices=tmesh.vertices, faces=tmesh.faces)
479483
tmesh.apply_transform(np.diag([*geom_size, 1]))
480484
uv = None
481485
gs_type = gs.GEOM_TYPE.ELLIPSOID
@@ -488,6 +492,7 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
488492
tmesh = trimesh.creation.capsule(radius=radius, height=height, count=(8, 12))
489493
else:
490494
tmesh = trimesh.creation.capsule(radius=radius, height=height)
495+
mesh_params = dict(vertices=tmesh.vertices, faces=tmesh.faces)
491496
uv = None
492497
gs_type = gs.GEOM_TYPE.CAPSULE
493498
geom_data = np.array([radius * scale, height * scale])
@@ -496,12 +501,14 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
496501
radius = geom_size[0]
497502
height = geom_size[1] * 2
498503
tmesh = trimesh.creation.cylinder(radius=radius, height=height)
504+
mesh_params = dict(vertices=tmesh.vertices, faces=tmesh.faces)
499505
uv = None
500506
gs_type = gs.GEOM_TYPE.CYLINDER
501507
geom_data = np.array([radius * scale, height * scale])
502508

503509
elif mj_geom.type == mujoco.mjtGeom.mjGEOM_BOX:
504510
tmesh = trimesh.creation.box(extents=geom_size * 2)
511+
mesh_params = dict(vertices=tmesh.vertices, faces=tmesh.faces)
505512
uv = tmesh.vertices[:, :2].copy()
506513
uv -= uv.min(axis=0)
507514
uv /= uv.max(axis=0)
@@ -512,39 +519,34 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
512519
mj_mesh = mj.mesh(mj_geom.dataid[0])
513520

514521
vert_start = mj_mesh.vertadr[0]
515-
vert_num = mj_mesh.vertnum[0]
516-
vert_end = vert_start + vert_num
522+
vert_end = vert_start + mj_mesh.vertnum[0]
517523

518524
face_start = mj_mesh.faceadr[0]
519-
face_num = mj_mesh.facenum[0]
520-
face_end = face_start + face_num
525+
face_end = face_start + mj_mesh.facenum[0]
521526

522527
vertices = mj.mesh_vert[vert_start:vert_end]
528+
normals = mj.mesh_normal[vert_start:vert_end]
523529
faces = mj.mesh_face[face_start:face_end]
524-
face_normals = mj.mesh_normal[vert_start:vert_end]
525530

526531
tex_vert_start = int(mj.mesh_texcoordadr[mj_mesh.id])
527-
num_tex_vert = int(mj.mesh_texcoordnum[mj_mesh.id])
532+
tex_vert_end = tex_vert_start + int(mj.mesh_texcoordnum[mj_mesh.id])
533+
528534
if tex_vert_start != -1: # -1 means no texcoord
529-
vertices = np.zeros((num_tex_vert, 3))
530-
faces = mj.mesh_facetexcoord[face_start:face_end]
531-
for face_id in range(face_start, face_end):
532-
for i in range(3):
533-
mesh_vert_id = mj.mesh_face[face_id, i]
534-
tex_vert_id = mj.mesh_facetexcoord[face_id, i]
535-
vertices[tex_vert_id] = mj.mesh_vert[mesh_vert_id + vert_start]
536-
537-
uv = mj.mesh_texcoord[tex_vert_start : (tex_vert_start + num_tex_vert)]
535+
tex_faces = mj.mesh_facetexcoord[face_start:face_end]
536+
uv = mj.mesh_texcoord[tex_vert_start:tex_vert_end]
538537
uv[:, 1] = 1 - uv[:, 1]
538+
539+
pairs = np.stack([faces.ravel(), tex_faces.ravel()], axis=1) # (face_num * 3, 2)
540+
uniq, inv = np.unique(pairs, axis=0, return_inverse=True)
541+
542+
vertices = vertices[uniq[:, 0]]
543+
normals = normals[uniq[:, 0]]
544+
uv = uv[uniq[:, 1]]
545+
faces = inv.reshape(-1, 3).astype(np.int64)
539546
else:
540547
uv = None
541548

542-
tmesh = trimesh.Trimesh(
543-
vertices=vertices,
544-
faces=faces,
545-
face_normals=face_normals,
546-
process=False,
547-
)
549+
mesh_params = dict(vertices=vertices, faces=faces, vertex_normals=normals)
548550
gs_type = gs.GEOM_TYPE.MESH
549551
geom_data = None
550552

@@ -554,11 +556,13 @@ def parse_geom(mj, i_g, scale, surface, xml_path):
554556
gs.logger.warning(f"Unsupported MJCF geom type '{mj_geom.type}'.")
555557
return None
556558

557-
if uv is not None:
558-
if mj_mat is not None:
559-
uv *= mj_mat.texrepeat
560-
if uv is not None or tmesh_mat is not None:
561-
tmesh.visual = TextureVisuals(uv=uv, material=tmesh_mat)
559+
if uv is not None and mj_mat is not None:
560+
uv *= mj_mat.texrepeat
561+
tmesh = trimesh.Trimesh(
562+
**mesh_params,
563+
visual=TextureVisuals(uv=uv, material=tmesh_mat),
564+
process=False,
565+
)
562566
mesh = gs.Mesh.from_trimesh(
563567
tmesh, scale=scale, surface=gs.surfaces.Collision() if is_col else surface, metadata=metadata
564568
)

0 commit comments

Comments
 (0)