-
Notifications
You must be signed in to change notification settings - Fork 5
feat: Cache checkpoint saving plan in each local rank. #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ | |
|
|
||
| import torch.cuda | ||
| from torch import distributed as torchdist | ||
| from torch.distributed.checkpoint import Metadata | ||
| from torch.distributed.checkpoint import state_dict_saver as torchdistsaver | ||
| from torch.distributed.checkpoint.logger import _dcp_method_logger | ||
| from torch.distributed.checkpoint.planner import SavePlan | ||
|
|
@@ -46,7 +45,14 @@ def generate_plan( | |
| storage_writer: MemoryStorageWriter, | ||
| planner: torchdistsaver.SavePlanner, | ||
| world_dist_wrapper: _DistWrapper, | ||
| ) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]: | ||
| cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None, | ||
| ) -> tuple[ | ||
| list[ObjectWriteBucket], | ||
| torchdistsaver.Metadata | None, | ||
| SavePlan, | ||
| SavePlan, | ||
| bool, | ||
| ]: | ||
| """Performs the planning phase of checkpointing. | ||
|
|
||
| This function is similar to PyTorch's `state_dict_saver.save` but only | ||
|
|
@@ -62,15 +68,28 @@ def generate_plan( | |
| planner: The SavePlanner to use for the save. | ||
| world_dist_wrapper: The distributed wrapper for world (all ranks) communication. | ||
| Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`. | ||
| cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse). | ||
|
|
||
| Returns: | ||
| A tuple containing the updated local plan, write buckets, and global metadata. | ||
| A tuple containing: | ||
| - write_buckets | ||
| - global_metadata | ||
| - central_plan (for caching) | ||
| - local_plan (for caching) | ||
| - validated_cache_reuse (bool) | ||
| """ | ||
| cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) | ||
| if cached_ckpt_structure: | ||
| cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure | ||
|
|
||
| global_metadata: torchdistsaver.Metadata | None = None | ||
|
|
||
| ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group} | ||
| local_plan = cached_local_plan | ||
|
|
||
| @_dcp_method_logger(**ckpt_kwargs) | ||
| def local_step() -> SavePlan: | ||
| nonlocal local_plan | ||
| storage_meta = storage_writer.storage_meta() | ||
| planner.set_up_planner( | ||
| state_dict=state_dict, | ||
|
|
@@ -79,7 +98,9 @@ def local_step() -> SavePlan: | |
| ) | ||
| storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator) | ||
|
|
||
| local_plan = planner.create_local_plan() | ||
| if not validated_cache_reuse: | ||
Leahlijuan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| local_plan = planner.create_local_plan() | ||
|
|
||
| local_plan = storage_writer.prepare_local_plan(local_plan) | ||
| return local_plan | ||
|
|
||
|
|
@@ -91,19 +112,30 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]: | |
| all_local_plans = storage_writer.prepare_global_plan(all_local_plans) | ||
| return all_local_plans | ||
|
|
||
| with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): | ||
| _LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...") | ||
| updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) | ||
|
|
||
| with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): | ||
| _LOGGER.debug("Executing global_metadata broadcast...") | ||
| # TODO(perf): - can broadcast only to local rank 0 to reduce comms | ||
| global_metadata = world_dist_wrapper.broadcast_object(global_metadata) | ||
|
|
||
| final_local_plan = planner.finish_plan(updated_local_plan) | ||
| central_plan = None | ||
| if validated_cache_reuse and cached_central_plan: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would just expect that if we received cached plans as parameters, we validate if needed and return those as is. Otherwise, we go through the existing logic
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we always received a cached plan and a validate param, if it's already validated, we skip some global coordination steps, if not we generate plan and compare it with cache and also update the validate param |
||
| _LOGGER.debug("Passed cache reusable") | ||
| local_step() | ||
| central_plan = cached_central_plan | ||
| else: | ||
| with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"): | ||
| _LOGGER.debug("Executing plan reduce_scatter to get central_plan...") | ||
| central_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step) | ||
|
|
||
| with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"): | ||
| _LOGGER.debug("Executing global_metadata broadcast...") | ||
Leahlijuan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| global_metadata = world_dist_wrapper.broadcast_object(global_metadata) | ||
|
|
||
| final_local_plan = planner.finish_plan(central_plan) | ||
| write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner) | ||
|
|
||
| return final_local_plan, write_buckets, global_metadata | ||
| return ( | ||
| write_buckets, | ||
| global_metadata, | ||
| central_plan, | ||
| local_plan, | ||
| cached_central_plan == central_plan, | ||
| ) | ||
|
|
||
|
|
||
| @log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.