Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `_atomic_save` staging local checkpoints through `$TMPDIR`, which caused "no space left on device" on SLURM/HPC setups with a small `/tmp` ([#21744](https://github.com/Lightning-AI/pytorch-lightning/pull/21744))

---

Expand Down
35 changes: 35 additions & 0 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
"""Utilities related to data saving/loading."""

import contextlib
import errno
import io
import logging
import os
import tempfile
from pathlib import Path
from typing import IO, Any, Optional, Union

Expand Down Expand Up @@ -97,6 +100,38 @@ def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None:
log.debug(f"Saving checkpoint: {filepath}")
torch.save(checkpoint, bytesbuffer)

if _is_local_file_protocol(filepath):
# Stage next to the destination so the temp file shares a filesystem with the target.
# fsspec's LocalFileSystem.transaction stages via tempfile.mkstemp() with no dir= argument,
# which puts the temp file in $TMPDIR and fails on setups where $TMPDIR is on a small
# partition (e.g. SLURM clusters). See https://github.com/Lightning-AI/pytorch-lightning/issues/21253.
target = os.fspath(filepath)
parent = os.path.dirname(target) or "."
fd, staging = tempfile.mkstemp(dir=parent, prefix=os.path.basename(target) + ".", suffix=".tmp")
try:
with os.fdopen(fd, "wb") as f:
f.write(bytesbuffer.getvalue())
# Flush contents to disk before rename so a crash can't leave a renamed-but-empty file.
f.flush()
os.fsync(f.fileno())
os.replace(staging, target) # atomic on same filesystem
except BaseException:
with contextlib.suppress(FileNotFoundError):
os.unlink(staging)
raise
# Best-effort: fsync the parent directory so the rename itself is durable.
# Not supported on every platform (e.g. Windows) — failures are non-fatal because the
# file contents are already on disk above.
try:
dir_fd = os.open(parent, os.O_RDONLY)
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
except OSError:
pass
return

try:
# We use a transaction here to avoid file corruption if the save gets interrupted
fs, urlpath = fsspec.core.url_to_fs(str(filepath))
Expand Down
97 changes: 97 additions & 0 deletions tests/tests_fabric/utilities/test_cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from pathlib import Path
from unittest import mock

import fsspec
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from fsspec.spec import AbstractFileSystem
Expand Down Expand Up @@ -150,3 +153,97 @@ def test_atomic_save_uses_write_for_local(tmp_path):
assert filepath.exists()
loaded = torch.load(filepath, weights_only=True)
torch.testing.assert_close(loaded["key"], checkpoint["key"])


def test_atomic_save_local_stages_next_to_destination(tmp_path, monkeypatch):
"""Regression for #21253: local checkpoints must stage in the destination dir, not $TMPDIR.

Setups with a small $TMPDIR (SLURM, HPC clusters) hit "no space left on device" when a
large checkpoint is staged there. Stage next to the destination instead. Detected by
instrumenting ``tempfile.mkstemp`` — the staging directory is observable mid-write, not
after, because a successful transaction cleans the staging file up on commit. Paths are
resolved before comparison so this doesn't false-fail on macOS where /tmp is a symlink to
/private/tmp.

"""
sentinel_tmpdir = (tmp_path / "sentinel_tmpdir").resolve()
dest_dir = (tmp_path / "destination").resolve()
sentinel_tmpdir.mkdir()
dest_dir.mkdir()

monkeypatch.setenv("TMPDIR", str(sentinel_tmpdir))
monkeypatch.setattr(tempfile, "tempdir", str(sentinel_tmpdir))

real_mkstemp = tempfile.mkstemp
staged_parents = []

def traced_mkstemp(*args, **kwargs):
fd, name = real_mkstemp(*args, **kwargs)
staged_parents.append(Path(name).resolve().parent)
return fd, name

monkeypatch.setattr(tempfile, "mkstemp", traced_mkstemp)

filepath = dest_dir / "checkpoint.ckpt"
_atomic_save({"key": torch.tensor([1, 2, 3])}, filepath)

assert filepath.exists()
assert sentinel_tmpdir not in staged_parents, f"mkstemp staged through $TMPDIR: {staged_parents}"
assert dest_dir in staged_parents, f"mkstemp did not stage in destination dir: {staged_parents}"


def test_atomic_save_local_cleans_up_staging_on_failure(tmp_path):
"""If the rename fails, the staging file must not leak in the destination dir.

Patches os.replace to fail so we can observe what happens after a successful write but a failed rename — the
destination should not exist and the destination dir should be empty (i.e. no staging file under any naming
convention).

"""
filepath = tmp_path / "checkpoint.ckpt"

with (
mock.patch("lightning.fabric.utilities.cloud_io.os.replace", side_effect=OSError("boom")),
pytest.raises(OSError, match="boom"),
):
_atomic_save({"key": torch.tensor([1, 2, 3])}, filepath)

assert not filepath.exists()
assert list(tmp_path.iterdir()) == [], f"unexpected files left after failed save: {list(tmp_path.iterdir())}"


def test_atomic_save_local_preserves_existing_on_failure(tmp_path):
"""The atomicity guarantee: a failed save must not corrupt or destroy a prior checkpoint.

The most important property of _atomic_save. Writes a baseline checkpoint, then attempts a
save that fails at the rename step, then asserts the original bytes are still on disk
untouched.
"""
filepath = tmp_path / "checkpoint.ckpt"
_atomic_save({"key": torch.tensor([0, 0, 0])}, filepath)
original_bytes = filepath.read_bytes()

with (
mock.patch("lightning.fabric.utilities.cloud_io.os.replace", side_effect=OSError("boom")),
pytest.raises(OSError, match="boom"),
):
_atomic_save({"key": torch.tensor([42, 42, 42])}, filepath)

assert filepath.read_bytes() == original_bytes
assert [p for p in tmp_path.iterdir() if p != filepath] == []


def test_atomic_save_local_missing_parent_raises(tmp_path):
"""Parent directories are not auto-created — locks in current behavior.

Lightning's checkpoint code creates dirs upstream (ModelCheckpoint.setup); a future refactor silently creating them
here would mask caller bugs.

"""
filepath = tmp_path / "missing" / "checkpoint.ckpt"

with pytest.raises(FileNotFoundError):
_atomic_save({"key": torch.tensor([1, 2, 3])}, filepath)

assert not filepath.exists()
assert not filepath.parent.exists()
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import operator
import os
import re
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch
from lightning_utilities.core.imports import compare_version

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
Expand Down Expand Up @@ -109,31 +105,6 @@ def test_hpc_max_ckpt_version(tmp_path):
)


def test_local_cross_device_checkpoint(tmpdir):
"""Test that the _CheckpointConnector can write local cross-device files or raises an error if fsspec<2025.5.0."""
model = BoringModel()
# hardcoding dir since `tmp_path` can be windows path
trainer = Trainer(
default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1
)
trainer.fit(model)
# Simulate the behavior of fsspec when writing to a local file system but other device.
with (
mock.patch("os.rename", side_effect=OSError(errno.EXDEV, "Invalid cross-device link")),
mock.patch("os.chmod", side_effect=PermissionError("Operation not permitted")),
):
if compare_version("fsspec", operator.lt, "2025.5.0"):
with pytest.raises(
RuntimeError,
match=re.escape(
'Upgrade fsspec to enable cross-device local checkpoints: pip install "fsspec[http]>=2025.5.0"'
),
):
trainer.save_checkpoint(tmpdir + "/test_ckpt_for_fsspec/hpc_ckpt.ckpt")
else:
trainer.save_checkpoint(tmpdir + "/test_ckpt_for_fsspec/hpc_ckpt.ckpt")


def test_ckpt_for_fsspec():
"""Test that the _CheckpointConnector is able to write to fsspec file systems."""
model = BoringModel()
Expand Down
Loading