Skip to content

Commit 50f1d8a

Browse files
authored
Handle no eval loss (#121)
* Handle no eval loss * Assert len(data_loader) > 0
1 parent 33bd187 commit 50f1d8a

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

trainer/trainer.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,10 @@ def _get_loader(
895895
loader = model.get_data_loader(
896896
config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus
897897
)
898+
899+
assert (
900+
len(loader) > 0
901+
), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. "
898902
return loader
899903

900904
def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader:
@@ -1210,11 +1214,9 @@ def optimize(
12101214
)
12111215

12121216
# skip the rest if not outputs from the model
1213-
if not outputs:
1214-
if loss_dict:
1215-
raise RuntimeError(" [!] Model must return outputs when losses are computed.")
1217+
if not loss_dict:
12161218
step_time = time.time() - step_start_time
1217-
return None, {}, step_time
1219+
return outputs, {}, step_time
12181220

12191221
grad_clip = self._set_grad_clip_per_optimizer(config=config, optimizer_idx=optimizer_idx)
12201222
# optimizer step
@@ -1758,9 +1760,9 @@ def _fit(self) -> None:
17581760
self.train_epoch()
17591761
if self.config.run_eval:
17601762
self.eval_epoch()
1761-
self.c_logger.print_epoch_end(self.epochs_done, self.keep_avg_eval.avg_values)
17621763
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
17631764
self.test_run()
1765+
17641766
self.c_logger.print_epoch_end(
17651767
epoch,
17661768
self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values,
@@ -1882,12 +1884,14 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
18821884
def save_best_model(self) -> None:
18831885
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
18841886

1885-
# set the target loss to choose the best model
1886-
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
1887+
eval_loss = None
1888+
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
1889+
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
1890+
train_loss = self._pick_target_avg_loss(self.keep_avg_train)
18871891

18881892
# save the model and update the best_loss
18891893
self.best_loss = save_best_model(
1890-
target_loss_dict,
1894+
train_loss if eval_loss is None else eval_loss,
18911895
self.best_loss,
18921896
self.config,
18931897
self.model,
@@ -1904,7 +1908,11 @@ def save_best_model(self) -> None:
19041908
@rank_zero_only
19051909
def save_checkpoint(self) -> None:
19061910
"""Save the current model checkpoint."""
1907-
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
1911+
eval_loss = None
1912+
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
1913+
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
1914+
train_loss = self._pick_target_avg_loss(self.keep_avg_train)
1915+
19081916
save_checkpoint(
19091917
self.config,
19101918
self.model,
@@ -1913,7 +1921,7 @@ def save_checkpoint(self) -> None:
19131921
self.total_steps_done,
19141922
self.epochs_done,
19151923
self.output_path,
1916-
model_loss=target_avg_loss,
1924+
model_loss={"train_loss": train_loss, "eval_loss": eval_loss},
19171925
save_n_checkpoints=self.config.save_n_checkpoints,
19181926
save_func=self.dashboard_logger.save_model,
19191927
)
@@ -2094,7 +2102,6 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict:
20942102
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
20952103
"""Pick the target loss to compare models"""
20962104
target_avg_loss = None
2097-
20982105
# return if target loss defined in the model config
20992106
# if not available in Dict use loss_1 as by default loss
21002107
if "target_loss" in self.config and self.config.target_loss:
@@ -2115,7 +2122,7 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
21152122
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
21162123
target_avg_loss /= len(self.optimizer)
21172124
else:
2118-
target_avg_loss = keep_avg_target["avg_loss"]
2125+
target_avg_loss = keep_avg_target.avg_values.get("avg_loss", 0)
21192126
return target_avg_loss
21202127

21212128
def _setup_logger_config(self, log_file: str) -> None:

0 commit comments

Comments
 (0)