Skip to content

Error when calling .to('cpu') on a CUDA metric loaded on a CPU-only machine #2223

Closed
@tringwald

Description

@tringwald

🐛 Bug

torchmetrics seems unable to handle a metric that has been serialized on a CUDA-enabled installation when later loading on a CPU-only installation and calling metric.to('cpu') (see below). This even happens with torch.load(..., map_location='cpu').

# On CUDA-enabled torch
import torch
import torchmetrics 
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5).cuda()
print(metric._device)
# device(type='cuda', index=0)
torch.save(metric, 'test.pth')
# On CPU-only torch
m = torch.load('test.pth', map_location='cpu')
print(m._device)
# device(type='cuda', index=0)
m.to('cpu')
#[...]
#    raise AssertionError("Torch not compiled with CUDA enabled")
#AssertionError: Torch not compiled with CUDA enabled

The problematic code is here:

# make sure to update the device attribute
# if the dummy tensor moves device by fn function we should also update the attribute
self._device = fn(torch.zeros(1, device=self.device)).device

After loading, self.device will still refer to the original CUDA device as that is what was serialized:

>>> m = torch.load('test.pth', map_location='cpu')
>>> m._device
device(type='cuda', index=0)

Constructing a tensor with torch.zeros(1, device=self.device) will then error out if CUDA is not available.

To Reproduce

see above

Expected behavior

Calling metric.to('cpu') on a metric stored on the CPU device should not throw an error.

Environment

torchmetrics==1.2.0

Additional context

Originally reported here:

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions