Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `_validate_checkpoint_directory` in DeepSpeed strategy failing for remote filesystem URIs (S3, GCS, HDFS) by replacing `pathlib.Path` with fsspec-based checks ([#21636](https://github.com/Lightning-AI/pytorch-lightning/pull/21636))

--

Expand Down
27 changes: 17 additions & 10 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import platform
import posixpath
from collections.abc import Mapping
from contextlib import AbstractContextManager, ExitStack
from datetime import timedelta
Expand All @@ -36,6 +37,7 @@
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
from lightning.fabric.utilities.load import _move_state_into
Expand All @@ -45,6 +47,7 @@

if TYPE_CHECKING:
from deepspeed import DeepSpeedEngine
from fsspec import AbstractFileSystem
from torch.optim.lr_scheduler import _LRScheduler

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
Expand Down Expand Up @@ -885,9 +888,9 @@ def _validate_device_index_selection(parallel_devices: list[torch.device]) -> No
)


def _is_deepspeed_checkpoint(path: Path) -> bool:
def _is_deepspeed_checkpoint(path: str, fs: "AbstractFileSystem") -> bool:
"""Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory."""
return path.is_dir() and (path / "checkpoint").is_dir()
return fs.isdir(path) and fs.isdir(f"{path.rstrip('/')}/checkpoint")


def _validate_checkpoint_directory(path: _PATH) -> None:
Expand All @@ -903,24 +906,28 @@ def _validate_checkpoint_directory(path: _PATH) -> None:
# ├── latest
# └── zero_to_fp32.py

path = Path(path)
path_is_ds_checkpoint = _is_deepspeed_checkpoint(path)
default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}"
path_str = str(path)
fs = get_filesystem(path_str)
path_is_ds_checkpoint = _is_deepspeed_checkpoint(path_str, fs)
default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path_str}"

if not path_is_ds_checkpoint:
# Case 1: User may have accidentally passed the subfolder "checkpoint"
parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent)
parent_is_ds_checkpoint = _is_deepspeed_checkpoint(posixpath.dirname(path_str), fs)
if parent_is_ds_checkpoint:
raise FileNotFoundError(
f"{default_message}. It looks like you passed the path to a subfolder."
f" Try to load using this parent directory instead: {path.parent}"
f" Try to load using this parent directory instead: {posixpath.dirname(path_str)}"
)
# Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder
parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent)
parent_parent_is_ds_checkpoint = fs.isfile(path_str) and _is_deepspeed_checkpoint(
posixpath.dirname(posixpath.dirname(path_str)), fs
)
if parent_parent_is_ds_checkpoint:
raise FileNotFoundError(
f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder."
f" Try to load using this parent directory instead: {path.parent.parent}"
f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed"
f" checkpoint folder."
f" Try to load using this parent directory instead: {posixpath.dirname(posixpath.dirname(path_str))}"
)
raise FileNotFoundError(default_message)

Expand Down
52 changes: 52 additions & 0 deletions tests/tests_fabric/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,58 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path):
strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})


def _make_s3_mock_fs(dirs, files=()):
"""Create a mock fsspec filesystem for S3-like remote URI tests."""
fs = Mock()
fs.isdir = Mock(side_effect=lambda p: p in dirs)
fs.isfile = Mock(side_effect=lambda p: p in files)
return fs


def test_validate_checkpoint_directory_remote_uri():
"""Test that _validate_checkpoint_directory works with remote filesystem URIs (e.g., S3, HDFS)."""
from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory

mock_fs = _make_s3_mock_fs(dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"})

with mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs):
# Should not raise when the remote path is a valid DeepSpeed checkpoint
_validate_checkpoint_directory("s3://bucket/ckpt")

# Verify the URI was NOT mangled (s3:// must stay as s3://, not s3:/)
mock_fs.isdir.assert_any_call("s3://bucket/ckpt")
mock_fs.isdir.assert_any_call("s3://bucket/ckpt/checkpoint")


def test_validate_checkpoint_directory_remote_uri_subfolder_suggestion():
"""Test that the subfolder suggestion works with remote URIs."""
from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory

mock_fs = _make_s3_mock_fs(dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"})

with (
mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs),
pytest.raises(FileNotFoundError, match="Try to load using this parent directory instead: s3://bucket/ckpt"),
):
_validate_checkpoint_directory("s3://bucket/ckpt/checkpoint")


def test_validate_checkpoint_directory_remote_uri_file_inside_checkpoint():
"""Test that the file-inside-checkpoint suggestion works with remote URIs."""
from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory

mock_fs = _make_s3_mock_fs(
dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"},
files={"s3://bucket/ckpt/checkpoint/zero_pp_rank_0_mp_rank_00_model_states.pt"},
)

with (
mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs),
pytest.raises(FileNotFoundError, match="Try to load using this parent directory instead: s3://bucket/ckpt"),
):
_validate_checkpoint_directory("s3://bucket/ckpt/checkpoint/zero_pp_rank_0_mp_rank_00_model_states.pt")


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_no_state(tmp_path):
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
Expand Down
Loading