Skip to content

Commit c039d96

Browse files
NilouxwilsonCernWq
andauthored
fix: use JIT return module for CUDA 12.8.1 (Blackwell) to avoid import failures (#166)
* Fix JIT extension loading for CUDA 12.8.1/Blackwell: use returned module instead of re-import --------- Co-authored-by: Qi Wu <wilson.over.cloud@gmail.com>
1 parent 38664dd commit c039d96

File tree

5 files changed

+10
-16
lines changed

5 files changed

+10
-16
lines changed

threedgrt_tracer/setup_3dgrt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def to_cpp_bool(value):
7171
f"{os.path.join(slang_build_file_dir,'models/gaussianParticles.slang')}",
7272
f"{os.path.join(slang_build_file_dir,'models/shRadiativeParticles.slang')}",
7373
"-o",
74-
f"{os.path.join(slang_build_file_dir,'gaussianParticles.cuh')}",
74+
f"{os.path.join(slang_build_file_dir, 'gaussianParticles.cuh')}",
7575
],
7676
env=slang_build_env,
7777
)
@@ -84,4 +84,4 @@ def to_cpp_bool(value):
8484
extra_cflags=cflags,
8585
extra_cuda_cflags=cuda_flags,
8686
extra_include_paths=include_paths,
87-
)
87+
)

threedgrut_playground/setup_playground.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import os
17-
import sys
1817
from pathlib import Path
1918

2019
from threedgrut.utils import jit
@@ -72,11 +71,11 @@ def to_cpp_bool(value):
7271
f"-DGAUSSIAN_PARTICLE_MIN_ALPHA={conf.render.particle_kernel_min_alpha}",
7372
f"-DGAUSSIAN_PARTICLE_MAX_ALPHA={conf.render.particle_kernel_max_alpha}",
7473
f"-DGAUSSIAN_PARTICLE_ENABLE_NORMAL={to_cpp_bool(conf.render.enable_normals)}",
75-
f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type=='trisurfel')}",
76-
f"{os.path.join(slang_build_file_dir,'models/gaussianParticles.slang')}",
77-
f"{os.path.join(slang_build_file_dir,'models/shRadiativeParticles.slang')}",
74+
f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type == 'trisurfel')}",
75+
f"{os.path.join(slang_build_file_dir, 'models/gaussianParticles.slang')}",
76+
f"{os.path.join(slang_build_file_dir, 'models/shRadiativeParticles.slang')}",
7877
"-o",
79-
f"{os.path.join(slang_build_file_dir,'gaussianParticles.cuh')}",
78+
f"{os.path.join(slang_build_file_dir, 'gaussianParticles.cuh')}",
8079
],
8180
env=slang_build_env,
8281
)
@@ -87,4 +86,4 @@ def to_cpp_bool(value):
8786
name="libplayground_cc",
8887
sources=source_paths,
8988
extra_include_paths=include_paths,
90-
)
89+
)

threedgrut_playground/tracer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717
import os
18-
from enum import IntEnum
1918

2019
import torch
2120
import torch.utils.cpp_extension
@@ -46,9 +45,7 @@ def load_playground_plugin(conf):
4645

4746

4847
class Tracer:
49-
5048
def __init__(self, conf):
51-
5249
self.device = "cuda"
5350
self.conf = conf
5451
self.num_update_bvh = 0

threedgut_tracer/setup_3dgut.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
# ----------------------------------------------------------------------------
2525
#
2626
def setup_3dgut(conf):
27-
2827
build_dir = torch.utils.cpp_extension._get_build_directory("lib3dgut_cc", verbose=True)
2928

3029
include_paths = []
@@ -126,9 +125,9 @@ def to_cpp_bool(value):
126125
"-Wno-41018",
127126
"-O2",
128127
*defines,
129-
f"{os.path.join(slang_build_inc_dir,'threedgut.slang')}",
128+
f"{os.path.join(slang_build_inc_dir, 'threedgut.slang')}",
130129
"-o",
131-
f"{os.path.join(build_dir,'threedgutSlang.cuh')}",
130+
f"{os.path.join(build_dir, 'threedgutSlang.cuh')}",
132131
],
133132
env=slang_build_env,
134133
)
@@ -142,4 +141,4 @@ def to_cpp_bool(value):
142141
extra_cuda_cflags=cuda_cflags,
143142
extra_include_paths=include_paths,
144143
build_directory=build_dir,
145-
)
144+
)

threedgut_tracer/tracer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import numpy as np
2121
import torch
22-
import torch.utils.cpp_extension
2322
from omegaconf import OmegaConf
2423

2524
from threedgrut.datasets.protocols import Batch

0 commit comments

Comments
 (0)