Skip to content

Version Conflict Between Torch and JAX for NVIDIA cuDNN-cu12 #858

Open
@apivovarov

Description

@apivovarov

I am trying to run all the pytests on a GPU instance

To set up the environment, I installed the [dev] and [gpu] dependencies, but encountered the following issue:

pip install -e .[dev]
torchvision-0.16.1 requires torch-2.1.1
torch-2.1.1 requires nvidia_cudnn_cu12-8.9.2.26
pip install -e .[gpu]
jax[cuda12]-0.4.33 needs nvidia-cudnn-cu12 9.5.1.17

This leads to the following conflict:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.5.1.17 which is incompatible.
Successfully installed nvidia-cudnn-cu12-9.5.1.17

I am unable to have both torch and jax installed simultaneously.

When nvidia-cudnn-cu12-9.5.1.17 (the newer version) is installed, torch-2.1.1 crashes with the following error:

import torch

ImportError: libcudnn.so.8: cannot open shared object file: No such file or directory

When nvidia-cudnn-cu12-8.9.2.26 (the older version) is installed, jax crashes with this error:

import jax.numpy as jnp
x = jnp.ones((1000, 1000))

FAILED_PRECONDITION: DNN library initialization failed

Approximately 30 test files use the torch package.

I am confused about how to run all pytests on the GPU instance, as I cannot have both torch/torchvision and jax[cuda12] installed at the same time due to these conflicts.

OS: Ubuntu 22.04

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions