Skip to content

Commit 6e8c686

Browse files
authored
[BUG FIX] Fix CUDA runtime being initialized by 'get_device'. (Genesis-Embodied-AI#1634)
* Fix CUDA runtime initialized by 'get_device', breaking device selection for distributed unit tests. * Fix support of undefined device index. * Ensure consistency between EGL_DEVICE_ID and CUDA_VISIBLE_DEVICES in unit tests. * Fix flaky unit tests due to mutable fixture interference.
1 parent 2a9f2f7 commit 6e8c686

File tree

20 files changed

+120
-90
lines changed

20 files changed

+120
-90
lines changed

genesis/ext/pyrender/camera.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class PerspectiveCamera(Camera):
109109
"""
110110

111111
def __init__(self, yfov, znear=DEFAULT_Z_NEAR, zfar=None, aspectRatio=None, name=None):
112-
super(PerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name)
112+
super().__init__(znear=znear, zfar=zfar, name=name)
113113

114114
self.yfov = yfov
115115
self.aspectRatio = aspectRatio
@@ -209,7 +209,7 @@ class OrthographicCamera(Camera):
209209
"""
210210

211211
def __init__(self, xmag, ymag, znear=DEFAULT_Z_NEAR, zfar=DEFAULT_Z_FAR, name=None):
212-
super(OrthographicCamera, self).__init__(znear=znear, zfar=zfar, name=name)
212+
super().__init__(znear=znear, zfar=zfar, name=name)
213213

214214
self.xmag = xmag
215215
self.ymag = ymag
@@ -305,7 +305,7 @@ class IntrinsicsCamera(Camera):
305305
"""
306306

307307
def __init__(self, fx, fy, cx, cy, znear=DEFAULT_Z_NEAR, zfar=DEFAULT_Z_FAR, name=None):
308-
super(IntrinsicsCamera, self).__init__(znear=znear, zfar=zfar, name=name)
308+
super().__init__(znear=znear, zfar=zfar, name=name)
309309

310310
self.fx = fx
311311
self.fy = fy

genesis/ext/pyrender/light.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,7 @@ class DirectionalLight(Light):
135135
"""
136136

137137
def __init__(self, color=None, intensity=None, name=None):
138-
super(DirectionalLight, self).__init__(
139-
color=color,
140-
intensity=intensity,
141-
name=name,
142-
)
138+
super().__init__(color=color, intensity=intensity, name=name)
143139

144140
def _generate_shadow_texture(self, size=None):
145141
"""Generate a shadow texture for this light.
@@ -191,11 +187,7 @@ class PointLight(Light):
191187
"""
192188

193189
def __init__(self, color=None, intensity=None, range=None, name=None):
194-
super(PointLight, self).__init__(
195-
color=color,
196-
intensity=intensity,
197-
name=name,
198-
)
190+
super().__init__(color=color, intensity=intensity, name=name)
199191
self.range = range
200192

201193
@property
@@ -306,11 +298,7 @@ class SpotLight(Light):
306298
def __init__(
307299
self, color=None, intensity=None, range=None, innerConeAngle=0.0, outerConeAngle=(np.pi / 4.0), name=None
308300
):
309-
super(SpotLight, self).__init__(
310-
name=name,
311-
color=color,
312-
intensity=intensity,
313-
)
301+
super().__init__(name=name, color=color, intensity=intensity)
314302
self.outerConeAngle = outerConeAngle
315303
self.innerConeAngle = innerConeAngle
316304
self.range = range

genesis/ext/pyrender/material.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def __init__(
377377
roughnessFactor=1.0,
378378
metallicRoughnessTexture=None,
379379
):
380-
super(MetallicRoughnessMaterial, self).__init__(
380+
super().__init__(
381381
name=name,
382382
normalTexture=normalTexture,
383383
occlusionTexture=occlusionTexture,
@@ -586,7 +586,7 @@ def __init__(
586586
glossinessFactor=1.0,
587587
specularGlossinessTexture=None,
588588
):
589-
super(SpecularGlossinessMaterial, self).__init__(
589+
super().__init__(
590590
name=name,
591591
normalTexture=normalTexture,
592592
occlusionTexture=occlusionTexture,

genesis/ext/pyrender/platforms/egl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class EGLPlatform(Platform):
117117
"""Renders using EGL."""
118118

119119
def __init__(self, viewport_width, viewport_height, device_id: int | None = None):
120-
super(EGLPlatform, self).__init__(viewport_width, viewport_height)
120+
super().__init__(viewport_width, viewport_height)
121121
if _eglQueryDevicesEXT is None and device_id not in (0, None):
122122
raise RuntimeError("EGL platform plugin is not available. Enforcing specific EGL device not supported.")
123123
self._egl_device_id = device_id

genesis/ext/pyrender/platforms/osmesa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class OSMesaPlatform(Platform):
1010
"""
1111

1212
def __init__(self, viewport_width, viewport_height):
13-
super(OSMesaPlatform, self).__init__(viewport_width, viewport_height)
13+
super().__init__(viewport_width, viewport_height)
1414
self._context = None
1515
self._buffer = None
1616

genesis/ext/pyrender/platforms/pyglet_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class PygletPlatform(Platform):
1313
"""
1414

1515
def __init__(self, viewport_width, viewport_height):
16-
super(PygletPlatform, self).__init__(viewport_width, viewport_height)
16+
super().__init__(viewport_width, viewport_height)
1717
self._window = None
1818

1919
def init_context(self):

genesis/ext/pyrender/viewer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def start(self, auto_refresh=True):
12381238
# This approach avoids "flickering" when creating and closing an invalid context. Besides, it avoids
12391239
# "frozen" graphical window during compilation that would be interpreted as as bug by the end-user.
12401240
try:
1241-
super(Viewer, self).__init__(
1241+
super().__init__(
12421242
config=conf,
12431243
visible=False,
12441244
resizable=True,

genesis/logging/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class GenesisFormatter(logging.Formatter):
1212
def __init__(self, verbose_time=True):
13-
super(GenesisFormatter, self).__init__()
13+
super().__init__()
1414

1515
self.mapping = {
1616
logging.DEBUG: colors.GREEN,

genesis/utils/misc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __exit__(self, exc_type, exc_value, traceback):
119119
def assert_initialized(cls):
120120
original_init = cls.__init__
121121

122+
@functools.wraps(original_init)
122123
def new_init(self, *args, **kwargs):
123124
if not gs._initialized:
124125
raise RuntimeError("Genesis hasn't been initialized. Did you call `gs.init()`?")
@@ -180,8 +181,6 @@ def get_device(backend: gs_backend, device_idx: Optional[int] = None):
180181
if not torch.cuda.is_available():
181182
gs.raise_exception("torch cuda not available")
182183

183-
if device_idx is None:
184-
device_idx = torch.cuda.current_device()
185184
device = torch.device("cuda", device_idx)
186185
device_property = torch.cuda.get_device_properties(device)
187186
device_name = device_property.name

genesis/utils/particle.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,9 @@ def parse_args(backend):
362362
compute_normals=True,
363363
enable_multi_threading=True,
364364
)
365-
mesh = trimesh.Trimesh(
366-
vertices=mesh_with_data.mesh.vertices,
367-
faces=mesh_with_data.mesh.triangles,
368-
face_normals=mesh_with_data.get_point_attribute("normals"),
369-
)
365+
normals = mesh_with_data.get_point_attribute("normals")
366+
vertices, triangles = mesh_with_data.take_mesh().take_vertices_and_triangles()
367+
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, face_normals=normals)
370368
gs.logger.debug(f"[splashsurf]: reconstruct vertices: {mesh.vertices.shape}, {mesh.faces.shape}")
371369
return mesh
372370

0 commit comments

Comments
 (0)