Open
Description
I am trying to run all the pytests on a GPU instance
To set up the environment, I installed the [dev] and [gpu] dependencies, but encountered the following issue:
pip install -e .[dev]
torchvision-0.16.1 requires torch-2.1.1
torch-2.1.1 requires nvidia_cudnn_cu12-8.9.2.26
pip install -e .[gpu]
jax[cuda12]-0.4.33 needs nvidia-cudnn-cu12 9.5.1.17
This leads to the following conflict:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.5.1.17 which is incompatible.
Successfully installed nvidia-cudnn-cu12-9.5.1.17
I am unable to have both torch and jax installed simultaneously.
When nvidia-cudnn-cu12-9.5.1.17 (the newer version) is installed, torch-2.1.1 crashes with the following error:
import torch
ImportError: libcudnn.so.8: cannot open shared object file: No such file or directory
When nvidia-cudnn-cu12-8.9.2.26 (the older version) is installed, jax crashes with this error:
import jax.numpy as jnp
x = jnp.ones((1000, 1000))
FAILED_PRECONDITION: DNN library initialization failed
Approximately 30 test files use the torch package.
I am confused about how to run all pytests on the GPU instance, as I cannot have both torch/torchvision and jax[cuda12] installed at the same time due to these conflicts.
OS: Ubuntu 22.04
Metadata
Metadata
Assignees
Labels
No labels