Skip to content

Commit 019a502

Browse files
committed
fix: update tensor cloning logic to ensure CPU snapshots are created correctly
1 parent bded2bb commit 019a502

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,6 @@ def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
105105
``clone()`` is required to break storage sharing.
106106
107107
"""
108-
if t.is_cuda:
109-
return t.detach().cpu()
110-
return t.detach().clone()
108+
if t.is_cpu:
109+
return t.detach().clone()
110+
return t.detach().cpu()

tests/tests_pytorch/plugins/test_async_checkpoint.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
88
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
9+
from tests_pytorch.helpers.runif import RunIf
910

1011

1112
class _CaptureCheckpointIO(CheckpointIO):
@@ -53,20 +54,22 @@ def test_async_checkpoint_should_snapshot_values_before_mutation():
5354
)
5455

5556

56-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
57-
def test_async_checkpoint_clones_tensors_to_cpu():
57+
@RunIf(min_cuda_gpus=1)
58+
@pytest.mark.parametrize(("device"), ["cpu", "cuda:0"])
59+
def test_async_checkpoint_clones_tensors_to_cpu(device):
5860
"""Verify that _clone_tensor produces a CPU snapshot that does not share storage."""
5961
from lightning.pytorch.plugins.io.async_plugin import _clone_tensor
6062

61-
t = torch.tensor([1.0, 2.0, 3.0])
63+
t = torch.tensor([1.0, 2.0, 3.0], device=device)
6264
cloned = _clone_tensor(t)
6365

6466
# cloned tensor should be on CPU
6567
assert cloned.device == torch.device("cpu"), f"Expected CPU tensor, got {cloned.device}"
6668
# values should match
67-
assert torch.equal(cloned, t)
69+
assert torch.equal(cloned, t.cpu())
6870
# cloned tensor should not share storage with the original
6971
assert cloned.data_ptr() != t.data_ptr()
7072
# mutation of the original must not affect the clone
7173
t.add_(1.0)
7274
assert torch.equal(cloned, torch.tensor([1.0, 2.0, 3.0]))
75+
assert t.device == torch.device(device), f"Original tensor should remain on {device}, got {t.device}"

0 commit comments

Comments
 (0)