Skip to content

Commit 001a9aa

Browse files
committed
Use megatron method to cache plan
1 parent 7075ce4 commit 001a9aa

File tree

4 files changed

+306
-29
lines changed

4 files changed

+306
-29
lines changed

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
_replace_state_dict_keys_with_sharded_keys,
2929
mcore_to_pyt_state_dict,
3030
)
31+
from torch.distributed.checkpoint.metadata import Metadata
32+
from torch.distributed.checkpoint.planner import SavePlan
3133
from torch.distributed.checkpoint.utils import _DistWrapper
3234
from typing_extensions import override
3335

@@ -94,6 +96,13 @@ def __init__(
9496
self._storage_writer: MemoryStorageWriter = storage_writer
9597
self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver
9698

99+
# Cache for state dict saving
100+
self.cached_central_plan: SavePlan | None = None
101+
self.cached_local_plan: SavePlan | None = None
102+
self.cached_global_metadata: Metadata | None = None
103+
self.validated_cache_reuse: bool = False
104+
self.use_cached_ckpt_structure: bool = True
105+
97106
@override
98107
def can_handle_sharded_objects(self) -> bool:
99108
# Not currently used, but in case it is, ensure this strategy is used for ShardedObjects as well.
@@ -157,14 +166,45 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
157166
# we also use Megatron's SavePlanner during saving for compatibility.
158167
planner: MCoreSavePlanner = MCoreSavePlanner(can_run_decentralized_global_plan=False)
159168
world_dist_wrapper = _DistWrapper(group=None, use_dist=not disable_dist, coordinator_rank=0)
160-
plan, write_buckets, global_metadata = statedictsaver.generate_plan(
169+
# Try twice to validate the generated `central_plan` is the same across iterations
170+
# If so, reuse `cached_central_plan` and `cached_global_metadata`
171+
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
172+
# (return None) so `self.cached_global_metadata` is reused
173+
args_cached_plans = None
174+
loaded_all_plans = None
175+
if self.use_cached_ckpt_structure:
176+
if self.cached_global_metadata:
177+
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
178+
179+
args_cached_plans = (
180+
self.cached_central_plan,
181+
self.cached_local_plan,
182+
self.validated_cache_reuse,
183+
)
184+
185+
(
186+
(plan, write_buckets, global_metadata),
187+
self.cached_central_plan,
188+
self.cached_local_plan,
189+
self.validated_cache_reuse,
190+
) = statedictsaver.generate_plan(
161191
checkpoint_id=checkpoint_id,
162192
state_dict=pyt_state_dict,
163193
storage_writer=self._storage_writer,
164194
planner=planner,
165195
world_dist_wrapper=world_dist_wrapper,
196+
cached_ckpt_structure=args_cached_plans,
197+
loaded_all_plans=loaded_all_plans,
166198
)
167199

200+
if self.validated_cache_reuse:
201+
if global_metadata is None and self.cached_global_metadata:
202+
global_metadata = self.cached_global_metadata
203+
204+
# If we have a valid global_metadata (either new or reused), cache it for next time
205+
if global_metadata is not None:
206+
self.cached_global_metadata = global_metadata
207+
168208
# 5. Stage to CPU.
169209
staged_write_buckets = self._storage_writer.stage_write_data_buckets(
170210
checkpoint_id, write_buckets, non_blocking=True

src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import torch.cuda
2424
from torch import distributed as torchdist
25-
from torch.distributed.checkpoint import Metadata
2625
from torch.distributed.checkpoint import state_dict_saver as torchdistsaver
2726
from torch.distributed.checkpoint.logger import _dcp_method_logger
2827
from torch.distributed.checkpoint.planner import SavePlan
@@ -46,7 +45,14 @@ def generate_plan(
4645
storage_writer: MemoryStorageWriter,
4746
planner: torchdistsaver.SavePlanner,
4847
world_dist_wrapper: _DistWrapper,
49-
) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]:
48+
cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None,
49+
loaded_all_plans: list[SavePlan] | None = None,
50+
) -> tuple[
51+
tuple[SavePlan, list[ObjectWriteBucket], torchdistsaver.Metadata | None],
52+
SavePlan,
53+
SavePlan,
54+
bool,
55+
]:
5056
"""Performs the planning phase of checkpointing.
5157
5258
This function is similar to PyTorch's `state_dict_saver.save` but only
@@ -62,15 +68,28 @@ def generate_plan(
6268
planner: The SavePlanner to use for the save.
6369
world_dist_wrapper: The distributed wrapper for world (all ranks) communication.
6470
Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`.
71+
cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse).
72+
loaded_all_plans: List of all local plans from the previous checkpoint (for validation).
73+
6574
Returns:
66-
A tuple containing the updated local plan, write buckets, and global metadata.
75+
A tuple containing:
76+
- (final_local_plan, write_buckets, global_metadata)
77+
- final_local_plan (for caching)
78+
- local_plan (for caching)
79+
- validated_cache_reuse (bool)
6780
"""
81+
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
82+
if cached_ckpt_structure:
83+
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure
84+
6885
global_metadata: torchdistsaver.Metadata | None = None
6986

7087
ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group}
88+
local_plan = cached_local_plan
7189

7290
@_dcp_method_logger(**ckpt_kwargs)
7391
def local_step() -> SavePlan:
92+
nonlocal local_plan
7493
storage_meta = storage_writer.storage_meta()
7594
planner.set_up_planner(
7695
state_dict=state_dict,
@@ -79,7 +98,9 @@ def local_step() -> SavePlan:
7998
)
8099
storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator)
81100

82-
local_plan = planner.create_local_plan()
101+
if not validated_cache_reuse:
102+
local_plan = planner.create_local_plan()
103+
83104
local_plan = storage_writer.prepare_local_plan(local_plan)
84105
return local_plan
85106

@@ -91,19 +112,29 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
91112
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
92113
return all_local_plans
93114

94-
with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
95-
_LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...")
96-
updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step)
97-
98-
with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
99-
_LOGGER.debug("Executing global_metadata broadcast...")
100-
# TODO(perf): - can broadcast only to local rank 0 to reduce comms
101-
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)
102-
103-
final_local_plan = planner.finish_plan(updated_local_plan)
115+
central_plan = None
116+
if validated_cache_reuse and cached_central_plan:
117+
_LOGGER.debug("Passed cache reusable")
118+
local_step()
119+
central_plan = cached_central_plan
120+
else:
121+
with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
122+
_LOGGER.debug("Executing plan reduce_scatter to get central_plan...")
123+
central_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step)
124+
125+
with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
126+
_LOGGER.debug("Executing global_metadata broadcast...")
127+
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)
128+
129+
final_local_plan = planner.finish_plan(central_plan)
104130
write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner)
105131

106-
return final_local_plan, write_buckets, global_metadata
132+
return (
133+
(final_local_plan, write_buckets, global_metadata),
134+
central_plan, # cached_central_plan
135+
local_plan, # cached_local_plan
136+
cached_central_plan == central_plan, # validated_cache_reuse
137+
)
107138

108139

109140
@log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO)

tests/adapter/megatron/test_save_strategies.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def test_async_save_initialization_calls_success(
169169
_,
170170
) = async_save_setup
171171
mock_statedictsaver.generate_plan.return_value = (
172+
(mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()),
172173
mocker.MagicMock(),
173-
dummy_write_buckets,
174174
mocker.MagicMock(),
175+
False,
175176
)
176177

177178
mock_memory_storage_writer_cls = mocker.patch(
@@ -211,9 +212,10 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
211212
_,
212213
) = async_save_setup
213214
mock_statedictsaver.generate_plan.return_value = (
215+
(mocker.MagicMock(), dummy_write_buckets, mocker.MagicMock()),
214216
mocker.MagicMock(),
215-
dummy_write_buckets,
216217
mocker.MagicMock(),
218+
False,
217219
)
218220

219221
# Set a specific thread_count on the original storage_writer
@@ -256,12 +258,21 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
256258
) = async_save_setup
257259
mock_planner = MockMCoreSavePlanner.return_value
258260
mock_statedictsaver.generate_plan.return_value = (
261+
(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()),
259262
mocker.MagicMock(),
260263
mocker.MagicMock(),
261-
mocker.MagicMock(),
264+
False,
262265
)
263266

264-
expected_kwarg_keys = {"checkpoint_id", "state_dict", "storage_writer", "planner", "world_dist_wrapper"}
267+
expected_kwarg_keys = {
268+
"checkpoint_id",
269+
"state_dict",
270+
"storage_writer",
271+
"planner",
272+
"world_dist_wrapper",
273+
"cached_ckpt_structure",
274+
"loaded_all_plans",
275+
}
265276

266277
# When
267278
strategy.async_save(sharded_state_dict, checkpoint_id.data)
@@ -281,6 +292,9 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
281292
assert kwargs["planner"] is mock_planner
282293
assert "world_dist_wrapper" in kwargs
283294
assert kwargs["world_dist_wrapper"].use_dist is False
295+
assert "cached_ckpt_structure" in kwargs
296+
assert "loaded_all_plans" in kwargs
297+
assert "cached_global_metadata" not in kwargs
284298

285299
def test_generate_plan_failure(self, mocker, async_save_setup):
286300
"""Tests that an exception in generate_plan is propagated."""
@@ -303,7 +317,12 @@ def test_async_save_async_fn_call_success(
303317

304318
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
305319
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
306-
mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata)
320+
mock_statedictsaver.generate_plan.return_value = (
321+
(dummy_save_plan, dummy_write_buckets, dummy_metadata),
322+
mocker.MagicMock(),
323+
mocker.MagicMock(),
324+
False,
325+
)
307326
staged_write_buckets = [
308327
ObjectWriteBucket(
309328
object_id=CheckpointObjectId(f"/test_checkpoint/staged_obj_{i}"),
@@ -339,9 +358,10 @@ def test_async_save_async_fn_failure(self, mocker, async_save_setup, checkpoint_
339358
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
340359
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
341360
mock_statedictsaver.generate_plan.return_value = (
361+
(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()),
342362
mocker.MagicMock(),
343363
mocker.MagicMock(),
344-
mocker.MagicMock(),
364+
False,
345365
)
346366
mock_statedictsaver.write_data.side_effect = Exception("Test Exception")
347367

@@ -368,7 +388,12 @@ def test_async_save_finalize_fns_calls(
368388
finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint")
369389
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
370390
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
371-
mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, dummy_write_buckets, dummy_metadata)
391+
mock_statedictsaver.generate_plan.return_value = (
392+
(dummy_save_plan, dummy_write_buckets, dummy_metadata),
393+
mocker.MagicMock(),
394+
mocker.MagicMock(),
395+
False,
396+
)
372397

373398
mock_memory_storage_writer_cls = mocker.patch(
374399
"ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
@@ -427,7 +452,12 @@ def test_finalize_fns_failure(
427452
finalize_checkpoint_spy = mocker.spy(checkpoint_saver, "finalize_checkpoint")
428453
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
429454
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
430-
mock_statedictsaver.generate_plan.return_value = (dummy_save_plan, mocker.MagicMock(), dummy_metadata)
455+
mock_statedictsaver.generate_plan.return_value = (
456+
(dummy_save_plan, mocker.MagicMock(), dummy_metadata),
457+
mocker.MagicMock(),
458+
mocker.MagicMock(),
459+
False,
460+
)
431461
mock_statedictsaver.finish_write.side_effect = ValueError("Finish Write Failed")
432462

433463
# When
@@ -465,13 +495,73 @@ def test_async_save_rank_determination(
465495
# Mock dependencies to ensure success path
466496
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
467497
mock_statedictsaver.generate_plan.return_value = (
498+
(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()),
468499
mocker.MagicMock(),
469500
mocker.MagicMock(),
470-
mocker.MagicMock(),
501+
False,
471502
)
472503

473504
# When
474505
actual_async_request = strategy.async_save(sharded_state_dict, checkpoint_id.data)
475506

476507
# Then
477508
assert actual_async_request.async_fn_kwargs["rank"] == expected_rank
509+
510+
def test_async_save_caching_flow(self, mocker, async_save_setup, storage_writer):
511+
"""Tests the caching flow across multiple async_save calls."""
512+
# Given
513+
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
514+
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
515+
516+
cached_plan = mocker.MagicMock()
517+
cached_metadata = mocker.MagicMock()
518+
519+
# First call: No cache
520+
mock_statedictsaver.generate_plan.return_value = (
521+
(mocker.MagicMock(), [], mocker.MagicMock()),
522+
cached_plan, # cached_central_plan returned
523+
mocker.MagicMock(),
524+
False,
525+
)
526+
527+
# When 1
528+
strategy.async_save(sharded_state_dict, checkpoint_id.data)
529+
530+
# Then 1
531+
assert strategy.cached_central_plan == cached_plan
532+
assert strategy.validated_cache_reuse is False
533+
534+
# Second call: Cache validation success
535+
mock_statedictsaver.generate_plan.return_value = (
536+
(mocker.MagicMock(), [], cached_metadata),
537+
cached_plan,
538+
mocker.MagicMock(),
539+
True, # validated_cache_reuse
540+
)
541+
542+
# When 2
543+
strategy.async_save(sharded_state_dict, checkpoint_id.data)
544+
545+
# Then 2
546+
assert strategy.validated_cache_reuse is True
547+
assert strategy.cached_global_metadata == cached_metadata
548+
549+
# Third call: Reuse cache
550+
mock_statedictsaver.generate_plan.return_value = (
551+
(mocker.MagicMock(), [], None), # Returns None for metadata
552+
cached_plan,
553+
mocker.MagicMock(),
554+
True,
555+
)
556+
557+
# During third call, async_save should use self.cached_global_metadata
558+
559+
# When 3
560+
strategy.async_save(sharded_state_dict, checkpoint_id.data)
561+
562+
# Then 3
563+
# Ensure generate_plan was called without cached_global_metadata
564+
_, kwargs = mock_statedictsaver.generate_plan.call_args
565+
assert "cached_global_metadata" not in kwargs
566+
# And cached_global_metadata in strategy should still be the same
567+
assert strategy.cached_global_metadata == cached_metadata

0 commit comments

Comments
 (0)