feat: Cache checkpoint saving plan in each local rank.#60
feat: Cache checkpoint saving plan in each local rank.#60Leahlijuan wants to merge 3 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a plan cache to the checkpoint saving process, aiming to improve efficiency by reusing previously computed save plans. The changes include adding a PlanCache type, functions for computing a plan structure hash, rehydrating plans with current tensors, and sanitizing plans for caching. Additionally, the generate_plan function is modified to utilize the cache, and tests are added to verify the caching mechanism. The goal is to reduce the overhead of global planning by skipping it when the model architecture or state hasn't changed.
d41d2a2 to
001a9aa
Compare
|
/gemini summary |
|
/gemini review |
Summary of ChangesThis pull request enhances the distributed checkpoint saving process by caching checkpoint saving plans in each local rank. This optimization bypasses the need for expensive collective communication operations, leading to more efficient checkpointing when only the values of tensors change, not their underlying structure. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
|
There was a problem hiding this comment.
Code Review
This pull request introduces a caching mechanism for the checkpoint saving plan to optimize performance by avoiding redundant distributed communication. The changes are implemented in MLFlashpointMegatronAsyncSaveStrategy and custom_state_dict_saver, with corresponding updates to tests. The overall logic for caching and validation appears correct. My review includes a few suggestions for improving test code clarity and cleanliness by adhering to the repository's style guide, removing duplicated assertions, and cleaning up commented-out code.
001a9aa to
5706546
Compare
Python Code Coverage Summary
Minimum allowed line rate is |
C++ Code Coverage Summary
Minimum allowed line rate is |
g-husam
left a comment
There was a problem hiding this comment.
let's make the cache variables and logic/flow clearer. One approach is to just decide whether we will use some cached data or not before even calling generate_plan, avoiding it altogether if we will use the cached data.
Only downside of that is that it might make metrics less clear as that function call wont exist for later checkpoints, but that might not matter much.
Alternative is to do a check at the beginning on whether we want to return cached data as is, and then do so. If not, use regular logic. This is a little awkward in that for the cached scenario, we just return what we were given, but allows the function calls to be consistent for every checkpoint
| self.cached_local_plan: SavePlan | None = None | ||
| self.cached_global_metadata: Metadata | None = None | ||
| self.validated_cache_reuse: bool = False | ||
| self.use_cached_ckpt_structure: bool = use_cached_ckpt_structure |
There was a problem hiding this comment.
these properties can all be "private" e.g. self._cached_local_plan.
| loaded_all_plans=loaded_all_plans, | ||
| ) | ||
|
|
||
| if self.validated_cache_reuse: |
There was a problem hiding this comment.
do we need this condition? why not just always do this "validation". Tho its not clear how this is a validation, its just setting global_metadata based on some conditions. Not sure what this condition is meant for
| # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` | ||
| # (return None) so `self.cached_global_metadata` is reused | ||
| args_cached_plans = None | ||
| loaded_all_plans = None |
There was a problem hiding this comment.
nit: these megatron variable names are not clear to me, let's rename them for clarity. This is just caching all local plans from all ranks? if so, can name it something like all_rank_local_plans_cached
| planner=planner, | ||
| world_dist_wrapper=world_dist_wrapper, | ||
| cached_ckpt_structure=args_cached_plans, | ||
| loaded_all_plans=loaded_all_plans, |
There was a problem hiding this comment.
nit: same here, can rename the parameter for clarity
| args_cached_plans = None | ||
| loaded_all_plans = None | ||
| if self.use_cached_ckpt_structure: | ||
| if self.cached_global_metadata: |
There was a problem hiding this comment.
will this metadata also include the write items and storage md that we store after writing? should make sure that's excluded and still determined after the write (so when we add that data into the metadata, we're not modifying the cached metadata)
| global_metadata = self.cached_global_metadata | ||
|
|
||
| # If we have a valid global_metadata (either new or reused), cache it for next time | ||
| if global_metadata is not None: |
There was a problem hiding this comment.
this condition and the one above are a little confusing. Having an else between them will prob help, and removing any unnecessary conditions.
E.g. if we only get an actual global_metadata that is not none when its new/diff, then can just do
if global_metadata is not None:
# update cached metadata
self.cached_global_metadata = global_metadata
else:
# use cached metadata
global_metadata = self.cached_global_metadata
is the logic above accurate?
| A tuple containing: | ||
| - (final_local_plan, write_buckets, global_metadata) | ||
| - final_local_plan (for caching) | ||
| - local_plan (for caching) |
There was a problem hiding this comment.
why do we need 2 copies of the final local plan? this return structure is pretty odd, can we just return a single tuple, or preferably a proper dataclass/object
| storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator) | ||
|
|
||
| local_plan = planner.create_local_plan() | ||
| if not validated_cache_reuse: |
There was a problem hiding this comment.
this boolean doesnt seem to be used for validation, just for caching? do we need to condition this line at all? why not just always do it (we are always updating it via the storage_writer on the next line anyway)
|
|
||
| final_local_plan = planner.finish_plan(updated_local_plan) | ||
| central_plan = None | ||
| if validated_cache_reuse and cached_central_plan: |
There was a problem hiding this comment.
I would just expect that if we received cached plans as parameters, we validate if needed and return those as is. Otherwise, we go through the existing logic
| loaded_all_plans = None | ||
| if self.use_cached_ckpt_structure: | ||
| if self.cached_global_metadata: | ||
| loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) |
There was a problem hiding this comment.
can we flatten these conditions into one if with an and if self.use_cached_ckpt_structure and self.cached_global_metadata:
Fixes #58