@@ -587,8 +587,9 @@ def fit(self):
587587
588588 self ._lr_scheduler .step ()
589589 self ._maybe_save_hf ()
590- self ._maybe_save (is_snapshot = False )
591- self ._maybe_save (is_snapshot = True )
590+ ckpt_saved = self ._maybe_save (is_snapshot = False )
591+ if not ckpt_saved :
592+ _ = self ._maybe_save (is_snapshot = True )
592593
593594 time_before_get_data = time .time ()
594595
@@ -805,19 +806,19 @@ def warmup_fn(x):
805806 )
806807 return lr_scheduler
807808
808- def _maybe_save (self , is_snapshot : bool = False ):
809+ def _maybe_save (self , is_snapshot : bool = False ) -> bool :
809810 ckp_interval = self ._checkpoint_interval if not is_snapshot else self ._snapshot_interval
810811 if ckp_interval is None :
811- return
812+ return False
812813
813814 if ckp_interval == - 1 : # only save at the end of training
814815 if self ._cur_step != self .total_step :
815- return
816+ return False
816817 else :
817818 if self .cur_step % ckp_interval != 0 and (is_snapshot or self ._cur_step != self .total_step ):
818819 # if is_snapshot, only save at interval
819820 # else save at interval or at the end of training
820- return
821+ return False
821822
822823 checkpoint_path = self ._get_checkpoint_path (epoch = self ._cur_epoch , step = self .cur_step , is_snapshot = is_snapshot )
823824 checkpoint_path .mkdir (parents = True , exist_ok = True )
@@ -901,6 +902,7 @@ def _maybe_save(self, is_snapshot: bool = False):
901902 f .write (self .meta .model_dump_json (indent = 2 ))
902903
903904 dist .barrier ()
905+ return True
904906
905907 def _save_dataloader (self , dataloader_path : Path | str ):
906908 _gathered_list = [None for _ in range (self .data_mesh ["dp" ].size ())]
0 commit comments