Skip to content

Commit 8d4849c

Browse files
authored
fix: make --continue_path work again (#131)
* fix: make --continue_path work again There were errors when loading models with `--continue_path` because #121 changed https://github.com/coqui-ai/Trainer/blob/47781f58d2714d8139dc00f57dbf64bcc14402b7/trainer/trainer.py#L1924 to save the `model_loss` as `{"train_loss": train_loss, "eval_loss": eval_loss}` instead of just a float. https://github.com/coqui-ai/Trainer/blob/47781f58d2714d8139dc00f57dbf64bcc14402b7/trainer/io.py#L195 still saves a float in `model_loss`, so loading the best model would still work fine. Loading a model via `--restore-path` also works fine because in that case the best loss is reset and not initialised from the saved model. This fix: - changes `save_best_model()` to also save a dict with train and eval loss, so that this is consistent everywhere - ensures that the model loader can handle both float and dict `model_loss` for backwards compatibility - adds relevant test cases * fixup! fix: make --continue_path work again
1 parent 463e763 commit 8d4849c

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

tests/test_continue_train.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,19 @@ def test_continue_train():
1414
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
1515
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))
1616

17-
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}"
17+
# Continue training from the best model
18+
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
1819
run_cli(command_continue)
1920

2021
assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))
22+
23+
# Continue training from the last checkpoint
24+
for best in glob.glob(os.path.join(continue_path, "best_model*")):
25+
os.remove(best)
26+
run_cli(command_continue)
27+
28+
# Continue training from a specific checkpoint
29+
restore_path = os.path.join(continue_path, "checkpoint_5.pth")
30+
command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}"
31+
run_cli(command_continue)
2132
shutil.rmtree(continue_path)

trainer/io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ def save_best_model(
180180
save_func=None,
181181
**kwargs,
182182
):
183-
if current_loss < best_loss:
183+
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
184+
if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
185+
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
186+
):
184187
best_model_name = f"best_model_{current_step}.pth"
185188
checkpoint_path = os.path.join(out_path, best_model_name)
186189
logger.info(" > BEST MODEL : %s", checkpoint_path)

trainer/trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value
451451
self.epochs_done = 0
452452
self.restore_step = 0
453453
self.restore_epoch = 0
454-
self.best_loss = float("inf")
454+
self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None}
455455
self.train_loader = None
456456
self.test_loader = None
457457
self.eval_loader = None
@@ -1724,8 +1724,15 @@ def _restore_best_loss(self):
17241724
logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path))
17251725
ch = load_fsspec(self.args.restore_path, map_location="cpu")
17261726
if "model_loss" in ch:
1727-
self.best_loss = ch["model_loss"]
1728-
logger.info(" > Starting with loaded last best loss %f", self.best_loss)
1727+
if isinstance(ch["model_loss"], dict):
1728+
self.best_loss = ch["model_loss"]
1729+
# For backwards-compatibility:
1730+
elif isinstance(ch["model_loss"], float):
1731+
if self.config.run_eval:
1732+
self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]}
1733+
else:
1734+
self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None}
1735+
logger.info(" > Starting with loaded last best loss %s", self.best_loss)
17291736

17301737
def test(self, model=None, test_samples=None) -> None:
17311738
"""Run evaluation steps on the test data split. You can either provide the model and the test samples
@@ -1907,7 +1914,7 @@ def save_best_model(self) -> None:
19071914

19081915
# save the model and update the best_loss
19091916
self.best_loss = save_best_model(
1910-
eval_loss if eval_loss else train_loss,
1917+
{"train_loss": train_loss, "eval_loss": eval_loss},
19111918
self.best_loss,
19121919
self.config,
19131920
self.model,

0 commit comments

Comments
 (0)