1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import gc
1617import logging
1718import os
19+ import shutil
1820from dataclasses import dataclass , field
1921from pathlib import Path
2022from typing import NamedTuple
2123
2224import torch
23- import torch .distributed .checkpoint as dcp
2425import transformers
2526from safetensors .torch import save_file
2627from torch .distributed .checkpoint .state_dict import (
2930 get_state_dict ,
3031 set_state_dict ,
3132)
33+ from torch .distributed .checkpoint .state_dict_loader import load as dcp_load
34+ from torch .distributed .checkpoint .state_dict_saver import async_save as dcp_async_save
35+ from torch .distributed .checkpoint .state_dict_saver import save as dcp_save
3236from torch .distributed .checkpoint .stateful import Stateful
3337from torchdata .stateful_dataloader import StatefulDataLoader
3438
3539from distributed_config import DistributedConfig
3640
3741
3842logger = logging .getLogger (__name__ )
43+ _ckpt_futures : dict = {}
3944
4045
4146class CheckpointOutput (NamedTuple ):
@@ -82,6 +87,20 @@ def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool:
8287 return False
8388
8489
90+ def prune_checkpoints (ckpt_path : str | os .PathLike , max_checkpoints : int ) -> None :
91+ """Prune checkpoints to keep only the latest `max_checkpoints` checkpoints."""
92+ ckpt_path = Path (ckpt_path )
93+ checkpoints = [f for f in ckpt_path .iterdir () if f .name .startswith ("step_" )]
94+ checkpoints .sort (key = lambda x : int (Path (x ).stem .split ("_" )[1 ]))
95+ if len (checkpoints ) > max_checkpoints :
96+ for checkpoint in checkpoints [:- max_checkpoints ]:
97+ logger .info (f"Pruning checkpoint { checkpoint } " )
98+ if checkpoint .is_dir ():
99+ shutil .rmtree (checkpoint )
100+ else :
101+ os .remove (checkpoint )
102+
103+
85104# ============================================================================
86105# DDP Checkpointing
87106# ============================================================================
@@ -131,6 +150,7 @@ def save_checkpoint_ddp(
131150 epoch : int ,
132151 dist_config : DistributedConfig ,
133152 dataloader : StatefulDataLoader | None = None ,
153+ max_checkpoints : int | None = None ,
134154) -> None :
135155 """Saves the Dataloader state and the DDP checkpoint."""
136156 ckpt_path = Path (ckpt_path )
@@ -157,6 +177,9 @@ def save_checkpoint_ddp(
157177
158178 logger .info (f"Saved DDP checkpoint to { checkpoint_path } " )
159179
180+ if max_checkpoints is not None and dist_config .is_main_process ():
181+ prune_checkpoints (ckpt_path , max_checkpoints )
182+
160183
161184def save_final_model_ddp (
162185 model : torch .nn .Module ,
@@ -243,6 +266,7 @@ def save_checkpoint_mfsdp(
243266 dist_config : DistributedConfig ,
244267 dataloader : StatefulDataLoader | None = None ,
245268 epoch : int = 0 ,
269+ max_checkpoints : int | None = None ,
246270) -> None :
247271 """Save mFSDP distributed checkpoint.
248272
@@ -255,6 +279,7 @@ def save_checkpoint_mfsdp(
255279 dist_config: The distributed configuration.
256280 dataloader: The dataloader to save.
257281 epoch: The epoch number to save the checkpoint.
282+ max_checkpoints: The maximum number of checkpoints to keep.
258283 """
259284 ckpt_path = Path (ckpt_path )
260285 checkpoint_path = ckpt_path / f"step_{ step } "
@@ -279,6 +304,9 @@ def save_checkpoint_mfsdp(
279304 if dist_config .is_main_process ():
280305 logger .info (f"Saved mFSDP checkpoint to { checkpoint_path } " )
281306
307+ if max_checkpoints is not None and dist_config .is_main_process ():
308+ prune_checkpoints (ckpt_path , max_checkpoints )
309+
282310
283311def save_final_model_mfsdp (
284312 model : torch .nn .Module ,
@@ -369,6 +397,7 @@ def load_checkpoint_fsdp2(
369397 ckpt_path : str | os .PathLike ,
370398 dist_config : DistributedConfig ,
371399 dataloader : StatefulDataLoader | None = None ,
400+ process_group : torch .distributed .ProcessGroup | None = None ,
372401) -> CheckpointOutput :
373402 """Load FSDP2 checkpoint.
374403
@@ -379,6 +408,7 @@ def load_checkpoint_fsdp2(
379408 ckpt_path: The directory containing checkpoints.
380409 dist_config: The distributed configuration.
381410 dataloader: The dataloader to load.
411+ process_group: The process group to use for checkpointing.
382412 """
383413 checkpoint_path , _ = get_latest_checkpoint (ckpt_path )
384414 if not checkpoint_path :
@@ -392,7 +422,7 @@ def load_checkpoint_fsdp2(
392422 )
393423
394424 state_dict = {"app" : app_state }
395- dcp . load (state_dict , checkpoint_id = checkpoint_path )
425+ dcp_load (state_dict , checkpoint_id = checkpoint_path , process_group = process_group )
396426
397427 if dataloader is not None :
398428 load_dataloader (
@@ -416,6 +446,9 @@ def save_checkpoint_fsdp2(
416446 epoch : int ,
417447 dist_config : DistributedConfig ,
418448 dataloader : StatefulDataLoader | None = None ,
449+ process_group : torch .distributed .ProcessGroup | None = None ,
450+ max_checkpoints : int | None = None ,
451+ async_save : bool = False ,
419452) -> None :
420453 """Save FSDP2 checkpoint.
421454
@@ -428,6 +461,9 @@ def save_checkpoint_fsdp2(
428461 epoch: The epoch number to save the checkpoint.
429462 dist_config: The distributed configuration.
430463 dataloader: The dataloader to save.
464+ process_group: The process group to use for checkpointing.
465+ max_checkpoints: The maximum number of checkpoints to keep.
466+ async_save: Whether to save the checkpoint asynchronously.
431467 """
432468 ckpt_path = Path (ckpt_path )
433469 checkpoint_path = ckpt_path / f"step_{ step } "
@@ -441,17 +477,24 @@ def save_checkpoint_fsdp2(
441477 )
442478 logger .info (f"Saved FSDP2 dataloader to { ckpt_path } " )
443479
444- state_dict = {
445- "app" : AppState (
446- model = model ,
447- optimizer = optimizer ,
448- scheduler = scheduler ,
449- step = step ,
450- epoch = epoch ,
451- )
452- }
453- dcp .save (state_dict = state_dict , checkpoint_id = checkpoint_path )
454- logger .info (f"Saved distributed FSDP2 checkpoint to { checkpoint_path } " )
480+ # If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
481+ if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures ["fsdp2" ] is not None :
482+ _ckpt_futures ["fsdp2" ].result ()
483+
484+ # Clear GPU cache before checkpointing to free up fragmented memory.
485+ gc .collect ()
486+ torch .cuda .empty_cache ()
487+ torch .distributed .barrier (group = process_group )
488+
489+ state_dict = {"app" : AppState (model = model , optimizer = optimizer , scheduler = scheduler , step = step , epoch = epoch )}
490+ ckpt_save_func = dcp_async_save if async_save else dcp_save
491+ _ckpt_futures ["fsdp2" ] = ckpt_save_func (state_dict , checkpoint_id = checkpoint_path , process_group = process_group )
492+
493+ if dist_config .is_main_process ():
494+ logger .info (f"Saved distributed FSDP2 checkpoint to { checkpoint_path } " )
495+
496+ if max_checkpoints is not None and dist_config .is_main_process ():
497+ prune_checkpoints (ckpt_path , max_checkpoints )
455498
456499
457500def save_final_model_fsdp2 (
0 commit comments