@@ -895,6 +895,10 @@ def _get_loader(
895895 loader = model .get_data_loader (
896896 config = config , assets = assets , is_eval = is_eval , samples = samples , verbose = verbose , num_gpus = num_gpus
897897 )
898+
899+ assert (
900+ len (loader ) > 0
901+ ), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. "
898902 return loader
899903
900904 def get_train_dataloader (self , training_assets : Dict , samples : List , verbose : bool ) -> DataLoader :
@@ -1210,11 +1214,9 @@ def optimize(
12101214 )
12111215
12121216 # skip the rest if not outputs from the model
1213- if not outputs :
1214- if loss_dict :
1215- raise RuntimeError (" [!] Model must return outputs when losses are computed." )
1217+ if not loss_dict :
12161218 step_time = time .time () - step_start_time
1217- return None , {}, step_time
1219+ return outputs , {}, step_time
12181220
12191221 grad_clip = self ._set_grad_clip_per_optimizer (config = config , optimizer_idx = optimizer_idx )
12201222 # optimizer step
@@ -1758,9 +1760,9 @@ def _fit(self) -> None:
17581760 self .train_epoch ()
17591761 if self .config .run_eval :
17601762 self .eval_epoch ()
1761- self .c_logger .print_epoch_end (self .epochs_done , self .keep_avg_eval .avg_values )
17621763 if epoch >= self .config .test_delay_epochs and self .args .rank <= 0 :
17631764 self .test_run ()
1765+
17641766 self .c_logger .print_epoch_end (
17651767 epoch ,
17661768 self .keep_avg_eval .avg_values if self .config .run_eval else self .keep_avg_train .avg_values ,
@@ -1882,12 +1884,14 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
18821884 def save_best_model (self ) -> None :
18831885 """Save the best model. It only saves if the current target loss is smaller then the previous."""
18841886
1885- # set the target loss to choose the best model
1886- target_loss_dict = self ._pick_target_avg_loss (self .keep_avg_eval if self .keep_avg_eval else self .keep_avg_train )
1887+ eval_loss = None
1888+ if self .keep_avg_eval and len (self .keep_avg_eval .avg_values .keys ()) > 0 :
1889+ eval_loss = self ._pick_target_avg_loss (self .keep_avg_eval )
1890+ train_loss = self ._pick_target_avg_loss (self .keep_avg_train )
18871891
18881892 # save the model and update the best_loss
18891893 self .best_loss = save_best_model (
1890- target_loss_dict ,
1894+ train_loss if eval_loss is None else eval_loss ,
18911895 self .best_loss ,
18921896 self .config ,
18931897 self .model ,
@@ -1904,7 +1908,11 @@ def save_best_model(self) -> None:
19041908 @rank_zero_only
19051909 def save_checkpoint (self ) -> None :
19061910 """Save the current model checkpoint."""
1907- target_avg_loss = self ._pick_target_avg_loss (self .keep_avg_train )
1911+ eval_loss = None
1912+ if self .keep_avg_eval and len (self .keep_avg_eval .avg_values .keys ()) > 0 :
1913+ eval_loss = self ._pick_target_avg_loss (self .keep_avg_eval )
1914+ train_loss = self ._pick_target_avg_loss (self .keep_avg_train )
1915+
19081916 save_checkpoint (
19091917 self .config ,
19101918 self .model ,
@@ -1913,7 +1921,7 @@ def save_checkpoint(self) -> None:
19131921 self .total_steps_done ,
19141922 self .epochs_done ,
19151923 self .output_path ,
1916- model_loss = target_avg_loss ,
1924+ model_loss = { "train_loss" : train_loss , "eval_loss" : eval_loss } ,
19171925 save_n_checkpoints = self .config .save_n_checkpoints ,
19181926 save_func = self .dashboard_logger .save_model ,
19191927 )
@@ -2094,7 +2102,6 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict:
20942102 def _pick_target_avg_loss (self , keep_avg_target : KeepAverage ) -> Dict :
20952103 """Pick the target loss to compare models"""
20962104 target_avg_loss = None
2097-
20982105 # return if target loss defined in the model config
20992106 # if not available in Dict use loss_1 as by default loss
21002107 if "target_loss" in self .config and self .config .target_loss :
@@ -2115,7 +2122,7 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
21152122 target_avg_loss += keep_avg_target [f"avg_loss_{ idx } " ]
21162123 target_avg_loss /= len (self .optimizer )
21172124 else :
2118- target_avg_loss = keep_avg_target [ "avg_loss" ]
2125+ target_avg_loss = keep_avg_target . avg_values . get ( "avg_loss" , 0 )
21192126 return target_avg_loss
21202127
21212128 def _setup_logger_config (self , log_file : str ) -> None :
0 commit comments