Skip to content

Commit 355ddaf

Browse files
authored
Fix _jit_compile arguments for torch>=2.7.0 (#651)
* Fix _jit_compile arguments for torch>=2.7.0 * reformatted using black
1 parent b416792 commit 355ddaf

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

gsplat/cuda/_backend.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Trigger compiling (for debugging):
33
44
VERBOSE=1 FAST_COMPILE=1 TORCH_CUDA_ARCH_LIST="8.9" python -c "from gsplat.cuda._backend import _C"
@@ -75,20 +75,46 @@ def load_extension(
7575
head_file = os.path.join(_TORCH_PATH, "include", "torch", "extension.h")
7676
extra_cflags += ["-include", head_file, "-Winvalid-pch"]
7777

78-
return _jit_compile(
79-
name,
80-
sources,
81-
extra_cflags,
82-
extra_cuda_cflags,
83-
extra_ldflags,
84-
extra_include_paths,
85-
build_directory,
86-
verbose,
87-
with_cuda=None,
88-
is_python_module=True,
89-
is_standalone=False,
90-
keep_intermediates=True,
91-
)
78+
try:
79+
compiled = _jit_compile(
80+
name,
81+
sources,
82+
extra_cflags,
83+
extra_cuda_cflags,
84+
extra_ldflags,
85+
extra_include_paths,
86+
build_directory,
87+
verbose,
88+
with_cuda=None,
89+
is_python_module=True,
90+
is_standalone=False,
91+
keep_intermediates=True,
92+
)
93+
except (
94+
TypeError
95+
) as e: # torch>=2.7.0 has added arguments to _jit_compile to support SYCL.
96+
# Narrow the scope of catch: only retry if it's due to unexpected argument(s)
97+
if "_jit_compile() missing" in str(e):
98+
compiled = _jit_compile(
99+
name,
100+
sources,
101+
extra_cflags,
102+
extra_cuda_cflags,
103+
None, # SYCL fallback
104+
extra_ldflags,
105+
extra_include_paths,
106+
build_directory,
107+
verbose,
108+
with_cuda=None,
109+
with_sycl=None,
110+
is_python_module=True,
111+
is_standalone=False,
112+
keep_intermediates=True,
113+
)
114+
else:
115+
raise e
116+
117+
return compiled
92118
except OSError:
93119
# The module should already be compiled if we get OSError
94120
return _import_module_from_library(name, build_directory, True)

0 commit comments

Comments
 (0)