-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Labels
Description
Which component has the problem?
CuTe DSL
Bug Report
Description
Since v4.4, import cutlass crashes with CudaDriverDependencyError on machines without GPU drivers when jax is installed. This worked fine in v4.3.x.
v4.4 (9fba3195) added from . import jax to cutlass/__init__.py, which eagerly calls initialize_cutlass_dsl() → cutlass.cute.compile() → cuda.cuInit(0) at import time. This fails hard when libcuda.so.1 is absent.
Repro
import cuda.bindings.driver
# Simulate CPU-only machine (or run on an actual CPU-only machine / container)
cuda.bindings.driver.cuInit = lambda flags: (_ for _ in ()).throw(
RuntimeError("Failed to dlopen libcuda.so.1"))
import cutlass # v4.3.x: OK, v4.4+: CudaDriverDependencyErrorImport chain
cutlass/__init__.py:79 from . import jax
cutlass/jax/__init__.py:53 initialize_cutlass_dsl() # called if jax is importable
cutlass/jax/compile.py:300 cutlass.cute.compile(kernel.init)
cutlass/base_dsl/compiler.py:169 _check_cuda_dependencies_once()
cutlass/base_dsl/compiler.py:211 cuda.cuInit(0) # CRASHES: Failed to dlopen libcuda.so.1
Suggested fix
Lazy-load the jax submodule instead of importing it eagerly in cutlass/__init__.py:
# Replace line 79: from . import jax as jax
# With:
def __getattr__(name):
if name == "jax":
from . import jax
return jax
raise AttributeError(f"module 'cutlass' has no attribute {name}")Or defer initialize_cutlass_dsl() in jax/__init__.py to first use instead of calling it at import time.
Reactions are currently unavailable