Skip to content

Commit 23b1831

Browse files
committed
[Enhance] When ckpt is saved, do not save snapshot
1 parent d5580f3 commit 23b1831

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

xtuner/v1/train/trainer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,9 @@ def fit(self):
587587

588588
self._lr_scheduler.step()
589589
self._maybe_save_hf()
590-
self._maybe_save(is_snapshot=False)
591-
self._maybe_save(is_snapshot=True)
590+
ckpt_saved = self._maybe_save(is_snapshot=False)
591+
if not ckpt_saved:
592+
_ = self._maybe_save(is_snapshot=True)
592593

593594
time_before_get_data = time.time()
594595

@@ -805,19 +806,19 @@ def warmup_fn(x):
805806
)
806807
return lr_scheduler
807808

808-
def _maybe_save(self, is_snapshot: bool = False):
809+
def _maybe_save(self, is_snapshot: bool = False) -> bool:
809810
ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval
810811
if ckp_interval is None:
811-
return
812+
return False
812813

813814
if ckp_interval == -1: # only save at the end of training
814815
if self._cur_step != self.total_step:
815-
return
816+
return False
816817
else:
817818
if self.cur_step % ckp_interval != 0 and (is_snapshot or self._cur_step != self.total_step):
818819
# if is_snapshot, only save at interval
819820
# else save at interval or at the end of training
820-
return
821+
return False
821822

822823
checkpoint_path = self._get_checkpoint_path(epoch=self._cur_epoch, step=self.cur_step, is_snapshot=is_snapshot)
823824
checkpoint_path.mkdir(parents=True, exist_ok=True)
@@ -901,6 +902,7 @@ def _maybe_save(self, is_snapshot: bool = False):
901902
f.write(self.meta.model_dump_json(indent=2))
902903

903904
dist.barrier()
905+
return True
904906

905907
def _save_dataloader(self, dataloader_path: Path | str):
906908
_gathered_list = [None for _ in range(self.data_mesh["dp"].size())]

0 commit comments

Comments
 (0)