diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index ba9e706..759d56b 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -28,6 +28,8 @@ _replace_state_dict_keys_with_sharded_keys, mcore_to_pyt_state_dict, ) +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner import SavePlan from torch.distributed.checkpoint.utils import _DistWrapper from typing_extensions import override @@ -83,17 +85,27 @@ def __init__( storage_writer: MemoryStorageWriter, backend: str = default_backend_format_name(), version: int = default_backend_format_version(), + use_cached_ckpt_structure: bool = False, ): """ Args: storage_writer (MemoryStorageWriter): The storage writer to use for saving operations. backend (str, optional): The name of the backend format. Defaults to "ml_flashpoint", which is recommended. version (int, optional): The version of the checkpoint format. Defaults to the latest version. + use_cached_ckpt_structure (bool, optional): Whether to reuse the checkpoint structure (plan) + from the previous save. Defaults to False. """ super().__init__(backend=backend, version=version) self._storage_writer: MemoryStorageWriter = storage_writer self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver + # Cache for state dict saving + self._cached_central_plan: SavePlan | None = None + self._cached_local_plan: SavePlan | None = None + self._cached_global_metadata: Metadata | None = None + self._validated_cache_reuse: bool = False + self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure + @override def can_handle_sharded_objects(self) -> bool: # Not currently used, but in case it is, ensure this strategy is used for ShardedObjects as well. @@ -157,14 +169,38 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union # we also use Megatron's SavePlanner during saving for compatibility. planner: MCoreSavePlanner = MCoreSavePlanner(can_run_decentralized_global_plan=False) world_dist_wrapper = _DistWrapper(group=None, use_dist=not disable_dist, coordinator_rank=0) - plan, write_buckets, global_metadata = statedictsaver.generate_plan( + # Try twice to validate the generated `central_plan` is the same across iterations + # If so, reuse `cached_central_plan` and `cached_global_metadata` + # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` + # (return None) so `self.cached_global_metadata` is reused + cached_structure_args = None + if self._use_cached_ckpt_structure: + cached_structure_args = ( + self._cached_central_plan, + self._cached_local_plan, + self._validated_cache_reuse, + ) + + ( + write_buckets, + global_metadata, + self._cached_central_plan, + self._cached_local_plan, + self._validated_cache_reuse, + ) = statedictsaver.generate_plan( checkpoint_id=checkpoint_id, state_dict=pyt_state_dict, storage_writer=self._storage_writer, planner=planner, world_dist_wrapper=world_dist_wrapper, + cached_ckpt_structure=cached_structure_args, ) + if global_metadata is None: + global_metadata = self._cached_global_metadata + else: + self._cached_global_metadata = global_metadata + # 5. Stage to CPU. staged_write_buckets = self._storage_writer.stage_write_data_buckets( checkpoint_id, write_buckets, non_blocking=True diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 257f889..3b30570 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -45,6 +45,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_cached_ckpt_structure: bool = False, ) -> MLFlashpointAutoResume: """Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`. @@ -62,6 +63,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. + Defaults to False. Returns: An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`. """ @@ -90,6 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count=write_thread_count, initial_write_buffer_size_bytes=initial_write_buffer_size_bytes, use_optimized_save=use_optimized_save, + use_cached_ckpt_structure=use_cached_ckpt_structure, ) default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {} @@ -111,6 +115,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_cached_ckpt_structure: bool = False, ): """Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities. @@ -138,6 +143,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. + Defaults to False. Returns: None. The trainer's checkpoint_io is modified in-place. @@ -218,7 +225,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( ), mp_manager=ctx.Manager(), thread_count=write_thread_count, - ) + ), + use_cached_ckpt_structure=use_cached_ckpt_structure, ) load_strategy = MLFlashpointMegatronLoadStrategy( replication_manager=replication_manager, diff --git a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py index 01e77da..2f40c55 100644 --- a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py +++ b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py @@ -22,7 +22,6 @@ import torch.cuda from torch import distributed as torchdist -from torch.distributed.checkpoint import Metadata from torch.distributed.checkpoint import state_dict_saver as torchdistsaver from torch.distributed.checkpoint.logger import _dcp_method_logger from torch.distributed.checkpoint.planner import SavePlan @@ -46,7 +45,14 @@ def generate_plan( storage_writer: MemoryStorageWriter, planner: torchdistsaver.SavePlanner, world_dist_wrapper: _DistWrapper, -) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]: + cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None, +) -> tuple[ + list[ObjectWriteBucket], + torchdistsaver.Metadata | None, + SavePlan, + SavePlan, + bool, +]: """Performs the planning phase of checkpointing. This function is similar to PyTorch's `state_dict_saver.save` but only @@ -62,15 +68,28 @@ def generate_plan( planner: The SavePlanner to use for the save. world_dist_wrapper: The distributed wrapper for world (all ranks) communication. Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`. + cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse). + Returns: - A tuple containing the updated local plan, write buckets, and global metadata. + A tuple containing: + - write_buckets + - global_metadata + - central_plan (for caching) + - local_plan (for caching) + - validated_cache_reuse (bool) """ + cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) + if cached_ckpt_structure: + cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure + global_metadata: torchdistsaver.Metadata | None = None ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group} + local_plan = cached_local_plan @_dcp_method_logger(**ckpt_kwargs) def local_step() -> SavePlan: + nonlocal local_plan storage_meta = storage_writer.storage_meta() planner.set_up_planner( state_dict=state_dict, @@ -79,7 +98,9 @@ def local_step() -> SavePlan: ) storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator) - local_plan = planner.create_local_plan() + if not validated_cache_reuse: + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) return local_plan @@ -91,19 +112,30 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]: all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans - with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): - _LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...") - updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) - - with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): - _LOGGER.debug("Executing global_metadata broadcast...") - # TODO(perf): - can broadcast only to local rank 0 to reduce comms - global_metadata = world_dist_wrapper.broadcast_object(global_metadata) - - final_local_plan = planner.finish_plan(updated_local_plan) + central_plan = None + if validated_cache_reuse and cached_central_plan: + _LOGGER.debug("Passed cache reusable") + local_step() + central_plan = cached_central_plan + else: + with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): + _LOGGER.debug("Executing plan reduce_scatter to get central_plan...") + central_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) + + with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): + _LOGGER.debug("Executing global_metadata broadcast...") + global_metadata = world_dist_wrapper.broadcast_object(global_metadata) + + final_local_plan = planner.finish_plan(central_plan) write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner) - return final_local_plan, write_buckets, global_metadata + return ( + write_buckets, + global_metadata, + central_plan, + local_plan, + cached_central_plan == central_plan, + ) @log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO) diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index d842283..72b0fbe 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -169,9 +169,11 @@ def test_async_save_initialization_calls_success( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock(), + mocker.MagicMock(), + mocker.MagicMock(), + False, ) mock_memory_storage_writer_cls = mocker.patch( @@ -211,9 +213,11 @@ def test_async_save_reinitializes_storage_writer_with_thread_count( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock(), + mocker.MagicMock(), + mocker.MagicMock(), + False, ) # Set a specific thread_count on the original storage_writer @@ -259,9 +263,18 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), + mocker.MagicMock(), + False, ) - expected_kwarg_keys = {"checkpoint_id", "state_dict", "storage_writer", "planner", "world_dist_wrapper"} + expected_kwarg_keys = { + "checkpoint_id", + "state_dict", + "storage_writer", + "planner", + "world_dist_wrapper", + "cached_ckpt_structure", + } # When strategy.async_save(sharded_state_dict, checkpoint_id.data) @@ -281,6 +294,8 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s assert kwargs["planner"] is mock_planner assert "world_dist_wrapper" in kwargs assert kwargs["world_dist_wrapper"].use_dist is False + assert "cached_ckpt_structure" in kwargs + assert "cached_global_metadata" not in kwargs def test_generate_plan_failure(self, mocker, async_save_setup): """Tests that an exception in generate_plan is propagated.""" @@ -303,7 +318,13 @@ def test_async_save_async_fn_call_success( mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + dummy_write_buckets, + dummy_metadata, + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) staged_write_buckets = [ ObjectWriteBucket( object_id=CheckpointObjectId(f"/test_checkpoint/staged_obj_{i}"), @@ -342,6 +363,8 @@ def test_async_save_async_fn_failure(self, mocker, async_save_setup, checkpoint_ mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), + mocker.MagicMock(), + False, ) mock_statedictsaver.write_data.side_effect = Exception("Test Exception") @@ -368,7 +391,13 @@ def test_async_save_finalize_fns_calls( finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint") mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + dummy_write_buckets, + dummy_metadata, + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) mock_memory_storage_writer_cls = mocker.patch( "ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter" @@ -427,7 +456,13 @@ def test_finalize_fns_failure( finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint") mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, mocker.MagicMock(), dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + mocker.MagicMock(), + dummy_metadata, + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) mock_statedictsaver.finish_write.side_effect = ValueError("Finish Write Failed") # When @@ -468,6 +503,8 @@ def test_async_save_rank_determination( mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), + mocker.MagicMock(), + False, ) # When @@ -475,3 +512,97 @@ def test_async_save_rank_determination( # Then assert actual_async_request.async_fn_kwargs["rank"] == expected_rank + + def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer): + """Tests the caching flow across multiple async_save calls.""" + # Given + mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") + strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup + cached_plan = mocker.MagicMock() + cached_metadata = mocker.MagicMock() + + # --- Call 1: No cache --- + # Given + mock_statedictsaver.generate_plan.return_value = ( + [], + mocker.MagicMock(), + cached_plan, # cached_central_plan returned + mocker.MagicMock(), + False, + ) + + # When + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then + assert strategy._cached_central_plan == cached_plan + assert strategy._validated_cache_reuse is False + + # --- Call 2: Cache validation success --- + # Given + mock_statedictsaver.generate_plan.return_value = ( + [], + cached_metadata, + cached_plan, + mocker.MagicMock(), + True, + ) + + # When + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then + assert strategy._validated_cache_reuse is True + assert strategy._cached_global_metadata == cached_metadata + + # --- Call 3: Reuse cache --- + # Given + mock_statedictsaver.generate_plan.return_value = ( + [], + None, # Returns None for metadata + cached_plan, + mocker.MagicMock(), + True, + ) + + # When + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then + # Ensure generate_plan was called without cached_global_metadata + _, kwargs = mock_statedictsaver.generate_plan.call_args + assert "cached_global_metadata" not in kwargs + # And cached_global_metadata in strategy should still be the same + assert strategy._cached_global_metadata == cached_metadata + + def test_async_save_caching_disabled_by_default(self, mocker, async_save_setup, storage_writer): + """Tests that caching is disabled by default.""" + # Given + mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") + strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup + + cached_plan = mocker.MagicMock() + + # Call: Returns a plan that could be cached + mock_statedictsaver.generate_plan.return_value = ( + [], + None, + cached_plan, + mocker.MagicMock(), + False, + ) + # When + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then + # Should NOT have updated the specific cached plan attribute if we assume + # generate_plan returns it regardless? + # actually statedictsaver.generate_plan returns the plan to be cached. + # But the strategy should NOT pass it back in the next call if use_cached_ckpt_structure is False. + + # Let's verify the next call doesn't pass it. + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + _, kwargs = mock_statedictsaver.generate_plan.call_args + assert kwargs["cached_ckpt_structure"] is None + assert strategy._use_cached_ckpt_structure is False diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index f00a197..25bf4f0 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -103,6 +103,7 @@ def test_successful_wrap_and_resume_creation(self, mocker): write_thread_count=1, initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save=True, + use_cached_ckpt_structure=False, ) # 3. Result is correct type and has correct attributes @@ -794,6 +795,46 @@ def test_write_thread_count_forwarding( _, kwargs = spy_memory_storage_writer_init.call_args assert kwargs["thread_count"] == expected_thread_count + @pytest.mark.parametrize("use_cached_ckpt_structure", [True, False]) + def test_cached_ckpt_structure_forwarding( + self, mocker, mock_ckpt_obj_manager, mock_replication_manager, use_cached_ckpt_structure + ): + """Tests that use_cached_ckpt_structure is forwarded correctly.""" + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + trainer.strategy.checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + base_container = "/test_base_container" + + # Mock the SaveStrategy to check initialization arguments + mock_save_strategy_cls = mocker.patch( + "ml_flashpoint.adapter.nemo.wrapper_util.MLFlashpointMegatronAsyncSaveStrategy" + ) + + # Mock dependencies + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.MemoryStorageWriter") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.DefaultMLFlashpointCheckpointSaver") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.torch_mp.get_context") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.MLFlashpointMegatronLoadStrategy") + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + use_cached_ckpt_structure=use_cached_ckpt_structure, + ) + + # Then + mock_save_strategy_cls.assert_called_once() + _, kwargs = mock_save_strategy_cls.call_args + assert kwargs["use_cached_ckpt_structure"] == use_cached_ckpt_structure + def test_spawn_context_used_for_mp_manager(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): """Tests that torch_mp.get_context('spawn').Manager() is correctly instantiated and passed.""" # Given diff --git a/tests/adapter/pytorch/test_custom_state_dict_saver.py b/tests/adapter/pytorch/test_custom_state_dict_saver.py index f98fd29..42d367b 100644 --- a/tests/adapter/pytorch/test_custom_state_dict_saver.py +++ b/tests/adapter/pytorch/test_custom_state_dict_saver.py @@ -100,7 +100,13 @@ def test_generate_plan_calls_dependencies_correctly( mock_storage_writer.prepare_write_data_buckets.return_value = expected_write_buckets # When - actual_plan, actual_write_buckets, actual_metadata = custom_state_dict_saver.generate_plan( + ( + actual_write_buckets, + actual_metadata, + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( checkpoint_id, state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -115,9 +121,80 @@ def test_generate_plan_calls_dependencies_correctly( mock_storage_writer.prepare_write_data_buckets.assert_called_once_with( checkpoint_id, local_plan, mock_save_planner ) - assert actual_plan == local_plan assert actual_write_buckets == expected_write_buckets assert actual_metadata == expected_global_metadata + assert actual_reused is False + + def test_generate_plan_reuses_cache(self, mocker, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan reuse the cache when validated_cache_reuse is True.""" + # Given + checkpoint_id = CheckpointContainerId("/test_checkpoint") + state_dict = {"model": "test"} + cached_plan = SavePlan([WriteItem(index=MetadataIndex("cached"), type=WriteItemType.TENSOR)]) + # cached_local_plan = SavePlan([WriteItem(index=MetadataIndex("local"), type=WriteItemType.TENSOR)]) + # cached_metadata = Metadata(state_dict_metadata={"cached": "meta"}) + + mock_save_planner.finish_plan.return_value = cached_plan + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + ( + _, + actual_metadata, + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( + checkpoint_id, + state_dict, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + cached_ckpt_structure=(cached_plan, None, True), + ) + + # Then + # Should not call reduce_scatter or broadcast_object + # But we can't easily assert on dist_wrapper methods as they are not mocks here unless we mock them + # However, we can check if they are NOT called by checking side effects if we had mocked them + # For now, checking return values is good enough proxy + assert actual_metadata is None + assert actual_reused is True + mock_save_planner.create_local_plan.assert_not_called() + + def test_generate_plan_validates_cache_success(self, mocker, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan validates cache successfully.""" + # Given + local_plan = SavePlan([]) + global_plans = [local_plan] + mock_save_planner.create_local_plan.return_value = local_plan + mock_storage_writer.prepare_local_plan.return_value = local_plan + mock_save_planner.create_global_plan.return_value = (global_plans, None) + mock_storage_writer.prepare_global_plan.return_value = global_plans + # Assume reduce_scatter returns the SAME plan as cached + mocker.patch.object(dist_wrapper, "reduce_scatter", return_value=local_plan) + mocker.patch.object(dist_wrapper, "broadcast_object", return_value=None) + mock_save_planner.finish_plan.return_value = local_plan + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + ( + _, + _, + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( + CheckpointContainerId("/test"), + {}, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + cached_ckpt_structure=(local_plan, None, False), + ) + + # Then + assert actual_reused is True def test_generate_plan_reduce_scatters_local_plan( self, mock_storage_writer, mock_save_planner, dist_wrapper, mocker @@ -144,7 +221,13 @@ def test_generate_plan_reduce_scatters_local_plan( mock_save_planner.finish_plan.return_value = expected_returned_plan # When - returned_local_plan, _, _ = custom_state_dict_saver.generate_plan( + ( + _, + _, + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -153,7 +236,6 @@ def test_generate_plan_reduce_scatters_local_plan( # We can't directly assert the call arguments for local_step and global_step as they are inner functions. # However, we can assert that reduce_scatter was called with 'plan' as the tag. assert mock_reduce_scatter.call_args[0][0] == "plan" - assert returned_local_plan == expected_returned_plan def test_generate_plan_broadcasts_global_metadata( self, mock_storage_writer, mock_save_planner, dist_wrapper, mocker @@ -176,7 +258,13 @@ def test_generate_plan_broadcasts_global_metadata( mocker.patch.object(dist_wrapper, "broadcast_object", return_value=expected_broadcasted_metadata) # When - _, _, returned_metadata = custom_state_dict_saver.generate_plan( + ( + _, + returned_metadata, + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -203,13 +291,43 @@ def test_generate_plan_returns_write_buckets(self, mock_storage_writer, mock_sav mock_storage_writer.prepare_write_data_buckets.return_value = expected_write_buckets # When - _, returned_write_buckets, _ = custom_state_dict_saver.generate_plan( + ( + returned_write_buckets, + _, + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) # Then assert returned_write_buckets == expected_write_buckets + def test_generate_plan_signature_compatibility(self, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan returns exactly 4 elements (updated).""" + # Given + state_dict = {"model": "test"} + mock_save_planner.create_local_plan.return_value = SavePlan([]) + mock_storage_writer.prepare_local_plan.return_value = SavePlan([]) + mock_save_planner.create_global_plan.return_value = ([SavePlan([])], None) + mock_storage_writer.prepare_global_plan.return_value = [SavePlan([])] + mock_save_planner.finish_plan.return_value = SavePlan([]) + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + result = custom_state_dict_saver.generate_plan( + CheckpointContainerId("/test_checkpoint"), + state_dict, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + ) + + # Then + assert isinstance(result, tuple) + assert len(result) == 5 + class TestWriteData: """Tests for the write_data function."""