Skip to content

Commit 2d86257

Browse files
authored
Merge pull request #133 from coqui-ai/revert-131-fix-continue
Revert "fix: make --continue_path work again"
2 parents 53d7345 + 695a699 commit 2d86257

File tree

3 files changed

+6
-27
lines changed

3 files changed

+6
-27
lines changed

tests/test_continue_train.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,8 @@ def test_continue_train():
1414
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
1515
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))
1616

17-
# Continue training from the best model
18-
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
17+
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}"
1918
run_cli(command_continue)
2019

2120
assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))
22-
23-
# Continue training from the last checkpoint
24-
for best in glob.glob(os.path.join(continue_path, "best_model*")):
25-
os.remove(best)
26-
run_cli(command_continue)
27-
28-
# Continue training from a specific checkpoint
29-
restore_path = os.path.join(continue_path, "checkpoint_5.pth")
30-
command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}"
31-
run_cli(command_continue)
3221
shutil.rmtree(continue_path)

trainer/io.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,7 @@ def save_best_model(
180180
save_func=None,
181181
**kwargs,
182182
):
183-
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
184-
if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
185-
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
186-
):
183+
if current_loss < best_loss:
187184
best_model_name = f"best_model_{current_step}.pth"
188185
checkpoint_path = os.path.join(out_path, best_model_name)
189186
logger.info(" > BEST MODEL : %s", checkpoint_path)

trainer/trainer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value
451451
self.epochs_done = 0
452452
self.restore_step = 0
453453
self.restore_epoch = 0
454-
self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None}
454+
self.best_loss = float("inf")
455455
self.train_loader = None
456456
self.test_loader = None
457457
self.eval_loader = None
@@ -1724,15 +1724,8 @@ def _restore_best_loss(self):
17241724
logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path))
17251725
ch = load_fsspec(self.args.restore_path, map_location="cpu")
17261726
if "model_loss" in ch:
1727-
if isinstance(ch["model_loss"], dict):
1728-
self.best_loss = ch["model_loss"]
1729-
# For backwards-compatibility:
1730-
elif isinstance(ch["model_loss"], float):
1731-
if self.config.run_eval:
1732-
self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]}
1733-
else:
1734-
self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None}
1735-
logger.info(" > Starting with loaded last best loss %s", self.best_loss)
1727+
self.best_loss = ch["model_loss"]
1728+
logger.info(" > Starting with loaded last best loss %f", self.best_loss)
17361729

17371730
def test(self, model=None, test_samples=None) -> None:
17381731
"""Run evaluation steps on the test data split. You can either provide the model and the test samples
@@ -1914,7 +1907,7 @@ def save_best_model(self) -> None:
19141907

19151908
# save the model and update the best_loss
19161909
self.best_loss = save_best_model(
1917-
{"train_loss": train_loss, "eval_loss": eval_loss},
1910+
eval_loss if eval_loss else train_loss,
19181911
self.best_loss,
19191912
self.config,
19201913
self.model,

0 commit comments

Comments
 (0)