Skip to content

[FSDP] [feature]Support save/load checkpointing#406

Closed
ppraneth wants to merge 52 commits intoTHUDM:mainfrom
ppraneth:support-fsdp_save
Closed

[FSDP] [feature]Support save/load checkpointing#406
ppraneth wants to merge 52 commits intoTHUDM:mainfrom
ppraneth:support-fsdp_save

Conversation

@ppraneth
Copy link
Contributor

Hi everyone,

This pr add load/save feature in FSDP and also closes #402

Here are the main changes I've made:

  • Checkpointing Logic: I added all the logic for saving and loading into a new file, slime/backends/fsdp_utils/checkpoint.py, to keep things organized. This handles saving the model, tokenizer, and optimizer state.

  • Actor Integration: I updated the FSDPTrainRayActor in slime/backends/fsdp_utils/actor.py to use the new functions. It now correctly loads the state from a checkpoint at the start of a run and saves progress when save_model is called.

  • New Arguments: To control this, I added a few command-line arguments to slime/backends/fsdp_utils/arguments.py: --save, --load, --save-safe-serialization and a safety flag --overwrite-checkpoints so you don't accidentally overwrite your work.

  • Bug Fixes: While building this, I also fixed a couple of bugs in the init process to prevent the model from being loaded twice and to make sure the global_step counter is restored correctly.

Please check it out and let me know if any changes are needed. I'm happy to make them! Thanks for taking a look.

@PopSoda2002
Copy link
Collaborator

Can you share some test case and result?

@ppraneth
Copy link
Contributor Author

Can you share some test case and result?

I’ve written a test case, but I’m not sure if it’s correct. Since I don’t have access to GPUs, could you please check it and run it?

@PopSoda2002
Copy link
Collaborator

Contributor

synced in slack

@leng-yue
Copy link

leng-yue commented Oct 3, 2025

Thanks for your great contribution :)

I’d suggest switching to FSDP’s distributed checkpointing, as it could make saving/loading the optimizer and LLM weights noticeably faster and reduce peak memory use. Happy to prototype this if you want.

We can also let FSDP write safetensors directly (see the PyTorch blog on HuggingFace safetensors support) Example: link.

@PopSoda2002
Copy link
Collaborator

PopSoda2002 commented Oct 5, 2025

Thanks for your great contribution :)

I’d suggest switching to FSDP’s distributed checkpointing, as it could make saving/loading the optimizer and LLM weights noticeably faster and reduce peak memory use. Happy to prototype this if you want.

We can also let FSDP write safetensors directly (see the PyTorch blog on HuggingFace safetensors support) Example: link.

Hi thanks for your suggestion! I think we should follow this tutorial and example too, unless there would be OOM for large model with current logic @ppraneth

@leng-yue
Copy link

leng-yue commented Oct 6, 2025

Thanks for your great contribution :)
I’d suggest switching to FSDP’s distributed checkpointing, as it could make saving/loading the optimizer and LLM weights noticeably faster and reduce peak memory use. Happy to prototype this if you want.
We can also let FSDP write safetensors directly (see the PyTorch blog on HuggingFace safetensors support) Example: link.

Hi thanks for your suggestion! I think we should follow this tutorial and example too, unless there would be OOM for large model with current logic @ppraneth

For reference, with the current loading logic a 235B model is likely to OOM on most machines. The bf16 weights alone are ~470 GB (235B × 2 bytes), and with typical 8 GPUs the effective memory can exceed ~3.7 TB.

Once this PR lands, I’m happy to follow up with a change that switches us to FSDP distributed checkpointing (writing safetensors directly) to make load/save faster and reduce peak memory.

@ppraneth
Copy link
Contributor Author

ppraneth commented Oct 6, 2025

Thanks for your great contribution :)
I’d suggest switching to FSDP’s distributed checkpointing, as it could make saving/loading the optimizer and LLM weights noticeably faster and reduce peak memory use. Happy to prototype this if you want.
We can also let FSDP write safetensors directly (see the PyTorch blog on HuggingFace safetensors support) Example: link.

Hi thanks for your suggestion! I think we should follow this tutorial and example too, unless there would be OOM for large model with current logic @ppraneth

I will try to implement it

@ppraneth ppraneth requested a review from leng-yue October 11, 2025 12:08
Copy link

@leng-yue leng-yue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tests/test1.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! But can you remove these tests like tests/test1.py and tests/test2.py, however, you can use these tests to test locally and paste the results in the PR

@ppraneth
Copy link
Contributor Author

image image

@PopSoda2002 Testing is done

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Oct 25, 2025

this is pretty useful for long runs

@ppraneth
Copy link
Contributor Author

@zhuzilin Can you check this

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Oct 28, 2025

@ppraneth hi could you please handle the lint error, and may I know whether it is ready to use (e.g. is it tested and looks good) or still has some ongoing work/fix?

if not optimizer_state_dict or not optimizer_state_dict.get("state", {}):
raise ValueError(f"Optimizer state dictionary is empty for iteration {iteration}")

use_safetensors = getattr(args, "save_safe_serialization", False)
Copy link
Collaborator

@fzyzcjy fzyzcjy Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit

Suggested change
use_safetensors = getattr(args, "save_safe_serialization", False)
use_safetensors = args.save_safe_serialization

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super tiny nits

Comment on lines +37 to +51
if use_safetensors:
try:
from torch.distributed.checkpoint import HuggingFaceStorageWriter

model_writer = HuggingFaceStorageWriter(
path=model_subdir, fqn_to_index_mapping={k: 0 for k in model_state_dict.keys()}
)
dist_cp.save(state_dict=model_state_dict, storage_writer=model_writer)
except ImportError as e:
raise ImportError(
"Safetensors library is required when save_safe_serialization is True, but it is not installed."
) from e
else:
model_writer = dist_cp.FileSystemWriter(model_subdir)
dist_cp.save(state_dict={"model": model_state_dict}, storage_writer=model_writer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit:

if use_safetensors:
  model_writer = HuggingFaceStorageWriter(..)
else:
  model_writer = dist_cp.FileSystemWriter(..)
dist_cp.save(..)

Comment on lines +114 to +129
# Load model
if is_safetensors:
try:
from torch.distributed.checkpoint import HuggingFaceStorageReader

model_storage_reader = HuggingFaceStorageReader(path=model_subdir)
dist_cp.load(state_dict=model_state_dict, storage_reader=model_storage_reader)
except ImportError as e:
raise ImportError(
"Safetensors library is required to load safetensors checkpoint files, but it is not installed."
) from e
else:
model_state_dict = {"model": model_state_dict}
model_storage_reader = dist_cp.FileSystemReader(model_subdir)
dist_cp.load(state_dict=model_state_dict, storage_reader=model_storage_reader)
model_state_dict = model_state_dict["model"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same)

model_state_dict = model_state_dict["model"]

# Load optimizer (always standard format)
optim_state_dict = {"optim": optimizer_state_dict}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether we can remove that extra "optim" nested key

Comment on lines +160 to +165
# Broadcast to all ranks
state_t = torch.tensor([0, 0], dtype=torch.int64, device="cpu")
if dist.get_rank() == 0:
state_t[0] = loaded_iteration
state_t[1] = global_step
dist.broadcast(state_t, src=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: dist.broadcast_object_list

@zhuzilin
Copy link
Contributor

close as solved by #633

@zhuzilin zhuzilin closed this Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FSDP] Support save/load checkpointing

6 participants