Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/ml_flashpoint/adapter/megatron/save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
_replace_state_dict_keys_with_sharded_keys,
mcore_to_pyt_state_dict,
)
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner import SavePlan
from torch.distributed.checkpoint.utils import _DistWrapper
from typing_extensions import override

Expand Down Expand Up @@ -83,17 +85,27 @@ def __init__(
storage_writer: MemoryStorageWriter,
backend: str = default_backend_format_name(),
version: int = default_backend_format_version(),
use_cached_ckpt_structure: bool = False,
):
"""
Args:
storage_writer (MemoryStorageWriter): The storage writer to use for saving operations.
backend (str, optional): The name of the backend format. Defaults to "ml_flashpoint", which is recommended.
version (int, optional): The version of the checkpoint format. Defaults to the latest version.
use_cached_ckpt_structure (bool, optional): Whether to reuse the checkpoint structure (plan)
from the previous save. Defaults to False.
"""
super().__init__(backend=backend, version=version)
self._storage_writer: MemoryStorageWriter = storage_writer
self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver

# Cache for state dict saving
self._cached_central_plan: SavePlan | None = None
self._cached_local_plan: SavePlan | None = None
self._cached_global_metadata: Metadata | None = None
self._validated_cache_reuse: bool = False
self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure

@override
def can_handle_sharded_objects(self) -> bool:
# Not currently used, but in case it is, ensure this strategy is used for ShardedObjects as well.
Expand Down Expand Up @@ -157,14 +169,38 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
# we also use Megatron's SavePlanner during saving for compatibility.
planner: MCoreSavePlanner = MCoreSavePlanner(can_run_decentralized_global_plan=False)
world_dist_wrapper = _DistWrapper(group=None, use_dist=not disable_dist, coordinator_rank=0)
plan, write_buckets, global_metadata = statedictsaver.generate_plan(
# Try twice to validate the generated `central_plan` is the same across iterations
# If so, reuse `cached_central_plan` and `cached_global_metadata`
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
cached_structure_args = None
if self._use_cached_ckpt_structure:
cached_structure_args = (
self._cached_central_plan,
self._cached_local_plan,
self._validated_cache_reuse,
)

(
write_buckets,
global_metadata,
self._cached_central_plan,
self._cached_local_plan,
self._validated_cache_reuse,
) = statedictsaver.generate_plan(
checkpoint_id=checkpoint_id,
state_dict=pyt_state_dict,
storage_writer=self._storage_writer,
planner=planner,
world_dist_wrapper=world_dist_wrapper,
cached_ckpt_structure=cached_structure_args,
)

if global_metadata is None:
global_metadata = self._cached_global_metadata
else:
self._cached_global_metadata = global_metadata

# 5. Stage to CPU.
staged_write_buckets = self._storage_writer.stage_write_data_buckets(
checkpoint_id, write_buckets, non_blocking=True
Expand Down
10 changes: 9 additions & 1 deletion src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_cached_ckpt_structure: bool = False,
) -> MLFlashpointAutoResume:
"""Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`.

Expand All @@ -62,6 +63,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
Defaults to False.
Returns:
An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`.
"""
Expand Down Expand Up @@ -90,6 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count=write_thread_count,
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
use_optimized_save=use_optimized_save,
use_cached_ckpt_structure=use_cached_ckpt_structure,
)

default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {}
Expand All @@ -111,6 +115,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
write_thread_count: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_cached_ckpt_structure: bool = False,
):
"""Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities.

Expand Down Expand Up @@ -138,6 +143,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
Defaults to False.

Returns:
None. The trainer's checkpoint_io is modified in-place.
Expand Down Expand Up @@ -218,7 +225,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
),
mp_manager=ctx.Manager(),
thread_count=write_thread_count,
)
),
use_cached_ckpt_structure=use_cached_ckpt_structure,
)
load_strategy = MLFlashpointMegatronLoadStrategy(
replication_manager=replication_manager,
Expand Down
62 changes: 47 additions & 15 deletions src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import torch.cuda
from torch import distributed as torchdist
from torch.distributed.checkpoint import Metadata
from torch.distributed.checkpoint import state_dict_saver as torchdistsaver
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.planner import SavePlan
Expand All @@ -46,7 +45,14 @@ def generate_plan(
storage_writer: MemoryStorageWriter,
planner: torchdistsaver.SavePlanner,
world_dist_wrapper: _DistWrapper,
) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]:
cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None,
) -> tuple[
list[ObjectWriteBucket],
torchdistsaver.Metadata | None,
SavePlan,
SavePlan,
bool,
]:
"""Performs the planning phase of checkpointing.

This function is similar to PyTorch's `state_dict_saver.save` but only
Expand All @@ -62,15 +68,28 @@ def generate_plan(
planner: The SavePlanner to use for the save.
world_dist_wrapper: The distributed wrapper for world (all ranks) communication.
Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`.
cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse).

Returns:
A tuple containing the updated local plan, write buckets, and global metadata.
A tuple containing:
- write_buckets
- global_metadata
- central_plan (for caching)
- local_plan (for caching)
- validated_cache_reuse (bool)
"""
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
if cached_ckpt_structure:
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure

global_metadata: torchdistsaver.Metadata | None = None

ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group}
local_plan = cached_local_plan

@_dcp_method_logger(**ckpt_kwargs)
def local_step() -> SavePlan:
nonlocal local_plan
storage_meta = storage_writer.storage_meta()
planner.set_up_planner(
state_dict=state_dict,
Expand All @@ -79,7 +98,9 @@ def local_step() -> SavePlan:
)
storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator)

local_plan = planner.create_local_plan()
if not validated_cache_reuse:
local_plan = planner.create_local_plan()

local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan

Expand All @@ -91,19 +112,30 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans

with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
_LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...")
updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step)

with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
_LOGGER.debug("Executing global_metadata broadcast...")
# TODO(perf): - can broadcast only to local rank 0 to reduce comms
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)

final_local_plan = planner.finish_plan(updated_local_plan)
central_plan = None
if validated_cache_reuse and cached_central_plan:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just expect that if we received cached plans as parameters, we validate if needed and return those as is. Otherwise, we go through the existing logic

Copy link
Collaborator Author

@Leahlijuan Leahlijuan Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we always received a cached plan and a validate param, if it's already validated, we skip some global coordination steps, if not we generate plan and compare it with cache and also update the validate param

_LOGGER.debug("Passed cache reusable")
local_step()
central_plan = cached_central_plan
else:
with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
_LOGGER.debug("Executing plan reduce_scatter to get central_plan...")
central_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step)

with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
_LOGGER.debug("Executing global_metadata broadcast...")
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)

final_local_plan = planner.finish_plan(central_plan)
write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner)

return final_local_plan, write_buckets, global_metadata
return (
write_buckets,
global_metadata,
central_plan,
local_plan,
cached_central_plan == central_plan,
)


@log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO)
Expand Down
Loading
Loading