@@ -169,9 +169,10 @@ def test_async_save_initialization_calls_success(
169169 _ ,
170170 ) = async_save_setup
171171 mock_statedictsaver .generate_plan .return_value = (
172+ (mocker .MagicMock (), dummy_write_buckets , mocker .MagicMock ()),
172173 mocker .MagicMock (),
173- dummy_write_buckets ,
174174 mocker .MagicMock (),
175+ False ,
175176 )
176177
177178 mock_memory_storage_writer_cls = mocker .patch (
@@ -211,9 +212,10 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
211212 _ ,
212213 ) = async_save_setup
213214 mock_statedictsaver .generate_plan .return_value = (
215+ (mocker .MagicMock (), dummy_write_buckets , mocker .MagicMock ()),
214216 mocker .MagicMock (),
215- dummy_write_buckets ,
216217 mocker .MagicMock (),
218+ False ,
217219 )
218220
219221 # Set a specific thread_count on the original storage_writer
@@ -256,12 +258,21 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
256258 ) = async_save_setup
257259 mock_planner = MockMCoreSavePlanner .return_value
258260 mock_statedictsaver .generate_plan .return_value = (
261+ (mocker .MagicMock (), mocker .MagicMock (), mocker .MagicMock ()),
259262 mocker .MagicMock (),
260263 mocker .MagicMock (),
261- mocker . MagicMock () ,
264+ False ,
262265 )
263266
264- expected_kwarg_keys = {"checkpoint_id" , "state_dict" , "storage_writer" , "planner" , "world_dist_wrapper" }
267+ expected_kwarg_keys = {
268+ "checkpoint_id" ,
269+ "state_dict" ,
270+ "storage_writer" ,
271+ "planner" ,
272+ "world_dist_wrapper" ,
273+ "cached_ckpt_structure" ,
274+ "loaded_all_plans" ,
275+ }
265276
266277 # When
267278 strategy .async_save (sharded_state_dict , checkpoint_id .data )
@@ -281,6 +292,9 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
281292 assert kwargs ["planner" ] is mock_planner
282293 assert "world_dist_wrapper" in kwargs
283294 assert kwargs ["world_dist_wrapper" ].use_dist is False
295+ assert "cached_ckpt_structure" in kwargs
296+ assert "loaded_all_plans" in kwargs
297+ assert "cached_global_metadata" not in kwargs
284298
285299 def test_generate_plan_failure (self , mocker , async_save_setup ):
286300 """Tests that an exception in generate_plan is propagated."""
@@ -303,7 +317,12 @@ def test_async_save_async_fn_call_success(
303317
304318 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
305319 strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
306- mock_statedictsaver .generate_plan .return_value = (dummy_save_plan , dummy_write_buckets , dummy_metadata )
320+ mock_statedictsaver .generate_plan .return_value = (
321+ (dummy_save_plan , dummy_write_buckets , dummy_metadata ),
322+ mocker .MagicMock (),
323+ mocker .MagicMock (),
324+ False ,
325+ )
307326 staged_write_buckets = [
308327 ObjectWriteBucket (
309328 object_id = CheckpointObjectId (f"/test_checkpoint/staged_obj_{ i } " ),
@@ -339,9 +358,10 @@ def test_async_save_async_fn_failure(self, mocker, async_save_setup, checkpoint_
339358 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
340359 strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
341360 mock_statedictsaver .generate_plan .return_value = (
361+ (mocker .MagicMock (), mocker .MagicMock (), mocker .MagicMock ()),
342362 mocker .MagicMock (),
343363 mocker .MagicMock (),
344- mocker . MagicMock () ,
364+ False ,
345365 )
346366 mock_statedictsaver .write_data .side_effect = Exception ("Test Exception" )
347367
@@ -368,7 +388,12 @@ def test_async_save_finalize_fns_calls(
368388 finalize_checkpoint_spy = mocker .spy (checkpoint_saver , "finalize_checkpoint" )
369389 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
370390 strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
371- mock_statedictsaver .generate_plan .return_value = (dummy_save_plan , dummy_write_buckets , dummy_metadata )
391+ mock_statedictsaver .generate_plan .return_value = (
392+ (dummy_save_plan , dummy_write_buckets , dummy_metadata ),
393+ mocker .MagicMock (),
394+ mocker .MagicMock (),
395+ False ,
396+ )
372397
373398 mock_memory_storage_writer_cls = mocker .patch (
374399 "ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
@@ -427,7 +452,12 @@ def test_finalize_fns_failure(
427452 finalize_checkpoint_spy = mocker .spy (checkpoint_saver , "finalize_checkpoint" )
428453 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
429454 strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
430- mock_statedictsaver .generate_plan .return_value = (dummy_save_plan , mocker .MagicMock (), dummy_metadata )
455+ mock_statedictsaver .generate_plan .return_value = (
456+ (dummy_save_plan , mocker .MagicMock (), dummy_metadata ),
457+ mocker .MagicMock (),
458+ mocker .MagicMock (),
459+ False ,
460+ )
431461 mock_statedictsaver .finish_write .side_effect = ValueError ("Finish Write Failed" )
432462
433463 # When
@@ -465,13 +495,73 @@ def test_async_save_rank_determination(
465495 # Mock dependencies to ensure success path
466496 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
467497 mock_statedictsaver .generate_plan .return_value = (
498+ (mocker .MagicMock (), mocker .MagicMock (), mocker .MagicMock ()),
468499 mocker .MagicMock (),
469500 mocker .MagicMock (),
470- mocker . MagicMock () ,
501+ False ,
471502 )
472503
473504 # When
474505 actual_async_request = strategy .async_save (sharded_state_dict , checkpoint_id .data )
475506
476507 # Then
477508 assert actual_async_request .async_fn_kwargs ["rank" ] == expected_rank
509+
510+ def test_async_save_caching_flow (self , mocker , async_save_setup , storage_writer ):
511+ """Tests the caching flow across multiple async_save calls."""
512+ # Given
513+ mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
514+ strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
515+
516+ cached_plan = mocker .MagicMock ()
517+ cached_metadata = mocker .MagicMock ()
518+
519+ # First call: No cache
520+ mock_statedictsaver .generate_plan .return_value = (
521+ (mocker .MagicMock (), [], mocker .MagicMock ()),
522+ cached_plan , # cached_central_plan returned
523+ mocker .MagicMock (),
524+ False ,
525+ )
526+
527+ # When 1
528+ strategy .async_save (sharded_state_dict , checkpoint_id .data )
529+
530+ # Then 1
531+ assert strategy .cached_central_plan == cached_plan
532+ assert strategy .validated_cache_reuse is False
533+
534+ # Second call: Cache validation success
535+ mock_statedictsaver .generate_plan .return_value = (
536+ (mocker .MagicMock (), [], cached_metadata ),
537+ cached_plan ,
538+ mocker .MagicMock (),
539+ True , # validated_cache_reuse
540+ )
541+
542+ # When 2
543+ strategy .async_save (sharded_state_dict , checkpoint_id .data )
544+
545+ # Then 2
546+ assert strategy .validated_cache_reuse is True
547+ assert strategy .cached_global_metadata == cached_metadata
548+
549+ # Third call: Reuse cache
550+ mock_statedictsaver .generate_plan .return_value = (
551+ (mocker .MagicMock (), [], None ), # Returns None for metadata
552+ cached_plan ,
553+ mocker .MagicMock (),
554+ True ,
555+ )
556+
557+ # During third call, async_save should use self.cached_global_metadata
558+
559+ # When 3
560+ strategy .async_save (sharded_state_dict , checkpoint_id .data )
561+
562+ # Then 3
563+ # Ensure generate_plan was called without cached_global_metadata
564+ _ , kwargs = mock_statedictsaver .generate_plan .call_args
565+ assert "cached_global_metadata" not in kwargs
566+ # And cached_global_metadata in strategy should still be the same
567+ assert strategy .cached_global_metadata == cached_metadata
0 commit comments