Skip to content

[BUG] import cutlass crashes on CPU-only machines when jax is installed (v4.4 regression) #3083

@henrylhtsang

Description

@henrylhtsang

Which component has the problem?

CuTe DSL

Bug Report

Description

Since v4.4, import cutlass crashes with CudaDriverDependencyError on machines without GPU drivers when jax is installed. This worked fine in v4.3.x.

v4.4 (9fba3195) added from . import jax to cutlass/__init__.py, which eagerly calls initialize_cutlass_dsl()cutlass.cute.compile()cuda.cuInit(0) at import time. This fails hard when libcuda.so.1 is absent.

Repro

import cuda.bindings.driver

# Simulate CPU-only machine (or run on an actual CPU-only machine / container)
cuda.bindings.driver.cuInit = lambda flags: (_ for _ in ()).throw(
    RuntimeError("Failed to dlopen libcuda.so.1"))

import cutlass  # v4.3.x: OK, v4.4+: CudaDriverDependencyError

Import chain

cutlass/__init__.py:79      from . import jax
cutlass/jax/__init__.py:53  initialize_cutlass_dsl()  # called if jax is importable
cutlass/jax/compile.py:300  cutlass.cute.compile(kernel.init)
cutlass/base_dsl/compiler.py:169  _check_cuda_dependencies_once()
cutlass/base_dsl/compiler.py:211  cuda.cuInit(0)  # CRASHES: Failed to dlopen libcuda.so.1

Suggested fix

Lazy-load the jax submodule instead of importing it eagerly in cutlass/__init__.py:

# Replace line 79: from . import jax as jax
# With:
def __getattr__(name):
    if name == "jax":
        from . import jax
        return jax
    raise AttributeError(f"module 'cutlass' has no attribute {name}")

Or defer initialize_cutlass_dsl() in jax/__init__.py to first use instead of calling it at import time.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions