Skip to content

Cannot load a GPU-trained nn.Module on a CPU-only machine #113973

Closed
@amorehead

Description

@amorehead

🐛 Describe the bug

  • Currently, with PyTorch 2.1, I cannot train an nn.Module model using CUDA on a GPU-enabled system and then load the model's checkpoints on another (CPU-only) machine that has CPU-only PyTorch installed on it.

  • The issue seems to be caused by .to() calls on a (originally) GPU-trained nn.Module (e.g., my_gpu_trained_model.to(cpu)) in any form. Even if the nn.Module is already located on the system's cpu device, what appears to be a call to CUDA still seems to be made, resulting in the error AssertionError: Torch not compiled with CUDA enabled being raised.

  • This issue was originally discovered from the perspective of PyTorch Lightning: Add support to saving.py for loading GPU-trained models on CPU-only machines Lightning-AI/pytorch-lightning#19024.

Versions (for the CPU-only machine)

PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: AlmaLinux release 8.8 (Sapphire Caracal) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-18)
Clang version: 15.0.7 (Red Hat 15.0.7-1.module_el8.8.0+3466+dfcbc058)
CMake version: version 3.20.2
Libc version: glibc-2.28

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-4.18.0-477.27.2.el8_8.x86_64-x86_64-with-glibc2.28
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
Stepping: 2
CPU MHz: 3300.000
CPU max MHz: 3300.0000
CPU min MHz: 1200.0000
BogoMIPS: 5000.24
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 30720K
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] lovely-numpy==0.2.9
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.1.2
[pip3] torch==2.1.1
[pip3] torch-cluster==1.6.2
[pip3] torch_geometric==2.4.0
[pip3] torch-scatter==2.1.2
[pip3] torch-sparse==0.6.18
[pip3] torch-spline-conv==1.2.2
[pip3] torchaudio==2.1.1
[pip3] torchmetrics==1.2.0
[pip3] torchvision==0.16.1
[conda] blas 2.116 mkl conda-forge
[conda] blas-devel 3.9.0 16_linux64_mkl conda-forge
[conda] cpuonly 2.0 0 pytorch
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libblas 3.9.0 16_linux64_mkl conda-forge
[conda] libcblas 3.9.0 16_linux64_mkl conda-forge
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] liblapack 3.9.0 16_linux64_mkl conda-forge
[conda] liblapacke 3.9.0 16_linux64_mkl conda-forge
[conda] lovely-numpy 0.2.9 pypi_0 pypi
[conda] mkl 2022.1.0 h84fe81f_915 conda-forge
[conda] mkl-devel 2022.1.0 ha770c72_916 conda-forge
[conda] mkl-include 2022.1.0 h84fe81f_915 conda-forge
[conda] numpy 1.23.5 pypi_0 pypi
[conda] pytorch 2.1.1 py3.10_cpu_0 pytorch
[conda] pytorch-lightning 2.1.2 pypi_0 pypi
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torch-cluster 1.6.2 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torch-spline-conv 1.2.2 pypi_0 pypi
[conda] torchaudio 2.1.1 py310_cpu pytorch
[conda] torchmetrics 1.2.0 pypi_0 pypi
[conda] torchvision 0.16.1 py310_cpu pytorch

Activity

Luonic

Luonic commented on Nov 18, 2023

@Luonic

Try checkpoint = torch.load(path_to_checkpoint, map_location=torch.device("cpu"))
And then model.load_state_dict(checkpoint)

tringwald

tringwald commented on Nov 19, 2023

@tringwald
Collaborator

This seems to be a bug in torchmetrics. When serializing the metric, m._device also seems to get serialized. In case of a CUDA model, it would be a CUDA device. Even when loading with map_location='cpu', only the underlying tensors would be moved, but not the hardcoded m._device variable. When finally calling .to(...) on this metric, the following code will try to construct a dummy tensor on m._device, which will fail, as CUDA isn't available.

https://github.com/Lightning-AI/torchmetrics/blob/894de4caeeae820f60f1871d75334873241e5633/src/torchmetrics/metric.py#L811-L813

A possible fix would be to wrap the call in a try/except block or to construct the tensor on the CPU device by default. The tracking logic should still work in that case.

tringwald

tringwald commented on Nov 19, 2023

@tringwald
Collaborator

Short reproducer:

>>> # On CUDA-enabled torch
>>> import torchmetrics 
>>> metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5).cuda()
>>> torch.save(metric, 'test.pth')
>>> # On CPU-only torch
>>> m = torch.load('test.pth', map_location='cpu')
>>> m._device
device(type='cuda', index=0)
>>> m.to('cpu')
[...]
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
amorehead

amorehead commented on Nov 19, 2023

@amorehead
ContributorAuthor

Thanks, @tringwald, for describing this so precisely. Yes, this is exactly what I am experiencing as well: it appears my checkpoints contain torchmetrics instances that do not have their devices mapped to a cpu correctly.

jbschlosser

jbschlosser commented on Nov 20, 2023

@jbschlosser
Contributor

Closing this as it is a bug within torchmetrics. Feel free to reopen if there is a legitimate PyTorch-side issue to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @amorehead@Luonic@jbschlosser@tringwald

        Issue actions

          Cannot load a GPU-trained `nn.Module` on a CPU-only machine · Issue #113973 · pytorch/pytorch