diff --git a/trainer/io.py b/trainer/io.py index 6e08aea..dfda581 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -180,7 +180,7 @@ def save_best_model( save_func=None, **kwargs, ): - if current_loss < best_loss: + if current_loss < best_loss and current_step > keep_after: best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) logger.info(" > BEST MODEL : %s", checkpoint_path)