|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +from unittest import mock |
15 | 16 |
|
16 | 17 | import fsspec |
| 18 | +import torch |
17 | 19 | from fsspec.implementations.local import LocalFileSystem |
18 | 20 | from fsspec.spec import AbstractFileSystem |
19 | 21 |
|
20 | | -from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem |
| 22 | +from lightning.fabric.utilities.cloud_io import _atomic_save, _is_dir, get_filesystem |
21 | 23 |
|
22 | 24 |
|
23 | 25 | def test_get_filesystem_custom_filesystem(): |
@@ -90,3 +92,61 @@ def isfile(self, path): |
90 | 92 | assert _is_dir(get_filesystem(s3_directory), s3_directory, strict=True) |
91 | 93 | assert not _is_dir(get_filesystem(s3_directory), s3_file) |
92 | 94 | assert not _is_dir(get_filesystem(s3_directory), s3_file, strict=True) |
| 95 | + |
| 96 | + |
| 97 | +def test_atomic_save_uses_pipe_for_s3(tmp_path): |
| 98 | + """Test that _atomic_save uses fs.pipe() for S3 filesystems.""" |
| 99 | + checkpoint = {"key": torch.tensor([1, 2, 3])} |
| 100 | + filepath = "s3://bucket/checkpoint.ckpt" |
| 101 | + |
| 102 | + mock_fs = mock.MagicMock() |
| 103 | + mock_fs.__class__.__name__ = "S3FileSystem" |
| 104 | + |
| 105 | + with ( |
| 106 | + mock.patch("lightning.fabric.utilities.cloud_io._is_object_storage", return_value=True), |
| 107 | + mock.patch("fsspec.core.url_to_fs", return_value=(mock_fs, "bucket/checkpoint.ckpt")), |
| 108 | + ): |
| 109 | + _atomic_save(checkpoint, filepath) |
| 110 | + |
| 111 | + mock_fs.pipe.assert_called_once() |
| 112 | + mock_fs.open.assert_not_called() |
| 113 | + |
| 114 | + |
| 115 | +def test_atomic_save_uses_write_for_azure(tmp_path): |
| 116 | + """Test that _atomic_save uses f.write() for Azure filesystems.""" |
| 117 | + import sys |
| 118 | + import types |
| 119 | + |
| 120 | + checkpoint = {"key": torch.tensor([1, 2, 3])} |
| 121 | + filepath = "azure://container/checkpoint.ckpt" |
| 122 | + |
| 123 | + # Create a fake adlfs module so isinstance check works |
| 124 | + AzureBlobFileSystem = type("AzureBlobFileSystem", (), {}) |
| 125 | + fake_adlfs = types.ModuleType("adlfs") |
| 126 | + fake_adlfs.AzureBlobFileSystem = AzureBlobFileSystem |
| 127 | + |
| 128 | + mock_fs = mock.MagicMock() |
| 129 | + mock_fs.__class__ = AzureBlobFileSystem |
| 130 | + |
| 131 | + with ( |
| 132 | + mock.patch.dict(sys.modules, {"adlfs": fake_adlfs}), |
| 133 | + mock.patch("lightning.fabric.utilities.cloud_io.module_available", return_value=True), |
| 134 | + mock.patch("lightning.fabric.utilities.cloud_io._is_object_storage", return_value=True), |
| 135 | + mock.patch("fsspec.core.url_to_fs", return_value=(mock_fs, "container/checkpoint.ckpt")), |
| 136 | + ): |
| 137 | + _atomic_save(checkpoint, filepath) |
| 138 | + |
| 139 | + mock_fs.pipe.assert_not_called() |
| 140 | + mock_fs.open.assert_called_once() |
| 141 | + |
| 142 | + |
| 143 | +def test_atomic_save_uses_write_for_local(tmp_path): |
| 144 | + """Test that _atomic_save uses f.write() for local filesystems.""" |
| 145 | + checkpoint = {"key": torch.tensor([1, 2, 3])} |
| 146 | + filepath = tmp_path / "checkpoint.ckpt" |
| 147 | + |
| 148 | + _atomic_save(checkpoint, filepath) |
| 149 | + |
| 150 | + assert filepath.exists() |
| 151 | + loaded = torch.load(filepath, weights_only=True) |
| 152 | + torch.testing.assert_close(loaded["key"], checkpoint["key"]) |
0 commit comments