[FSDP] [feature]Support save/load checkpointing#406
[FSDP] [feature]Support save/load checkpointing#406ppraneth wants to merge 52 commits intoTHUDM:mainfrom
Conversation
|
Can you share some test case and result? |
I’ve written a test case, but I’m not sure if it’s correct. Since I don’t have access to GPUs, could you please check it and run it? |
synced in slack |
|
Thanks for your great contribution :) I’d suggest switching to FSDP’s distributed checkpointing, as it could make saving/loading the optimizer and LLM weights noticeably faster and reduce peak memory use. Happy to prototype this if you want. We can also let FSDP write |
Hi thanks for your suggestion! I think we should follow this tutorial and example too, unless there would be OOM for large model with current logic @ppraneth |
For reference, with the current loading logic a 235B model is likely to OOM on most machines. The bf16 weights alone are ~470 GB (235B × 2 bytes), and with typical 8 GPUs the effective memory can exceed ~3.7 TB. Once this PR lands, I’m happy to follow up with a change that switches us to FSDP distributed checkpointing (writing |
I will try to implement it |
tests/test1.py
Outdated
There was a problem hiding this comment.
LGTM! But can you remove these tests like tests/test1.py and tests/test2.py, however, you can use these tests to test locally and paste the results in the PR
…into support-fsdp_save
@PopSoda2002 Testing is done |
|
this is pretty useful for long runs |
|
@zhuzilin Can you check this |
|
@ppraneth hi could you please handle the lint error, and may I know whether it is ready to use (e.g. is it tested and looks good) or still has some ongoing work/fix? |
| if not optimizer_state_dict or not optimizer_state_dict.get("state", {}): | ||
| raise ValueError(f"Optimizer state dictionary is empty for iteration {iteration}") | ||
|
|
||
| use_safetensors = getattr(args, "save_safe_serialization", False) |
There was a problem hiding this comment.
tiny nit
| use_safetensors = getattr(args, "save_safe_serialization", False) | |
| use_safetensors = args.save_safe_serialization |
| if use_safetensors: | ||
| try: | ||
| from torch.distributed.checkpoint import HuggingFaceStorageWriter | ||
|
|
||
| model_writer = HuggingFaceStorageWriter( | ||
| path=model_subdir, fqn_to_index_mapping={k: 0 for k in model_state_dict.keys()} | ||
| ) | ||
| dist_cp.save(state_dict=model_state_dict, storage_writer=model_writer) | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Safetensors library is required when save_safe_serialization is True, but it is not installed." | ||
| ) from e | ||
| else: | ||
| model_writer = dist_cp.FileSystemWriter(model_subdir) | ||
| dist_cp.save(state_dict={"model": model_state_dict}, storage_writer=model_writer) |
There was a problem hiding this comment.
tiny nit:
if use_safetensors:
model_writer = HuggingFaceStorageWriter(..)
else:
model_writer = dist_cp.FileSystemWriter(..)
dist_cp.save(..)
| # Load model | ||
| if is_safetensors: | ||
| try: | ||
| from torch.distributed.checkpoint import HuggingFaceStorageReader | ||
|
|
||
| model_storage_reader = HuggingFaceStorageReader(path=model_subdir) | ||
| dist_cp.load(state_dict=model_state_dict, storage_reader=model_storage_reader) | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Safetensors library is required to load safetensors checkpoint files, but it is not installed." | ||
| ) from e | ||
| else: | ||
| model_state_dict = {"model": model_state_dict} | ||
| model_storage_reader = dist_cp.FileSystemReader(model_subdir) | ||
| dist_cp.load(state_dict=model_state_dict, storage_reader=model_storage_reader) | ||
| model_state_dict = model_state_dict["model"] |
| model_state_dict = model_state_dict["model"] | ||
|
|
||
| # Load optimizer (always standard format) | ||
| optim_state_dict = {"optim": optimizer_state_dict} |
There was a problem hiding this comment.
nit: wondering whether we can remove that extra "optim" nested key
| # Broadcast to all ranks | ||
| state_t = torch.tensor([0, 0], dtype=torch.int64, device="cpu") | ||
| if dist.get_rank() == 0: | ||
| state_t[0] = loaded_iteration | ||
| state_t[1] = global_step | ||
| dist.broadcast(state_t, src=0) |
There was a problem hiding this comment.
tiny nit: dist.broadcast_object_list
|
close as solved by #633 |


Hi everyone,
This pr add load/save feature in FSDP and also closes #402
Here are the main changes I've made:
Checkpointing Logic: I added all the logic for saving and loading into a new file,
slime/backends/fsdp_utils/checkpoint.py, to keep things organized. This handles saving the model, tokenizer, and optimizer state.Actor Integration: I updated the
FSDPTrainRayActorinslime/backends/fsdp_utils/actor.pyto use the new functions. It now correctly loads the state from a checkpoint at the start of a run and saves progress whensave_modelis called.New Arguments: To control this, I added a few command-line arguments to
slime/backends/fsdp_utils/arguments.py:--save,--load,--save-safe-serializationand a safety flag--overwrite-checkpointsso you don't accidentally overwrite your work.Bug Fixes: While building this, I also fixed a couple of bugs in the
initprocess to prevent the model from being loaded twice and to make sure theglobal_stepcounter is restored correctly.Please check it out and let me know if any changes are needed. I'm happy to make them! Thanks for taking a look.