diff --git a/trainer/trainer.py b/trainer/trainer.py index cc74024..06d2d70 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -1724,7 +1724,7 @@ def _restore_best_loss(self): logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path)) ch = load_fsspec(self.args.restore_path, map_location="cpu") if "model_loss" in ch: - self.best_loss = ch["model_loss"] + self.best_loss = ch["model_loss"]["eval_loss"] logger.info(" > Starting with loaded last best loss %f", self.best_loss) def test(self, model=None, test_samples=None) -> None: