diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65c0ea6e..ba59f693e67bb 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -238,6 +238,16 @@ def __init__( self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end self._enable_version_counter = enable_version_counter + self.dirpath: Optional[_PATH] = dirpath + self.filename = filename + self.kth_value: Optional[Tensor] = None + self._mode = mode + + self.__init_state() + self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) + self.__validate_init_configuration() + + def __init_state(self) -> None: self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None @@ -248,19 +258,12 @@ def __init__( self.last_model_path = "" self._last_checkpoint_saved = "" - self.kth_value: Tensor - self.dirpath: Optional[_PATH] - self.__init_monitor_mode(mode) - self.__init_ckpt_dir(dirpath, filename) - self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) - self.__validate_init_configuration() - @property @override def state_key(self) -> str: return self._generate_state_key( monitor=self.monitor, - mode=self.mode, + mode=self._mode, every_n_train_steps=self._every_n_train_steps, every_n_epochs=self._every_n_epochs, train_time_interval=self._train_time_interval, @@ -268,6 +271,10 @@ def state_key(self) -> str: @override def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.__init_state() + self.__set_monitor_mode(self._mode) + self.__set_ckpt_dir(self.dirpath, self.filename) + dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) self.dirpath = dirpath @@ -469,7 +476,7 @@ def __validate_init_configuration(self) -> None: " configuration. No quantity for top_k to track." ) - def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + def __set_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): @@ -478,7 +485,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> self.dirpath = dirpath self.filename = filename - def __init_monitor_mode(self, mode: str) -> None: + def __set_monitor_mode(self, mode: str) -> None: torch_inf = torch.tensor(torch.inf) mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")}