Skip to content

Commit 446f9ce

Browse files
committed
resolve comments
1 parent 5706546 commit 446f9ce

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
131131

132132
return (
133133
(final_local_plan, write_buckets, global_metadata),
134-
central_plan, # cached_central_plan
135-
local_plan, # cached_local_plan
136-
cached_central_plan == central_plan, # validated_cache_reuse
134+
central_plan,
135+
local_plan,
136+
cached_central_plan == central_plan,
137137
)
138138

139139

tests/adapter/megatron/test_save_strategies.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -512,56 +512,54 @@ def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer)
512512
# Given
513513
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
514514
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
515-
# Enable caching significantly for this test
516-
strategy.use_cached_ckpt_structure = True
517-
518515
cached_plan = mocker.MagicMock()
519516
cached_metadata = mocker.MagicMock()
520517

521-
# First call: No cache
518+
# --- Call 1: No cache ---
519+
# Given
522520
mock_statedictsaver.generate_plan.return_value = (
523521
(mocker.MagicMock(), [], mocker.MagicMock()),
524522
cached_plan, # cached_central_plan returned
525523
mocker.MagicMock(),
526524
False,
527525
)
528526

529-
# When 1
527+
# When
530528
strategy.async_save(sharded_state_dict, checkpoint_id.data)
531529

532-
# Then 1
530+
# Then
533531
assert strategy.cached_central_plan == cached_plan
534532
assert strategy.validated_cache_reuse is False
535533

536-
# Second call: Cache validation success
534+
# --- Call 2: Cache validation success ---
535+
# Given
537536
mock_statedictsaver.generate_plan.return_value = (
538537
(mocker.MagicMock(), [], cached_metadata),
539538
cached_plan,
540539
mocker.MagicMock(),
541540
True, # validated_cache_reuse
542541
)
543542

544-
# When 2
543+
# When
545544
strategy.async_save(sharded_state_dict, checkpoint_id.data)
546545

547-
# Then 2
546+
# Then
548547
assert strategy.validated_cache_reuse is True
549548
assert strategy.cached_global_metadata == cached_metadata
550549

551-
# Third call: Reuse cache
550+
# --- Call 3: Reuse cache ---
551+
# Given
552552
mock_statedictsaver.generate_plan.return_value = (
553553
(mocker.MagicMock(), [], None), # Returns None for metadata
554554
cached_plan,
555555
mocker.MagicMock(),
556556
True,
557557
)
558558

559-
# During third call, async_save should use self.cached_global_metadata
560-
561-
# When 3
559+
# When
562560
strategy.async_save(sharded_state_dict, checkpoint_id.data)
563561

564-
# Then 3
562+
# Then
565563
# Ensure generate_plan was called without cached_global_metadata
566564
_, kwargs = mock_statedictsaver.generate_plan.call_args
567565
assert "cached_global_metadata" not in kwargs

0 commit comments

Comments
 (0)