[BUG] - TransformerModel does not work on multiple GPUs - darts [0.31 - 0.32] #2631
Description
Describe the bug
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation. We can try to use copy() but it would not help, clone does not work at darts or i have done something wrong.
To Reproduce
original code
src-TransformerModel-10-multigpus-3.txt
error_log.txt
System (please complete the following information):
- Python 3.11.10
- darts [0.31 - 0.32]
- NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7
-- NVIDIA H100 80GB HBM3
-- NVIDIA H100 80GB HBM3
PRETTY_NAME="Ubuntu 22.04.5 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.5 LTS (Jammy Jellyfish)"
modules
Requirement already satisfied: darts in /usr/local/lib/python3.11/dist-packages (0.32.0)
Requirement already satisfied: holidays>=0.11.1 in /usr/local/lib/python3.11/dist-packages (from darts) (0.63)
Requirement already satisfied: joblib>=0.16.0 in /usr/local/lib/python3.11/dist-packages (from darts) (1.4.2)
Requirement already satisfied: matplotlib>=3.3.0 in /usr/local/lib/python3.11/dist-packages (from darts) (3.10.0)
Requirement already satisfied: nfoursid>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from darts) (1.0.1)
Requirement already satisfied: numpy<2.0.0,>=1.19.0 in /usr/local/lib/python3.11/dist-packages (from darts) (1.26.3)
Requirement already satisfied: pandas>=1.0.5 in /usr/local/lib/python3.11/dist-packages (from darts) (2.2.3)
Requirement already satisfied: pmdarima>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2.0.4)
Requirement already satisfied: pyod>=0.9.5 in /usr/local/lib/python3.11/dist-packages (from darts) (2.0.3)
Requirement already satisfied: requests>=2.22.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2.32.3)
Requirement already satisfied: scikit-learn<1.6.0,>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from darts) (1.5.2)
Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.11/dist-packages (from darts) (1.14.1)
Requirement already satisfied: shap>=0.40.0 in /usr/local/lib/python3.11/dist-packages (from darts) (0.46.0)
Requirement already satisfied: statsforecast>=1.4 in /usr/local/lib/python3.11/dist-packages (from darts) (2.0.0)
Requirement already satisfied: statsmodels>=0.14.0 in /usr/local/lib/python3.11/dist-packages (from darts) (0.14.4)
Requirement already satisfied: tbats>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from darts) (1.1.3)
Requirement already satisfied: tqdm>=4.60.0 in /usr/local/lib/python3.11/dist-packages (from darts) (4.67.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from darts) (4.9.0)
Requirement already satisfied: xarray>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2024.11.0)
Requirement already satisfied: xgboost>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2.1.3)
Requirement already satisfied: pytorch-lightning>=1.5.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2.5.0.post0)
Requirement already satisfied: tensorboardX>=2.1 in /usr/local/lib/python3.11/dist-packages (from darts) (2.6.2.2)
Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from darts) (2.4.1+cu124)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.11/dist-packages (from holidays>=0.11.1->darts) (2.9.0.post0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (4.55.3)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (24.1)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.3.0->darts) (10.2.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib>=3.3.0->darts) (2.4.7)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.5->darts) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.5->darts) (2024.2)
Requirement already satisfied: Cython!=0.29.18,!=0.29.31,>=0.29 in /usr/local/lib/python3.11/dist-packages (from pmdarima>=1.8.0->darts) (3.0.11)
Requirement already satisfied: urllib3 in /usr/local/lib/python3.11/dist-packages (from pmdarima>=1.8.0->darts) (2.2.3)
Requirement already satisfied: setuptools!=50.0.0,>=38.6.0 in /usr/local/lib/python3.11/dist-packages (from pmdarima>=1.8.0->darts) (75.1.0)
Requirement already satisfied: numba>=0.51 in /usr/local/lib/python3.11/dist-packages (from pyod>=0.9.5->darts) (0.60.0)
Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.5.0->darts) (6.0.2)
Requirement already satisfied: fsspec>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (2024.2.0)
Requirement already satisfied: torchmetrics>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.5.0->darts) (1.6.0)
Requirement already satisfied: lightning-utilities>=0.10.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.5.0->darts) (0.11.9)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.22.0->darts) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.22.0->darts) (3.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.22.0->darts) (2024.8.30)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<1.6.0,>=1.0.1->darts) (3.5.0)
Requirement already satisfied: slicer==0.0.8 in /usr/local/lib/python3.11/dist-packages (from shap>=0.40.0->darts) (0.0.8)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.11/dist-packages (from shap>=0.40.0->darts) (3.1.0)
Requirement already satisfied: coreforecast>=0.0.12 in /usr/local/lib/python3.11/dist-packages (from statsforecast>=1.4->darts) (0.0.15)
Requirement already satisfied: fugue>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from statsforecast>=1.4->darts) (0.9.1)
Requirement already satisfied: utilsforecast>=0.1.4 in /usr/local/lib/python3.11/dist-packages (from statsforecast>=1.4->darts) (0.2.10)
Requirement already satisfied: patsy>=0.5.6 in /usr/local/lib/python3.11/dist-packages (from statsmodels>=0.14.0->darts) (1.0.1)
Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.11/dist-packages (from tensorboardX>=2.1->darts) (5.29.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (3.13.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (3.2.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (3.1.3)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.99)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.99)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.99)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.2.65 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.2.65)
Requirement already satisfied: nvidia-cufft-cu12==11.2.0.44 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (11.2.0.44)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.119 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (10.3.5.119)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.0.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (11.6.0.99)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.0.142 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.3.0.142)
Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (2.20.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.99)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (12.4.99)
Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.8.0->darts) (3.0.0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (3.11.11)
Requirement already satisfied: triad>=0.9.7 in /usr/local/lib/python3.11/dist-packages (from fugue>=0.8.1->statsforecast>=1.4->darts) (0.9.8)
Requirement already satisfied: adagio>=0.2.4 in /usr/local/lib/python3.11/dist-packages (from fugue>=0.8.1->statsforecast>=1.4->darts) (0.2.6)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba>=0.51->pyod>=0.9.5->darts) (0.43.0)
Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil->holidays>=0.11.1->darts) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.8.0->darts) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.11/dist-packages (from sympy->torch>=1.8.0->darts) (1.3.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (1.3.2)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (0.2.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.5.0->darts) (1.18.3)
Requirement already satisfied: pyarrow>=6.0.1 in /usr/local/lib/python3.11/dist-packages (from triad>=0.9.7->fugue>=0.8.1->statsforecast>=1.4->darts) (18.1.0)
Requirement already satisfied: fs in /usr/local/lib/python3.11/dist-packages (from triad>=0.9.7->fugue>=0.8.1->statsforecast>=1.4->darts) (2.4.16)
Requirement already satisfied: appdirs~=1.4.3 in /usr/local/lib/python3.11/dist-packages (from fs->triad>=0.9.7->fugue>=0.8.1->statsforecast>=1.4->darts) (1.4.4)
Additional context
error_log.txt