Skip to content

Commit 385cced

Browse files
authored
Merge pull request #129 from coqui-ai/fix_eval
Multiples bug fixes and add on_train_epoch_start callback
2 parents 47781f5 + 5b3cb63 commit 385cced

File tree

4 files changed

+76
-17
lines changed

4 files changed

+76
-17
lines changed

.github/workflows/pypi-release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
build-sdist:
1111
runs-on: ubuntu-20.04
1212
steps:
13-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v3
1414
- name: Verify tag matches version
1515
run: |
1616
set -ex
@@ -19,7 +19,7 @@ jobs:
1919
if [[ "$version" != "$tag" ]]; then
2020
exit 1
2121
fi
22-
- uses: actions/setup-python@v2
22+
- uses: actions/checkout@v3
2323
with:
2424
python-version: 3.9
2525
- run: |

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ jobs:
2121
python-version: [3.8, 3.9, "3.10", "3.11"]
2222
experimental: [false]
2323
steps:
24-
- uses: actions/checkout@v2
24+
- uses: actions/checkout@v3
2525
- name: Set up Python ${{ matrix.python-version }}
26-
uses: coqui-ai/setup-python@pip-cache-key-py-ver
26+
uses: actions/setup-python@v4
2727
with:
2828
python-version: ${{ matrix.python-version }}
2929
architecture: x64

trainer/callbacks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ def __init__(self) -> None:
77
self.callbacks_on_init_end = []
88
self.callbacks_on_epoch_start = []
99
self.callbacks_on_epoch_end = []
10+
self.callbacks_on_train_epoch_start = []
11+
self.callbacks_on_train_epoch_end = []
1012
self.callbacks_on_train_step_start = []
1113
self.callbacks_on_train_step_end = []
1214
self.callbacks_on_keyboard_interrupt = []
@@ -21,6 +23,10 @@ def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None:
2123
self.callbacks_on_epoch_start.append(value)
2224
elif key == "on_epoch_end":
2325
self.callbacks_on_epoch_end.append(value)
26+
elif key == "on_train_epoch_start":
27+
self.callbacks_on_train_epoch_start.append(value)
28+
elif key == "on_train_epoch_end":
29+
self.callbacks_on_train_epoch_end.append(value)
2430
elif key == "on_train_step_start":
2531
self.callbacks_on_train_step_start.append(value)
2632
elif key == "on_train_step_end":
@@ -102,6 +108,42 @@ def on_epoch_end(self, trainer) -> None:
102108
for callback in self.callbacks_on_epoch_end:
103109
callback(trainer)
104110

111+
def on_train_epoch_start(self, trainer) -> None:
112+
if hasattr(trainer.model, "module"):
113+
if hasattr(trainer.model.module, "on_train_epoch_start"):
114+
trainer.model.module.on_train_epoch_start(trainer)
115+
else:
116+
if hasattr(trainer.model, "on_train_epoch_start"):
117+
trainer.model.on_train_epoch_start(trainer)
118+
119+
if hasattr(trainer.criterion, "on_train_epoch_start"):
120+
trainer.criterion.on_train_epoch_start(trainer)
121+
122+
if hasattr(trainer.optimizer, "on_train_epoch_start"):
123+
trainer.optimizer.on_train_epoch_start(trainer)
124+
125+
if self.callbacks_on_train_epoch_start:
126+
for callback in self.callbacks_on_train_epoch_start:
127+
callback(trainer)
128+
129+
def on_train_epoch_end(self, trainer) -> None:
130+
if hasattr(trainer.model, "module"):
131+
if hasattr(trainer.model.module, "on_train_epoch_end"):
132+
trainer.model.module.on_train_epoch_end(trainer)
133+
else:
134+
if hasattr(trainer.model, "on_train_epoch_end"):
135+
trainer.model.on_train_epoch_end(trainer)
136+
137+
if hasattr(trainer.criterion, "on_train_epoch_end"):
138+
trainer.criterion.on_train_epoch_end(trainer)
139+
140+
if hasattr(trainer.optimizer, "on_train_epoch_end"):
141+
trainer.optimizer.on_train_epoch_end(trainer)
142+
143+
if self.callbacks_on_train_epoch_end:
144+
for callback in self.callbacks_on_train_epoch_end:
145+
callback(trainer)
146+
105147
@staticmethod
106148
def before_backward_pass(trainer, loss_dict) -> None:
107149
if hasattr(trainer.model, "module"):

trainer/trainer.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ def __init__( # pylint: disable=dangerous-default-value
443443
if not self.config.log_model_step:
444444
self.config.log_model_step = self.config.save_step
445445

446+
# make sure that start_with_eval is disabled if eval is disabled
447+
if not self.config.run_eval and self.start_with_eval:
448+
self.start_with_eval = False
449+
446450
self.total_steps_done = 0
447451
self.epochs_done = 0
448452
self.restore_step = 0
@@ -525,6 +529,16 @@ def __init__( # pylint: disable=dangerous-default-value
525529
# setup optimizer
526530
self.optimizer = self.get_optimizer(self.model, self.config)
527531

532+
# If multiple-optimizer setup with grad accumulation and without custom optimize method raise an error
533+
if (
534+
self.grad_accum_steps != 1
535+
and isinstance(self.optimizer, list)
536+
and not isimplemented(self.model, "optimize")
537+
):
538+
raise ValueError(
539+
" [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called ´optimize` that need to deal with dangling gradients in multiple-optimizer setup!"
540+
)
541+
528542
# CALLBACK
529543
self.callbacks = TrainerCallback()
530544
self.callbacks.parse_callbacks_dict(callbacks)
@@ -1480,6 +1494,8 @@ def train_epoch(self) -> None:
14801494
self.model.train()
14811495
epoch_start_time = time.time()
14821496

1497+
self.callbacks.on_train_epoch_start(self)
1498+
14831499
self.c_logger.print_train_start()
14841500
loader_start_time = time.time()
14851501
# TRAINING EPOCH -> iterate over the training samples
@@ -1502,6 +1518,8 @@ def train_epoch(self) -> None:
15021518
torch.set_grad_enabled(True)
15031519

15041520
epoch_time = time.time() - epoch_start_time
1521+
self.callbacks.on_train_epoch_end(self)
1522+
15051523
# scheduler step
15061524
if self.scheduler is not None and self.config.scheduler_after_epoch:
15071525
if isinstance(self.scheduler, list):
@@ -1884,14 +1902,12 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
18841902
def save_best_model(self) -> None:
18851903
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
18861904

1887-
eval_loss = None
1888-
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
1889-
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
1905+
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
18901906
train_loss = self._pick_target_avg_loss(self.keep_avg_train)
18911907

18921908
# save the model and update the best_loss
18931909
self.best_loss = save_best_model(
1894-
train_loss if eval_loss is None else eval_loss,
1910+
eval_loss if eval_loss else train_loss,
18951911
self.best_loss,
18961912
self.config,
18971913
self.model,
@@ -1908,9 +1924,7 @@ def save_best_model(self) -> None:
19081924
@rank_zero_only
19091925
def save_checkpoint(self) -> None:
19101926
"""Save the current model checkpoint."""
1911-
eval_loss = None
1912-
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
1913-
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
1927+
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
19141928
train_loss = self._pick_target_avg_loss(self.keep_avg_train)
19151929

19161930
save_checkpoint(
@@ -2101,18 +2115,21 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict:
21012115

21022116
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
21032117
"""Pick the target loss to compare models"""
2118+
2119+
# if the keep_avg_target is None or empty return None
2120+
if keep_avg_target is None or len(list(keep_avg_target.avg_values.keys())) == 0:
2121+
return None
2122+
21042123
target_avg_loss = None
21052124
# return if target loss defined in the model config
21062125
# if not available in Dict use loss_1 as by default loss
21072126
if "target_loss" in self.config and self.config.target_loss:
21082127
if f"avg_{self.config.target_loss}" in keep_avg_target.avg_values.keys():
21092128
return keep_avg_target[f"avg_{self.config.target_loss}"]
2110-
target_loss = keep_avg_target["avg_loss_1"]
2111-
if target_loss is None:
2112-
raise ValueError(
2113-
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
2114-
)
2115-
return target_loss
2129+
2130+
raise ValueError(
2131+
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
2132+
)
21162133

21172134
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
21182135
if isinstance(self.optimizer, list):

0 commit comments

Comments
 (0)