@@ -512,56 +512,54 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer)
512512 # Given
513513 mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
514514 strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
515- # Enable caching significantly for this test
516- strategy .use_cached_ckpt_structure = True
517-
518515 cached_plan = mocker .MagicMock ()
519516 cached_metadata = mocker .MagicMock ()
520517
521- # First call: No cache
518+ # --- Call 1: No cache ---
519+ # Given
522520 mock_statedictsaver .generate_plan .return_value = (
523521 (mocker .MagicMock (), [], mocker .MagicMock ()),
524522 cached_plan , # cached_central_plan returned
525523 mocker .MagicMock (),
526524 False ,
527525 )
528526
529- # When 1
527+ # When
530528 strategy .async_save (sharded_state_dict , checkpoint_id .data )
531529
532- # Then 1
530+ # Then
533531 assert strategy .cached_central_plan == cached_plan
534532 assert strategy .validated_cache_reuse is False
535533
536- # Second call: Cache validation success
534+ # --- Call 2: Cache validation success ---
535+ # Given
537536 mock_statedictsaver .generate_plan .return_value = (
538537 (mocker .MagicMock (), [], cached_metadata ),
539538 cached_plan ,
540539 mocker .MagicMock (),
541540 True , # validated_cache_reuse
542541 )
543542
544- # When 2
543+ # When
545544 strategy .async_save (sharded_state_dict , checkpoint_id .data )
546545
547- # Then 2
546+ # Then
548547 assert strategy .validated_cache_reuse is True
549548 assert strategy .cached_global_metadata == cached_metadata
550549
551- # Third call: Reuse cache
550+ # --- Call 3: Reuse cache ---
551+ # Given
552552 mock_statedictsaver .generate_plan .return_value = (
553553 (mocker .MagicMock (), [], None ), # Returns None for metadata
554554 cached_plan ,
555555 mocker .MagicMock (),
556556 True ,
557557 )
558558
559- # During third call, async_save should use self.cached_global_metadata
560-
561- # When 3
559+ # When
562560 strategy .async_save (sharded_state_dict , checkpoint_id .data )
563561
564- # Then 3
562+ # Then
565563 # Ensure generate_plan was called without cached_global_metadata
566564 _ , kwargs = mock_statedictsaver .generate_plan .call_args
567565 assert "cached_global_metadata" not in kwargs
0 commit comments