|
6 | 6 |
|
7 | 7 | from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO |
8 | 8 | from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO |
| 9 | +from tests_pytorch.helpers.runif import RunIf |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class _CaptureCheckpointIO(CheckpointIO): |
@@ -53,20 +54,22 @@ def test_async_checkpoint_should_snapshot_values_before_mutation(): |
53 | 54 | ) |
54 | 55 |
|
55 | 56 |
|
56 | | -@pytest.mark.filterwarnings("ignore::DeprecationWarning") |
57 | | -def test_async_checkpoint_clones_tensors_to_cpu(): |
| 57 | +@RunIf(min_cuda_gpus=1) |
| 58 | +@pytest.mark.parametrize(("device"), ["cpu", "cuda:0"]) |
| 59 | +def test_async_checkpoint_clones_tensors_to_cpu(device): |
58 | 60 | """Verify that _clone_tensor produces a CPU snapshot that does not share storage.""" |
59 | 61 | from lightning.pytorch.plugins.io.async_plugin import _clone_tensor |
60 | 62 |
|
61 | | - t = torch.tensor([1.0, 2.0, 3.0]) |
| 63 | + t = torch.tensor([1.0, 2.0, 3.0], device=device) |
62 | 64 | cloned = _clone_tensor(t) |
63 | 65 |
|
64 | 66 | # cloned tensor should be on CPU |
65 | 67 | assert cloned.device == torch.device("cpu"), f"Expected CPU tensor, got {cloned.device}" |
66 | 68 | # values should match |
67 | | - assert torch.equal(cloned, t) |
| 69 | + assert torch.equal(cloned, t.cpu()) |
68 | 70 | # cloned tensor should not share storage with the original |
69 | 71 | assert cloned.data_ptr() != t.data_ptr() |
70 | 72 | # mutation of the original must not affect the clone |
71 | 73 | t.add_(1.0) |
72 | 74 | assert torch.equal(cloned, torch.tensor([1.0, 2.0, 3.0])) |
| 75 | + assert t.device == torch.device(device), f"Original tensor should remain on {device}, got {t.device}" |
0 commit comments