Skip to content

feat: Cache checkpoint saving plan in each local rank.#60

Open
Leahlijuan wants to merge 3 commits intomainfrom
feat/cache-plan
Open

feat: Cache checkpoint saving plan in each local rank.#60
Leahlijuan wants to merge 3 commits intomainfrom
feat/cache-plan

Conversation

@Leahlijuan
Copy link
Collaborator

@Leahlijuan Leahlijuan commented Feb 24, 2026

Fixes #58

  • Tests pass
  • Appropriate changes to documentation are included in the PR

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a plan cache to the checkpoint saving process, aiming to improve efficiency by reusing previously computed save plans. The changes include adding a PlanCache type, functions for computing a plan structure hash, rehydrating plans with current tensors, and sanitizing plans for caching. Additionally, the generate_plan function is modified to utilize the cache, and tests are added to verify the caching mechanism. The goal is to reduce the overhead of global planning by skipping it when the model architecture or state hasn't changed.

@Leahlijuan Leahlijuan marked this pull request as draft February 24, 2026 21:54
@Leahlijuan
Copy link
Collaborator Author

/gemini summary

@Leahlijuan
Copy link
Collaborator Author

/gemini review

@google google deleted a comment from gemini-code-assist bot Feb 25, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

This pull request enhances the distributed checkpoint saving process by caching checkpoint saving plans in each local rank. This optimization bypasses the need for expensive collective communication operations, leading to more efficient checkpointing when only the values of tensors change, not their underlying structure.

Highlights

  • Checkpoint Plan Caching: Introduced a caching mechanism for checkpoint saving plans to optimize distributed checkpointing, avoiding redundant global planning.
  • Performance Optimization: By caching the plan structure, subsequent checkpoint saves can reuse the global plan, significantly reducing communication overhead and improving checkpointing speed.
  • Cache Validation and Reuse: Implemented logic to validate and reuse cached plans across multiple async_save calls, ensuring consistency and efficiency.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/ml_flashpoint/adapter/megatron/save_strategies.py
    • Initialized cache-related attributes for state dict saving, including cached_central_plan, cached_local_plan, cached_global_metadata, validated_cache_reuse, and use_cached_ckpt_structure.
    • Modified async_save to incorporate logic for reusing cached checkpoint structures, including validation and updating of cached plans and metadata.
  • src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py
    • Modified generate_plan to incorporate cache lookup and storage, conditionally performing global planning based on cache hits or misses.
    • Added parameters cached_ckpt_structure and loaded_all_plans to generate_plan to support caching logic.
  • tests/adapter/megatron/test_save_strategies.py
    • Updated tests to reflect changes in async_save, including the return values of generate_plan and the arguments passed to it.
    • Added a new test case, test_async_save_caching_flow, to verify the caching behavior across multiple async_save calls.
  • tests/adapter/pytorch/test_custom_state_dict_saver.py
    • Added tests for cache reuse and validation in generate_plan, ensuring that cached plans are correctly utilized when available.
Activity
  • Leahlijuan opened this pull request to address issue Cache the checkpoint saving plan #58.
  • The author has confirmed that all tests pass.
  • The author has confirmed that appropriate documentation changes are included.
  • Gemini-code-assist[bot] suggested implementing a cache eviction policy to manage memory usage, which Leahlijuan addressed by adding cache size limiting logic.
  • Gemini-code-assist[bot] suggested clarifying that tensor data is set to None to prevent memory leaks.
  • Gemini-code-assist[bot] suggested including the checkpoint ID in the log message for better debugging.
  • Gemini-code-assist[bot] noted an opportunity for performance improvement by broadcasting only to local rank 0.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism for the checkpoint saving plan to optimize performance by avoiding redundant distributed communication. The changes are implemented in MLFlashpointMegatronAsyncSaveStrategy and custom_state_dict_saver, with corresponding updates to tests. The overall logic for caching and validation appears correct. My review includes a few suggestions for improving test code clarity and cleanliness by adhering to the repository's style guide, removing duplicated assertions, and cleaning up commented-out code.

@Leahlijuan Leahlijuan requested review from g-husam and kkkapu February 26, 2026 20:50
@github-actions
Copy link

Python Code Coverage Summary

Code Coverage

Package Line Rate Branch Rate Health
src.ml_flashpoint 100% 100%
src.ml_flashpoint.adapter 100% 100%
src.ml_flashpoint.adapter.megatron 96% 82%
src.ml_flashpoint.adapter.nemo 98% 94%
src.ml_flashpoint.adapter.pytorch 99% 90%
src.ml_flashpoint.checkpoint_object_manager 92% 91%
src.ml_flashpoint.core 96% 92%
src.ml_flashpoint.replication 81% 81%
Summary 95% (2083 / 2198) 90% (485 / 536)

Minimum allowed line rate is 90%

@github-actions
Copy link

C++ Code Coverage Summary

Code Coverage

Package Line Rate Branch Rate Health
src.ml_flashpoint.checkpoint_object_manager.buffer_object 93% 54%
src.ml_flashpoint.checkpoint_object_manager.object_manager 70% 37%
src.ml_flashpoint.replication.transfer_service 79% 40%
Summary 81% (916 / 1126) 43% (687 / 1604)

Minimum allowed line rate is 80%

@Leahlijuan Leahlijuan marked this pull request as ready for review February 26, 2026 21:26
Copy link
Collaborator

@g-husam g-husam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make the cache variables and logic/flow clearer. One approach is to just decide whether we will use some cached data or not before even calling generate_plan, avoiding it altogether if we will use the cached data.
Only downside of that is that it might make metrics less clear as that function call wont exist for later checkpoints, but that might not matter much.

Alternative is to do a check at the beginning on whether we want to return cached data as is, and then do so. If not, use regular logic. This is a little awkward in that for the cached scenario, we just return what we were given, but allows the function calls to be consistent for every checkpoint

self.cached_local_plan: SavePlan | None = None
self.cached_global_metadata: Metadata | None = None
self.validated_cache_reuse: bool = False
self.use_cached_ckpt_structure: bool = use_cached_ckpt_structure
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these properties can all be "private" e.g. self._cached_local_plan.

loaded_all_plans=loaded_all_plans,
)

if self.validated_cache_reuse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this condition? why not just always do this "validation". Tho its not clear how this is a validation, its just setting global_metadata based on some conditions. Not sure what this condition is meant for

# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None
loaded_all_plans = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these megatron variable names are not clear to me, let's rename them for clarity. This is just caching all local plans from all ranks? if so, can name it something like all_rank_local_plans_cached

planner=planner,
world_dist_wrapper=world_dist_wrapper,
cached_ckpt_structure=args_cached_plans,
loaded_all_plans=loaded_all_plans,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: same here, can rename the parameter for clarity

args_cached_plans = None
loaded_all_plans = None
if self.use_cached_ckpt_structure:
if self.cached_global_metadata:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this metadata also include the write items and storage md that we store after writing? should make sure that's excluded and still determined after the write (so when we add that data into the metadata, we're not modifying the cached metadata)

global_metadata = self.cached_global_metadata

# If we have a valid global_metadata (either new or reused), cache it for next time
if global_metadata is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition and the one above are a little confusing. Having an else between them will prob help, and removing any unnecessary conditions.

E.g. if we only get an actual global_metadata that is not none when its new/diff, then can just do

if global_metadata is not None: 
   # update cached metadata
    self.cached_global_metadata = global_metadata
else:
    # use cached metadata
    global_metadata = self.cached_global_metadata

is the logic above accurate?

A tuple containing:
- (final_local_plan, write_buckets, global_metadata)
- final_local_plan (for caching)
- local_plan (for caching)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need 2 copies of the final local plan? this return structure is pretty odd, can we just return a single tuple, or preferably a proper dataclass/object

storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator)

local_plan = planner.create_local_plan()
if not validated_cache_reuse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this boolean doesnt seem to be used for validation, just for caching? do we need to condition this line at all? why not just always do it (we are always updating it via the storage_writer on the next line anyway)


final_local_plan = planner.finish_plan(updated_local_plan)
central_plan = None
if validated_cache_reuse and cached_central_plan:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just expect that if we received cached plans as parameters, we validate if needed and return those as is. Otherwise, we go through the existing logic

loaded_all_plans = None
if self.use_cached_ckpt_structure:
if self.cached_global_metadata:
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we flatten these conditions into one if with an and if self.use_cached_ckpt_structure and self.cached_global_metadata:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cache the checkpoint saving plan

2 participants