Closed
Description
🐛 Bug
torchmetrics
seems unable to handle a metric that has been serialized on a CUDA-enabled installation when later loading on a CPU-only installation and calling metric.to('cpu')
(see below). This even happens with torch.load(..., map_location='cpu')
.
# On CUDA-enabled torch
import torch
import torchmetrics
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5).cuda()
print(metric._device)
# device(type='cuda', index=0)
torch.save(metric, 'test.pth')
# On CPU-only torch
m = torch.load('test.pth', map_location='cpu')
print(m._device)
# device(type='cuda', index=0)
m.to('cpu')
#[...]
# raise AssertionError("Torch not compiled with CUDA enabled")
#AssertionError: Torch not compiled with CUDA enabled
The problematic code is here:
torchmetrics/src/torchmetrics/metric.py
Lines 811 to 813 in 894de4c
After loading, self.device
will still refer to the original CUDA device as that is what was serialized:
>>> m = torch.load('test.pth', map_location='cpu')
>>> m._device
device(type='cuda', index=0)
Constructing a tensor with torch.zeros(1, device=self.device)
will then error out if CUDA is not available.
To Reproduce
see above
Expected behavior
Calling metric.to('cpu')
on a metric stored on the CPU device should not throw an error.
Environment
torchmetrics==1.2.0
Additional context
Originally reported here: