diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 2c68c4102c48d..72db974c511b8 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `_validate_checkpoint_directory` in DeepSpeed strategy failing for remote filesystem URIs (S3, GCS, HDFS) by replacing `pathlib.Path` with fsspec-based checks ([#21636](https://github.com/Lightning-AI/pytorch-lightning/pull/21636)) -- diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 883546fea1f2d..40b344bf92dcc 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,6 +16,7 @@ import logging import os import platform +import posixpath from collections.abc import Mapping from contextlib import AbstractContextManager, ExitStack from datetime import timedelta @@ -36,6 +37,7 @@ from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded +from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.distributed import log from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.load import _move_state_into @@ -45,6 +47,7 @@ if TYPE_CHECKING: from deepspeed import DeepSpeedEngine + from fsspec import AbstractFileSystem from torch.optim.lr_scheduler import _LRScheduler _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") @@ -885,9 +888,9 @@ def _validate_device_index_selection(parallel_devices: list[torch.device]) -> No ) -def _is_deepspeed_checkpoint(path: Path) -> bool: +def _is_deepspeed_checkpoint(path: str, fs: "AbstractFileSystem") -> bool: """Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory.""" - return path.is_dir() and (path / "checkpoint").is_dir() + return fs.isdir(path) and fs.isdir(f"{path.rstrip('/')}/checkpoint") def _validate_checkpoint_directory(path: _PATH) -> None: @@ -903,24 +906,28 @@ def _validate_checkpoint_directory(path: _PATH) -> None: # ├── latest # └── zero_to_fp32.py - path = Path(path) - path_is_ds_checkpoint = _is_deepspeed_checkpoint(path) - default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}" + path_str = str(path) + fs = get_filesystem(path_str) + path_is_ds_checkpoint = _is_deepspeed_checkpoint(path_str, fs) + default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path_str}" if not path_is_ds_checkpoint: # Case 1: User may have accidentally passed the subfolder "checkpoint" - parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent) + parent_is_ds_checkpoint = _is_deepspeed_checkpoint(posixpath.dirname(path_str), fs) if parent_is_ds_checkpoint: raise FileNotFoundError( f"{default_message}. It looks like you passed the path to a subfolder." - f" Try to load using this parent directory instead: {path.parent}" + f" Try to load using this parent directory instead: {posixpath.dirname(path_str)}" ) # Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder - parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent) + parent_parent_is_ds_checkpoint = fs.isfile(path_str) and _is_deepspeed_checkpoint( + posixpath.dirname(posixpath.dirname(path_str)), fs + ) if parent_parent_is_ds_checkpoint: raise FileNotFoundError( - f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder." - f" Try to load using this parent directory instead: {path.parent.parent}" + f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed" + f" checkpoint folder." + f" Try to load using this parent directory instead: {posixpath.dirname(posixpath.dirname(path_str))}" ) raise FileNotFoundError(default_message) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 0194c7b87820a..cceaf4828e40a 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -264,6 +264,58 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path): strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()}) +def _make_s3_mock_fs(dirs, files=()): + """Create a mock fsspec filesystem for S3-like remote URI tests.""" + fs = Mock() + fs.isdir = Mock(side_effect=lambda p: p in dirs) + fs.isfile = Mock(side_effect=lambda p: p in files) + return fs + + +def test_validate_checkpoint_directory_remote_uri(): + """Test that _validate_checkpoint_directory works with remote filesystem URIs (e.g., S3, HDFS).""" + from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory + + mock_fs = _make_s3_mock_fs(dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"}) + + with mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs): + # Should not raise when the remote path is a valid DeepSpeed checkpoint + _validate_checkpoint_directory("s3://bucket/ckpt") + + # Verify the URI was NOT mangled (s3:// must stay as s3://, not s3:/) + mock_fs.isdir.assert_any_call("s3://bucket/ckpt") + mock_fs.isdir.assert_any_call("s3://bucket/ckpt/checkpoint") + + +def test_validate_checkpoint_directory_remote_uri_subfolder_suggestion(): + """Test that the subfolder suggestion works with remote URIs.""" + from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory + + mock_fs = _make_s3_mock_fs(dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"}) + + with ( + mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs), + pytest.raises(FileNotFoundError, match="Try to load using this parent directory instead: s3://bucket/ckpt"), + ): + _validate_checkpoint_directory("s3://bucket/ckpt/checkpoint") + + +def test_validate_checkpoint_directory_remote_uri_file_inside_checkpoint(): + """Test that the file-inside-checkpoint suggestion works with remote URIs.""" + from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory + + mock_fs = _make_s3_mock_fs( + dirs={"s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"}, + files={"s3://bucket/ckpt/checkpoint/zero_pp_rank_0_mp_rank_00_model_states.pt"}, + ) + + with ( + mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs), + pytest.raises(FileNotFoundError, match="Try to load using this parent directory instead: s3://bucket/ckpt"), + ): + _validate_checkpoint_directory("s3://bucket/ckpt/checkpoint/zero_pp_rank_0_mp_rank_00_model_states.pt") + + @RunIf(deepspeed=True) def test_deepspeed_load_checkpoint_no_state(tmp_path): """Test that DeepSpeed can't load the full state without access to a model instance from the user."""