Skip to content

Commit b120ea9

Browse files
Fix dataloader reload schedule when resuming from checkpoint (#21492)
1 parent bb7820f commit b120ea9

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,13 @@ def on_save_checkpoint(self) -> dict:
522522
state_dict = super().on_save_checkpoint()
523523
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
524524
state_dict["combined_loader"] = loader_states
525+
state_dict["last_train_dl_reload_epoch"] = self._last_train_dl_reload_epoch
525526
return state_dict
526527

527528
@override
528529
def on_load_checkpoint(self, state_dict: dict) -> None:
529530
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
531+
self._last_train_dl_reload_epoch = state_dict.get("last_train_dl_reload_epoch", float("-inf"))
530532
super().on_load_checkpoint(state_dict)
531533

532534
def _warn_if_modules_in_eval_mode(self) -> None:

0 commit comments

Comments
 (0)