From 8d08f6237f5849f12a299ece3a9bfd358abbe79b Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Wed, 25 Feb 2026 15:34:28 +0000 Subject: [PATCH 1/4] Use megatron method to cache plan --- .../adapter/megatron/save_strategies.py | 42 +++++- .../pytorch/custom_state_dict_saver.py | 61 ++++++--- .../adapter/megatron/test_save_strategies.py | 108 +++++++++++++-- .../pytorch/test_custom_state_dict_saver.py | 124 +++++++++++++++++- 4 files changed, 306 insertions(+), 29 deletions(-) diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index ba9e706..48caf9f 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -28,6 +28,8 @@ _replace_state_dict_keys_with_sharded_keys, mcore_to_pyt_state_dict, ) +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner import SavePlan from torch.distributed.checkpoint.utils import _DistWrapper from typing_extensions import override @@ -94,6 +96,13 @@ def __init__( self._storage_writer: MemoryStorageWriter = storage_writer self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver + # Cache for state dict saving + self.cached_central_plan: SavePlan | None = None + self.cached_local_plan: SavePlan | None = None + self.cached_global_metadata: Metadata | None = None + self.validated_cache_reuse: bool = False + self.use_cached_ckpt_structure: bool = True + @override def can_handle_sharded_objects(self) -> bool: # Not currently used, but in case it is, ensure this strategy is used for ShardedObjects as well. @@ -157,14 +166,45 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union # we also use Megatron's SavePlanner during saving for compatibility. planner: MCoreSavePlanner = MCoreSavePlanner(can_run_decentralized_global_plan=False) world_dist_wrapper = _DistWrapper(group=None, use_dist=not disable_dist, coordinator_rank=0) - plan, write_buckets, global_metadata = statedictsaver.generate_plan( + # Try twice to validate the generated `central_plan` is the same across iterations + # If so, reuse `cached_central_plan` and `cached_global_metadata` + # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` + # (return None) so `self.cached_global_metadata` is reused + args_cached_plans = None + loaded_all_plans = None + if self.use_cached_ckpt_structure: + if self.cached_global_metadata: + loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) + + args_cached_plans = ( + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + ) + + ( + (plan, write_buckets, global_metadata), + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + ) = statedictsaver.generate_plan( checkpoint_id=checkpoint_id, state_dict=pyt_state_dict, storage_writer=self._storage_writer, planner=planner, world_dist_wrapper=world_dist_wrapper, + cached_ckpt_structure=args_cached_plans, + loaded_all_plans=loaded_all_plans, ) + if self.validated_cache_reuse: + if global_metadata is None and self.cached_global_metadata: + global_metadata = self.cached_global_metadata + + # If we have a valid global_metadata (either new or reused), cache it for next time + if global_metadata is not None: + self.cached_global_metadata = global_metadata + # 5. Stage to CPU. staged_write_buckets = self._storage_writer.stage_write_data_buckets( checkpoint_id, write_buckets, non_blocking=True diff --git a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py index 01e77da..510d78d 100644 --- a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py +++ b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py @@ -22,7 +22,6 @@ import torch.cuda from torch import distributed as torchdist -from torch.distributed.checkpoint import Metadata from torch.distributed.checkpoint import state_dict_saver as torchdistsaver from torch.distributed.checkpoint.logger import _dcp_method_logger from torch.distributed.checkpoint.planner import SavePlan @@ -46,7 +45,14 @@ def generate_plan( storage_writer: MemoryStorageWriter, planner: torchdistsaver.SavePlanner, world_dist_wrapper: _DistWrapper, -) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]: + cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None, + loaded_all_plans: list[SavePlan] | None = None, +) -> tuple[ + tuple[SavePlan, list[ObjectWriteBucket], torchdistsaver.Metadata | None], + SavePlan, + SavePlan, + bool, +]: """Performs the planning phase of checkpointing. This function is similar to PyTorch's `state_dict_saver.save` but only @@ -62,15 +68,28 @@ def generate_plan( planner: The SavePlanner to use for the save. world_dist_wrapper: The distributed wrapper for world (all ranks) communication. Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`. + cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse). + loaded_all_plans: List of all local plans from the previous checkpoint (for validation). + Returns: - A tuple containing the updated local plan, write buckets, and global metadata. + A tuple containing: + - (final_local_plan, write_buckets, global_metadata) + - final_local_plan (for caching) + - local_plan (for caching) + - validated_cache_reuse (bool) """ + cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) + if cached_ckpt_structure: + cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure + global_metadata: torchdistsaver.Metadata | None = None ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group} + local_plan = cached_local_plan @_dcp_method_logger(**ckpt_kwargs) def local_step() -> SavePlan: + nonlocal local_plan storage_meta = storage_writer.storage_meta() planner.set_up_planner( state_dict=state_dict, @@ -79,7 +98,9 @@ def local_step() -> SavePlan: ) storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator) - local_plan = planner.create_local_plan() + if not validated_cache_reuse: + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) return local_plan @@ -91,19 +112,29 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]: all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans - with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): - _LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...") - updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) - - with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): - _LOGGER.debug("Executing global_metadata broadcast...") - # TODO(perf): - can broadcast only to local rank 0 to reduce comms - global_metadata = world_dist_wrapper.broadcast_object(global_metadata) - - final_local_plan = planner.finish_plan(updated_local_plan) + central_plan = None + if validated_cache_reuse and cached_central_plan: + _LOGGER.debug("Passed cache reusable") + local_step() + central_plan = cached_central_plan + else: + with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): + _LOGGER.debug("Executing plan reduce_scatter to get central_plan...") + central_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) + + with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): + _LOGGER.debug("Executing global_metadata broadcast...") + global_metadata = world_dist_wrapper.broadcast_object(global_metadata) + + final_local_plan = planner.finish_plan(central_plan) write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner) - return final_local_plan, write_buckets, global_metadata + return ( + (final_local_plan, write_buckets, global_metadata), + central_plan, # cached_central_plan + local_plan, # cached_local_plan + cached_central_plan == central_plan, # validated_cache_reuse + ) @log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO) diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index d842283..ddb92d2 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -169,9 +169,10 @@ def test_async_save_initialization_calls_success( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()), mocker.MagicMock(), - dummy_write_buckets, mocker.MagicMock(), + False, ) mock_memory_storage_writer_cls = mocker.patch( @@ -211,9 +212,10 @@ def test_async_save_reinitializes_storage_writer_with_thread_count( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()), mocker.MagicMock(), - dummy_write_buckets, mocker.MagicMock(), + False, ) # Set a specific thread_count on the original storage_writer @@ -256,12 +258,21 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s ) = async_save_setup mock_planner = MockMCoreSavePlanner.return_value mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), mocker.MagicMock(), mocker.MagicMock(), - mocker.MagicMock(), + False, ) - expected_kwarg_keys = {"checkpoint_id", "state_dict", "storage_writer", "planner", "world_dist_wrapper"} + expected_kwarg_keys = { + "checkpoint_id", + "state_dict", + "storage_writer", + "planner", + "world_dist_wrapper", + "cached_ckpt_structure", + "loaded_all_plans", + } # When strategy.async_save(sharded_state_dict, checkpoint_id.data) @@ -281,6 +292,9 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s assert kwargs["planner"] is mock_planner assert "world_dist_wrapper" in kwargs assert kwargs["world_dist_wrapper"].use_dist is False + assert "cached_ckpt_structure" in kwargs + assert "loaded_all_plans" in kwargs + assert "cached_global_metadata" not in kwargs def test_generate_plan_failure(self, mocker, async_save_setup): """Tests that an exception in generate_plan is propagated.""" @@ -303,7 +317,12 @@ def test_async_save_async_fn_call_success( mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + (dummy_save_plan, dummy_write_buckets, dummy_metadata), + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) staged_write_buckets = [ ObjectWriteBucket( object_id=CheckpointObjectId(f"/test_checkpoint/staged_obj_{i}"), @@ -339,9 +358,10 @@ def test_async_save_async_fn_failure(self, mocker, async_save_setup, checkpoint_ mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), mocker.MagicMock(), mocker.MagicMock(), - mocker.MagicMock(), + False, ) mock_statedictsaver.write_data.side_effect = Exception("Test Exception") @@ -368,7 +388,12 @@ def test_async_save_finalize_fns_calls( finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint") mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + (dummy_save_plan, dummy_write_buckets, dummy_metadata), + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) mock_memory_storage_writer_cls = mocker.patch( "ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter" @@ -427,7 +452,12 @@ def test_finalize_fns_failure( finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint") mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, mocker.MagicMock(), dummy_metadata) + mock_statedictsaver.generate_plan.return_value = ( + (dummy_save_plan, mocker.MagicMock(), dummy_metadata), + mocker.MagicMock(), + mocker.MagicMock(), + False, + ) mock_statedictsaver.finish_write.side_effect = ValueError("Finish Write Failed") # When @@ -465,9 +495,10 @@ def test_async_save_rank_determination( # Mock dependencies to ensure success path mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), mocker.MagicMock(), mocker.MagicMock(), - mocker.MagicMock(), + False, ) # When @@ -475,3 +506,62 @@ def test_async_save_rank_determination( # Then assert actual_async_request.async_fn_kwargs["rank"] == expected_rank + + def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer): + """Tests the caching flow across multiple async_save calls.""" + # Given + mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") + strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup + + cached_plan = mocker.MagicMock() + cached_metadata = mocker.MagicMock() + + # First call: No cache + mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), [], mocker.MagicMock()), + cached_plan, # cached_central_plan returned + mocker.MagicMock(), + False, + ) + + # When 1 + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then 1 + assert strategy.cached_central_plan == cached_plan + assert strategy.validated_cache_reuse is False + + # Second call: Cache validation success + mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), [], cached_metadata), + cached_plan, + mocker.MagicMock(), + True, # validated_cache_reuse + ) + + # When 2 + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then 2 + assert strategy.validated_cache_reuse is True + assert strategy.cached_global_metadata == cached_metadata + + # Third call: Reuse cache + mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), [], None), # Returns None for metadata + cached_plan, + mocker.MagicMock(), + True, + ) + + # During third call, async_save should use self.cached_global_metadata + + # When 3 + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then 3 + # Ensure generate_plan was called without cached_global_metadata + _, kwargs = mock_statedictsaver.generate_plan.call_args + assert "cached_global_metadata" not in kwargs + # And cached_global_metadata in strategy should still be the same + assert strategy.cached_global_metadata == cached_metadata diff --git a/tests/adapter/pytorch/test_custom_state_dict_saver.py b/tests/adapter/pytorch/test_custom_state_dict_saver.py index f98fd29..27a16c9 100644 --- a/tests/adapter/pytorch/test_custom_state_dict_saver.py +++ b/tests/adapter/pytorch/test_custom_state_dict_saver.py @@ -100,7 +100,12 @@ def test_generate_plan_calls_dependencies_correctly( mock_storage_writer.prepare_write_data_buckets.return_value = expected_write_buckets # When - actual_plan, actual_write_buckets, actual_metadata = custom_state_dict_saver.generate_plan( + ( + (actual_plan, actual_write_buckets, actual_metadata), + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( checkpoint_id, state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -117,7 +122,79 @@ def test_generate_plan_calls_dependencies_correctly( ) assert actual_plan == local_plan assert actual_write_buckets == expected_write_buckets + assert actual_plan == local_plan + assert actual_write_buckets == expected_write_buckets assert actual_metadata == expected_global_metadata + assert actual_reused is False + + def test_generate_plan_reuses_cache(self, mocker, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan reuse the cache when validated_cache_reuse is True.""" + # Given + checkpoint_id = CheckpointContainerId("/test_checkpoint") + state_dict = {"model": "test"} + cached_plan = SavePlan([WriteItem(index=MetadataIndex("cached"), type=WriteItemType.TENSOR)]) + # cached_local_plan = SavePlan([WriteItem(index=MetadataIndex("local"), type=WriteItemType.TENSOR)]) + # cached_metadata = Metadata(state_dict_metadata={"cached": "meta"}) + + mock_save_planner.finish_plan.return_value = cached_plan + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + ( + (actual_plan, _, actual_metadata), + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( + checkpoint_id, + state_dict, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + cached_ckpt_structure=(cached_plan, None, True), + ) + + # Then + # Should not call reduce_scatter or broadcast_object + # But we can't easily assert on dist_wrapper methods as they are not mocks here unless we mock them + # However, we can check if they are NOT called by checking side effects if we had mocked them + # For now, checking return values is good enough proxy + assert actual_plan == cached_plan + assert actual_metadata is None + assert actual_reused is True + + def test_generate_plan_validates_cache_success(self, mocker, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan validates cache successfully.""" + # Given + local_plan = SavePlan([]) + global_plans = [local_plan] + mock_save_planner.create_local_plan.return_value = local_plan + mock_storage_writer.prepare_local_plan.return_value = local_plan + mock_save_planner.create_global_plan.return_value = (global_plans, None) + mock_storage_writer.prepare_global_plan.return_value = global_plans + # Assume reduce_scatter returns the SAME plan as cached + mocker.patch.object(dist_wrapper, "reduce_scatter", return_value=local_plan) + mocker.patch.object(dist_wrapper, "broadcast_object", return_value=None) + mock_save_planner.finish_plan.return_value = local_plan + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + ( + _, + _, + _, + actual_reused, + ) = custom_state_dict_saver.generate_plan( + CheckpointContainerId("/test"), + {}, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + cached_ckpt_structure=(local_plan, None, False), + ) + + # Then + assert actual_reused is True def test_generate_plan_reduce_scatters_local_plan( self, mock_storage_writer, mock_save_planner, dist_wrapper, mocker @@ -144,7 +221,12 @@ def test_generate_plan_reduce_scatters_local_plan( mock_save_planner.finish_plan.return_value = expected_returned_plan # When - returned_local_plan, _, _ = custom_state_dict_saver.generate_plan( + ( + (returned_local_plan, _, _), + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -176,7 +258,12 @@ def test_generate_plan_broadcasts_global_metadata( mocker.patch.object(dist_wrapper, "broadcast_object", return_value=expected_broadcasted_metadata) # When - _, _, returned_metadata = custom_state_dict_saver.generate_plan( + ( + (_, _, returned_metadata), + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) @@ -203,13 +290,42 @@ def test_generate_plan_returns_write_buckets(self, mock_storage_writer, mock_sav mock_storage_writer.prepare_write_data_buckets.return_value = expected_write_buckets # When - _, returned_write_buckets, _ = custom_state_dict_saver.generate_plan( + ( + (_, returned_write_buckets, _), + _, + _, + _, + ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test_checkpoint"), state_dict, mock_storage_writer, mock_save_planner, dist_wrapper ) # Then assert returned_write_buckets == expected_write_buckets + def test_generate_plan_signature_compatibility(self, mock_storage_writer, mock_save_planner, dist_wrapper): + """Tests that generate_plan returns exactly 4 elements (updated).""" + # Given + state_dict = {"model": "test"} + mock_save_planner.create_local_plan.return_value = SavePlan([]) + mock_storage_writer.prepare_local_plan.return_value = SavePlan([]) + mock_save_planner.create_global_plan.return_value = ([SavePlan([])], None) + mock_storage_writer.prepare_global_plan.return_value = [SavePlan([])] + mock_save_planner.finish_plan.return_value = SavePlan([]) + mock_storage_writer.prepare_write_data_buckets.return_value = [] + + # When + result = custom_state_dict_saver.generate_plan( + CheckpointContainerId("/test_checkpoint"), + state_dict, + mock_storage_writer, + mock_save_planner, + dist_wrapper, + ) + + # Then + assert isinstance(result, tuple) + assert len(result) == 4 + class TestWriteData: """Tests for the write_data function.""" From d0885f18e0ee032b2e0c7ac337258cac36ac795d Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Thu, 26 Feb 2026 15:45:43 +0000 Subject: [PATCH 2/4] add cache configurable arg in wrapper --- .../adapter/megatron/save_strategies.py | 5 ++- .../adapter/nemo/wrapper_util.py | 10 ++++- .../adapter/megatron/test_save_strategies.py | 34 +++++++++++++++ tests/adapter/nemo/test_wrapper_util.py | 41 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index 48caf9f..2dc0bd4 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -85,12 +85,15 @@ def __init__( storage_writer: MemoryStorageWriter, backend: str = default_backend_format_name(), version: int = default_backend_format_version(), + use_cached_ckpt_structure: bool = False, ): """ Args: storage_writer (MemoryStorageWriter): The storage writer to use for saving operations. backend (str, optional): The name of the backend format. Defaults to "ml_flashpoint", which is recommended. version (int, optional): The version of the checkpoint format. Defaults to the latest version. + use_cached_ckpt_structure (bool, optional): Whether to reuse the checkpoint structure (plan) + from the previous save. Defaults to False. """ super().__init__(backend=backend, version=version) self._storage_writer: MemoryStorageWriter = storage_writer @@ -101,7 +104,7 @@ def __init__( self.cached_local_plan: SavePlan | None = None self.cached_global_metadata: Metadata | None = None self.validated_cache_reuse: bool = False - self.use_cached_ckpt_structure: bool = True + self.use_cached_ckpt_structure: bool = use_cached_ckpt_structure @override def can_handle_sharded_objects(self) -> bool: diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 257f889..3b30570 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -45,6 +45,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_cached_ckpt_structure: bool = False, ) -> MLFlashpointAutoResume: """Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`. @@ -62,6 +63,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. + Defaults to False. Returns: An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`. """ @@ -90,6 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count=write_thread_count, initial_write_buffer_size_bytes=initial_write_buffer_size_bytes, use_optimized_save=use_optimized_save, + use_cached_ckpt_structure=use_cached_ckpt_structure, ) default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {} @@ -111,6 +115,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_cached_ckpt_structure: bool = False, ): """Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities. @@ -138,6 +143,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. + Defaults to False. Returns: None. The trainer's checkpoint_io is modified in-place. @@ -218,7 +225,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( ), mp_manager=ctx.Manager(), thread_count=write_thread_count, - ) + ), + use_cached_ckpt_structure=use_cached_ckpt_structure, ) load_strategy = MLFlashpointMegatronLoadStrategy( replication_manager=replication_manager, diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index ddb92d2..fc37973 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -512,6 +512,8 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) # Given mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup + # Enable caching significantly for this test + strategy.use_cached_ckpt_structure = True cached_plan = mocker.MagicMock() cached_metadata = mocker.MagicMock() @@ -565,3 +567,35 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) assert "cached_global_metadata" not in kwargs # And cached_global_metadata in strategy should still be the same assert strategy.cached_global_metadata == cached_metadata + + def test_async_save_caching_disabled_by_default(self, mocker, async_save_setup, storage_writer): + """Tests that caching is disabled by default.""" + # Given + mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") + strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup + + cached_plan = mocker.MagicMock() + + # Call: Returns a plan that could be cached + mock_statedictsaver.generate_plan.return_value = ( + (mocker.MagicMock(), [], mocker.MagicMock()), + cached_plan, + mocker.MagicMock(), + False, + ) + + # When + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + # Then + # Should NOT have updated the specific cached plan attribute if we assume + # generate_plan returns it regardless? + # actually statedictsaver.generate_plan returns the plan to be cached. + # But the strategy should NOT pass it back in the next call if use_cached_ckpt_structure is False. + + # Let's verify the next call doesn't pass it. + strategy.async_save(sharded_state_dict, checkpoint_id.data) + + _, kwargs = mock_statedictsaver.generate_plan.call_args + assert kwargs["cached_ckpt_structure"] is None + assert strategy.use_cached_ckpt_structure is False diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index f00a197..25bf4f0 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -103,6 +103,7 @@ def test_successful_wrap_and_resume_creation(self, mocker): write_thread_count=1, initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save=True, + use_cached_ckpt_structure=False, ) # 3. Result is correct type and has correct attributes @@ -794,6 +795,46 @@ def test_write_thread_count_forwarding( _, kwargs = spy_memory_storage_writer_init.call_args assert kwargs["thread_count"] == expected_thread_count + @pytest.mark.parametrize("use_cached_ckpt_structure", [True, False]) + def test_cached_ckpt_structure_forwarding( + self, mocker, mock_ckpt_obj_manager, mock_replication_manager, use_cached_ckpt_structure + ): + """Tests that use_cached_ckpt_structure is forwarded correctly.""" + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + trainer.strategy.checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + base_container = "/test_base_container" + + # Mock the SaveStrategy to check initialization arguments + mock_save_strategy_cls = mocker.patch( + "ml_flashpoint.adapter.nemo.wrapper_util.MLFlashpointMegatronAsyncSaveStrategy" + ) + + # Mock dependencies + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.MemoryStorageWriter") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.DefaultMLFlashpointCheckpointSaver") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.torch_mp.get_context") + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.MLFlashpointMegatronLoadStrategy") + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + use_cached_ckpt_structure=use_cached_ckpt_structure, + ) + + # Then + mock_save_strategy_cls.assert_called_once() + _, kwargs = mock_save_strategy_cls.call_args + assert kwargs["use_cached_ckpt_structure"] == use_cached_ckpt_structure + def test_spawn_context_used_for_mp_manager(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): """Tests that torch_mp.get_context('spawn').Manager() is correctly instantiated and passed.""" # Given From a837e67f7956cdbbfcbcd26c05d40981de23eb82 Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Thu, 26 Feb 2026 20:47:43 +0000 Subject: [PATCH 3/4] resolve comments --- .../pytorch/custom_state_dict_saver.py | 6 ++--- .../adapter/megatron/test_save_strategies.py | 26 +++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py index 510d78d..b05f67e 100644 --- a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py +++ b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py @@ -131,9 +131,9 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]: return ( (final_local_plan, write_buckets, global_metadata), - central_plan, # cached_central_plan - local_plan, # cached_local_plan - cached_central_plan == central_plan, # validated_cache_reuse + central_plan, + local_plan, + cached_central_plan == central_plan, ) diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index fc37973..f3b17f0 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -512,13 +512,11 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) # Given mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup - # Enable caching significantly for this test - strategy.use_cached_ckpt_structure = True - cached_plan = mocker.MagicMock() cached_metadata = mocker.MagicMock() - # First call: No cache + # --- Call 1: No cache --- + # Given mock_statedictsaver.generate_plan.return_value = ( (mocker.MagicMock(), [], mocker.MagicMock()), cached_plan, # cached_central_plan returned @@ -526,14 +524,15 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) False, ) - # When 1 + # When strategy.async_save(sharded_state_dict, checkpoint_id.data) - # Then 1 + # Then assert strategy.cached_central_plan == cached_plan assert strategy.validated_cache_reuse is False - # Second call: Cache validation success + # --- Call 2: Cache validation success --- + # Given mock_statedictsaver.generate_plan.return_value = ( (mocker.MagicMock(), [], cached_metadata), cached_plan, @@ -541,14 +540,15 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) True, # validated_cache_reuse ) - # When 2 + # When strategy.async_save(sharded_state_dict, checkpoint_id.data) - # Then 2 + # Then assert strategy.validated_cache_reuse is True assert strategy.cached_global_metadata == cached_metadata - # Third call: Reuse cache + # --- Call 3: Reuse cache --- + # Given mock_statedictsaver.generate_plan.return_value = ( (mocker.MagicMock(), [], None), # Returns None for metadata cached_plan, @@ -556,12 +556,10 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) True, ) - # During third call, async_save should use self.cached_global_metadata - - # When 3 + # When strategy.async_save(sharded_state_dict, checkpoint_id.data) - # Then 3 + # Then # Ensure generate_plan was called without cached_global_metadata _, kwargs = mock_statedictsaver.generate_plan.call_args assert "cached_global_metadata" not in kwargs From dd4efced9a1c8d42ef83d5e1d9507ff589196cb1 Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 2 Mar 2026 20:13:28 +0000 Subject: [PATCH 4/4] resolve comments --- .../adapter/megatron/save_strategies.py | 49 ++++++++--------- .../pytorch/custom_state_dict_saver.py | 13 ++--- .../adapter/megatron/test_save_strategies.py | 53 +++++++++++-------- .../pytorch/test_custom_state_dict_saver.py | 24 +++++---- 4 files changed, 72 insertions(+), 67 deletions(-) diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index 2dc0bd4..759d56b 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -100,11 +100,11 @@ def __init__( self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver # Cache for state dict saving - self.cached_central_plan: SavePlan | None = None - self.cached_local_plan: SavePlan | None = None - self.cached_global_metadata: Metadata | None = None - self.validated_cache_reuse: bool = False - self.use_cached_ckpt_structure: bool = use_cached_ckpt_structure + self._cached_central_plan: SavePlan | None = None + self._cached_local_plan: SavePlan | None = None + self._cached_global_metadata: Metadata | None = None + self._validated_cache_reuse: bool = False + self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure @override def can_handle_sharded_objects(self) -> bool: @@ -173,40 +173,33 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union # If so, reuse `cached_central_plan` and `cached_global_metadata` # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` # (return None) so `self.cached_global_metadata` is reused - args_cached_plans = None - loaded_all_plans = None - if self.use_cached_ckpt_structure: - if self.cached_global_metadata: - loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) - - args_cached_plans = ( - self.cached_central_plan, - self.cached_local_plan, - self.validated_cache_reuse, + cached_structure_args = None + if self._use_cached_ckpt_structure: + cached_structure_args = ( + self._cached_central_plan, + self._cached_local_plan, + self._validated_cache_reuse, ) ( - (plan, write_buckets, global_metadata), - self.cached_central_plan, - self.cached_local_plan, - self.validated_cache_reuse, + write_buckets, + global_metadata, + self._cached_central_plan, + self._cached_local_plan, + self._validated_cache_reuse, ) = statedictsaver.generate_plan( checkpoint_id=checkpoint_id, state_dict=pyt_state_dict, storage_writer=self._storage_writer, planner=planner, world_dist_wrapper=world_dist_wrapper, - cached_ckpt_structure=args_cached_plans, - loaded_all_plans=loaded_all_plans, + cached_ckpt_structure=cached_structure_args, ) - if self.validated_cache_reuse: - if global_metadata is None and self.cached_global_metadata: - global_metadata = self.cached_global_metadata - - # If we have a valid global_metadata (either new or reused), cache it for next time - if global_metadata is not None: - self.cached_global_metadata = global_metadata + if global_metadata is None: + global_metadata = self._cached_global_metadata + else: + self._cached_global_metadata = global_metadata # 5. Stage to CPU. staged_write_buckets = self._storage_writer.stage_write_data_buckets( diff --git a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py index b05f67e..2f40c55 100644 --- a/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py +++ b/src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py @@ -46,9 +46,9 @@ def generate_plan( planner: torchdistsaver.SavePlanner, world_dist_wrapper: _DistWrapper, cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None, - loaded_all_plans: list[SavePlan] | None = None, ) -> tuple[ - tuple[SavePlan, list[ObjectWriteBucket], torchdistsaver.Metadata | None], + list[ObjectWriteBucket], + torchdistsaver.Metadata | None, SavePlan, SavePlan, bool, @@ -69,12 +69,12 @@ def generate_plan( world_dist_wrapper: The distributed wrapper for world (all ranks) communication. Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`. cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse). - loaded_all_plans: List of all local plans from the previous checkpoint (for validation). Returns: A tuple containing: - - (final_local_plan, write_buckets, global_metadata) - - final_local_plan (for caching) + - write_buckets + - global_metadata + - central_plan (for caching) - local_plan (for caching) - validated_cache_reuse (bool) """ @@ -130,7 +130,8 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]: write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner) return ( - (final_local_plan, write_buckets, global_metadata), + write_buckets, + global_metadata, central_plan, local_plan, cached_central_plan == central_plan, diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index f3b17f0..72b0fbe 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -169,7 +169,8 @@ def test_async_save_initialization_calls_success( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()), + dummy_write_buckets, + mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), False, @@ -212,7 +213,8 @@ def test_async_save_reinitializes_storage_writer_with_thread_count( _, ) = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()), + dummy_write_buckets, + mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), False, @@ -258,7 +260,8 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s ) = async_save_setup mock_planner = MockMCoreSavePlanner.return_value mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), + mocker.MagicMock(), + mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), False, @@ -271,7 +274,6 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s "planner", "world_dist_wrapper", "cached_ckpt_structure", - "loaded_all_plans", } # When @@ -293,7 +295,6 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s assert "world_dist_wrapper" in kwargs assert kwargs["world_dist_wrapper"].use_dist is False assert "cached_ckpt_structure" in kwargs - assert "loaded_all_plans" in kwargs assert "cached_global_metadata" not in kwargs def test_generate_plan_failure(self, mocker, async_save_setup): @@ -318,7 +319,8 @@ def test_async_save_async_fn_call_success( mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (dummy_save_plan, dummy_write_buckets, dummy_metadata), + dummy_write_buckets, + dummy_metadata, mocker.MagicMock(), mocker.MagicMock(), False, @@ -358,7 +360,8 @@ def test_async_save_async_fn_failure(self, mocker, async_save_setup, checkpoint_ mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), + mocker.MagicMock(), + mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), False, @@ -389,7 +392,8 @@ def test_async_save_finalize_fns_calls( mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (dummy_save_plan, dummy_write_buckets, dummy_metadata), + dummy_write_buckets, + dummy_metadata, mocker.MagicMock(), mocker.MagicMock(), False, @@ -453,7 +457,8 @@ def test_finalize_fns_failure( mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup mock_statedictsaver.generate_plan.return_value = ( - (dummy_save_plan, mocker.MagicMock(), dummy_metadata), + mocker.MagicMock(), + dummy_metadata, mocker.MagicMock(), mocker.MagicMock(), False, @@ -495,7 +500,8 @@ def test_async_save_rank_determination( # Mock dependencies to ensure success path mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()), + mocker.MagicMock(), + mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), False, @@ -518,7 +524,8 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) # --- Call 1: No cache --- # Given mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), [], mocker.MagicMock()), + [], + mocker.MagicMock(), cached_plan, # cached_central_plan returned mocker.MagicMock(), False, @@ -528,29 +535,31 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) strategy.async_save(sharded_state_dict, checkpoint_id.data) # Then - assert strategy.cached_central_plan == cached_plan - assert strategy.validated_cache_reuse is False + assert strategy._cached_central_plan == cached_plan + assert strategy._validated_cache_reuse is False # --- Call 2: Cache validation success --- # Given mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), [], cached_metadata), + [], + cached_metadata, cached_plan, mocker.MagicMock(), - True, # validated_cache_reuse + True, ) # When strategy.async_save(sharded_state_dict, checkpoint_id.data) # Then - assert strategy.validated_cache_reuse is True - assert strategy.cached_global_metadata == cached_metadata + assert strategy._validated_cache_reuse is True + assert strategy._cached_global_metadata == cached_metadata # --- Call 3: Reuse cache --- # Given mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), [], None), # Returns None for metadata + [], + None, # Returns None for metadata cached_plan, mocker.MagicMock(), True, @@ -564,7 +573,7 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer) _, kwargs = mock_statedictsaver.generate_plan.call_args assert "cached_global_metadata" not in kwargs # And cached_global_metadata in strategy should still be the same - assert strategy.cached_global_metadata == cached_metadata + assert strategy._cached_global_metadata == cached_metadata def test_async_save_caching_disabled_by_default(self, mocker, async_save_setup, storage_writer): """Tests that caching is disabled by default.""" @@ -576,12 +585,12 @@ def test_async_save_caching_disabled_by_default(self, mocker, async_save_setup, # Call: Returns a plan that could be cached mock_statedictsaver.generate_plan.return_value = ( - (mocker.MagicMock(), [], mocker.MagicMock()), + [], + None, cached_plan, mocker.MagicMock(), False, ) - # When strategy.async_save(sharded_state_dict, checkpoint_id.data) @@ -596,4 +605,4 @@ def test_async_save_caching_disabled_by_default(self, mocker, async_save_setup, _, kwargs = mock_statedictsaver.generate_plan.call_args assert kwargs["cached_ckpt_structure"] is None - assert strategy.use_cached_ckpt_structure is False + assert strategy._use_cached_ckpt_structure is False diff --git a/tests/adapter/pytorch/test_custom_state_dict_saver.py b/tests/adapter/pytorch/test_custom_state_dict_saver.py index 27a16c9..42d367b 100644 --- a/tests/adapter/pytorch/test_custom_state_dict_saver.py +++ b/tests/adapter/pytorch/test_custom_state_dict_saver.py @@ -101,7 +101,8 @@ def test_generate_plan_calls_dependencies_correctly( # When ( - (actual_plan, actual_write_buckets, actual_metadata), + actual_write_buckets, + actual_metadata, _, _, actual_reused, @@ -120,9 +121,6 @@ def test_generate_plan_calls_dependencies_correctly( mock_storage_writer.prepare_write_data_buckets.assert_called_once_with( checkpoint_id, local_plan, mock_save_planner ) - assert actual_plan == local_plan - assert actual_write_buckets == expected_write_buckets - assert actual_plan == local_plan assert actual_write_buckets == expected_write_buckets assert actual_metadata == expected_global_metadata assert actual_reused is False @@ -141,7 +139,8 @@ def test_generate_plan_reuses_cache(self, mocker, mock_storage_writer, mock_save # When ( - (actual_plan, _, actual_metadata), + _, + actual_metadata, _, _, actual_reused, @@ -159,9 +158,9 @@ def test_generate_plan_reuses_cache(self, mocker, mock_storage_writer, mock_save # But we can't easily assert on dist_wrapper methods as they are not mocks here unless we mock them # However, we can check if they are NOT called by checking side effects if we had mocked them # For now, checking return values is good enough proxy - assert actual_plan == cached_plan assert actual_metadata is None assert actual_reused is True + mock_save_planner.create_local_plan.assert_not_called() def test_generate_plan_validates_cache_success(self, mocker, mock_storage_writer, mock_save_planner, dist_wrapper): """Tests that generate_plan validates cache successfully.""" @@ -183,6 +182,7 @@ def test_generate_plan_validates_cache_success(self, mocker, mock_storage_writer _, _, _, + _, actual_reused, ) = custom_state_dict_saver.generate_plan( CheckpointContainerId("/test"), @@ -222,7 +222,8 @@ def test_generate_plan_reduce_scatters_local_plan( # When ( - (returned_local_plan, _, _), + _, + _, _, _, _, @@ -235,7 +236,6 @@ def test_generate_plan_reduce_scatters_local_plan( # We can't directly assert the call arguments for local_step and global_step as they are inner functions. # However, we can assert that reduce_scatter was called with 'plan' as the tag. assert mock_reduce_scatter.call_args[0][0] == "plan" - assert returned_local_plan == expected_returned_plan def test_generate_plan_broadcasts_global_metadata( self, mock_storage_writer, mock_save_planner, dist_wrapper, mocker @@ -259,7 +259,8 @@ def test_generate_plan_broadcasts_global_metadata( # When ( - (_, _, returned_metadata), + _, + returned_metadata, _, _, _, @@ -291,7 +292,8 @@ def test_generate_plan_returns_write_buckets(self, mock_storage_writer, mock_sav # When ( - (_, returned_write_buckets, _), + returned_write_buckets, + _, _, _, _, @@ -324,7 +326,7 @@ def test_generate_plan_signature_compatibility(self, mock_storage_writer, mock_s # Then assert isinstance(result, tuple) - assert len(result) == 4 + assert len(result) == 5 class TestWriteData: